kuprel/min-dalle

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32

Closed this issue · 3 comments

papul commented

Running python image_from_text.py --text='a comfy chair' --seed=7 shows the following error:

$  python image_from_text.py --text='a comfy chair' --seed=7                                                                                                                            

Namespace(mega=False, torch=False, text='a comfy chair', seed=7, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mini
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
text tokens [0, 58, 29872, 2408, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "/home/papul/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/papul/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/home/papul/.local/lib/python3.10/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/home/papul/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/papul/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/home/papul/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/home/papul/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 255, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 214, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/home/papul/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/home/papul/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.
papul commented

Works after doing a git pull.

I just did the git pull, still doesn't work :/

Someone pointed out I pinned torch not flax. It should be flax==0.4.2. I just updated the requirements.txt file