TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
Closed this issue · 4 comments
andreafdaf commented
When running python image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100
or anything else with --mega
Output is the following:
Namespace(mega=True, torch=False, text='court sketch of godzilla on trial', seed=100, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġcourt']
['Ġsketch']
['Ġof']
['Ġgodzilla']
['Ġon']
['Ġtrial']
text tokens [0, 2634, 4189, 111, 14450, 133, 5167, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
File "/Users/REDACTED/workspace/min-dalle/image_from_text.py", line 44, in <module>
image = generate_image_from_text(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
image_tokens[...] = generate_image_tokens_flax(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
image_tokens = decode_flax(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
image_tokens = decoder.sample_image_tokens(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
_, image_tokens = lax.scan(
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/opt/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
logits, keys_state, values_state = self.apply(
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 1159, in apply
return apply(
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/scope.py", line 831, in wrapper
y = fn(root, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 1535, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
decoder_state, (keys_state, values_state) = self.layers(
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 218, in wrapper
y, out_variable_groups_xs_t = fn(
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 770, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/opt/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 114, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 754, in scanned
c, y = fn(scope, c, *args)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 307, in core_fn
res = fn(cloned, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
decoder_state, keys_values_state = self.self_attn(
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
keys_state = lax.dynamic_update_slice(
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
return dynamic_update_slice_p.bind(operand, update, *start_indices)
File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
out_aval, effects = primitive.abstract_eval(*avals, **params)
File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 359, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
dtype_rule(*avals, **kwargs), weak_type=weak_type,
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/REDACTED/workspace/min-dalle/image_from_text.py", line 44, in <module>
image = generate_image_from_text(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
image_tokens[...] = generate_image_tokens_flax(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
image_tokens = decode_flax(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
image_tokens = decoder.sample_image_tokens(
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
_, image_tokens = lax.scan(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
logits, keys_state, values_state = self.apply(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
decoder_state, (keys_state, values_state) = self.layers(
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 114, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
decoder_state, keys_values_state = self.self_attn(
File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
Running on M1 Macbook Pro.
robmckinnon commented
I get the same error when the --mega
flag is set running on M1 Macbook Pro.
kuprel commented
This issue occurred in the latest flax version 0.5.2. It should work now with the latest commit
warmlogic commented
This issue occurred in the latest flax version 0.5.2. It should work now with the latest commit
I'm new to this, but I was still running into the same error after your commit. It works for me when I specify dtype=jnp.float16
for each of the 4 nn.Dense
layers in dalle_bart_encoder_flax.py
's AttentionFlax.setup()