kuprel/min-dalle

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

Closed this issue · 4 comments

When running python image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100 or anything else with --mega

Output is the following:

Namespace(mega=True, torch=False, text='court sketch of godzilla on trial', seed=100, image_path='generated', image_token_count=256)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġcourt']
['Ġsketch']
['Ġof']
['Ġgodzilla']
['Ġon']
['Ġtrial']
text tokens [0, 2634, 4189, 111, 14450, 133, 5167, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "/Users/REDACTED/workspace/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/opt/homebrew/lib/python3.9/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/opt/homebrew/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/REDACTED/workspace/min-dalle/image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/opt/homebrew/lib/python3.9/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/REDACTED/workspace/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

Running on M1 Macbook Pro.

I get the same error when the --mega flag is set running on M1 Macbook Pro.

see #2

This issue occurred in the latest flax version 0.5.2. It should work now with the latest commit

This issue occurred in the latest flax version 0.5.2. It should work now with the latest commit

I'm new to this, but I was still running into the same error after your commit. It works for me when I specify dtype=jnp.float16 for each of the 4 nn.Dense layers in dalle_bart_encoder_flax.py's AttentionFlax.setup()