run demo load checkpoint error .
kelisiya opened this issue · 11 comments
When I download large-3m in my local path , run this in A100 GPU
I set
FULL_CKPT_PATH = './experiment/unified-io-2/large-3m'
MODEL_TYPE = "large"
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray
same error
Looks like an issue with the checkpoints on GPUs, I will try and reproduce it.
Same issue on A6000
File /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:2026, in asarray(a, dtype, order)
2024 @_wraps(np.asarray, lax_description=_ARRAY_DOC)
2025 def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
-> 2026 lax_internal._check_user_dtype_supported(dtype, "asarray")
2027 dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
2028 return array(a, dtype=dtype, copy=False, order=order)
File /usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py:4812, in _check_user_dtype_supported(dtype, fun_name)
4810 msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
4811 msg += f" in {fun_name}" if fun_name else ""
-> 4812 raise TypeError(msg)
4813 if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
4814 msg = ("Explicitly requested dtype {} {} is not available, "
4815 "and will be truncated to dtype {}. To enable more dtypes, set the "
4816 "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
4817 "environment variable. "
4818 "See https://github.com/google/jax#current-gotchas for more.")
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray
same error
Same issue on A6000
File /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:2026, in asarray(a, dtype, order) 2024 @_wraps(np.asarray, lax_description=_ARRAY_DOC) 2025 def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array: -> 2026 lax_internal._check_user_dtype_supported(dtype, "asarray") 2027 dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype 2028 return array(a, dtype=dtype, copy=False, order=order) File /usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py:4812, in _check_user_dtype_supported(dtype, fun_name) 4810 msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" 4811 msg += f" in {fun_name}" if fun_name else "" -> 4812 raise TypeError(msg) 4813 if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype): 4814 msg = ("Explicitly requested dtype {} {} is not available, " 4815 "and will be truncated to dtype {}. To enable more dtypes, set the " 4816 "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " 4817 "environment variable. " 4818 "See https://github.com/google/jax#current-gotchas for more.") TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray
how to resolve?
change dtype in t5x/examples/unified_io/t5_1_1/{model_size}.gin and t5x/examples/unified_io.config.py to float32 in i got new error:
AssertionError Traceback (most recent call last)
Cell In[11], line 18
16 vocab = get_default_vocabulary()
17 partitioner = partitioning.PjitPartitioner(num_partitions=8)
---> 18 parameters, param_axes = uio_utils.get_parameters(model, FULL_CKPT_PATH, partitioner)
get_parameters(model, model_checkpoint, partitioner, rng)
83 input_shapes, input_types = get_input_spec(1)
84 if partitioner is not None:
---> 85 train_state_initializer = TrainStateInitializer(
86 optimizer_def=None,
87 init_fn=model.get_initial_variables,
88 input_shapes=input_shapes,
89 input_types=input_types,
90 partitioner=partitioner
91 )
92 param_axes = train_state_initializer.train_state_axes.params
93 params = LegacyCheckpointManager(
94 restore_cfg=RestoreCheckpointConfig(model_checkpoint),
95 train_state_shape=train_state_initializer.global_train_state_shape,
96 partitioner=partitioner
97 ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params
unified_io_2/t5x/utils.py:958, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, model, input_types)
955 return train_state_lib.InferenceState.create(initial_variables)
957 self._partitioner = partitioner
--> 958 self.global_train_state_shape = jax.eval_shape(
959 initialize_train_state, rng=jax.random.PRNGKey(0))
961 self.train_state_axes = partitioner.get_mesh_axes(
962 self.global_train_state_shape)
963 self._initialize_train_state = initialize_train_state
File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/_src/api.py:3201, in eval_shape(fun, *args, **kwargs)
3199 wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
3200 debug_info = pe.debug_info(fun, in_tree, True, "eval_shape")
-> 3201 out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
3202 *map(shaped_abstractify, args_flat),
3203 debug_info=debug_info)
3204 out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
3205 return tree_unflatten(out_tree(), out)
unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:660, in abstract_eval_fun(fun, debug_info, *avals, **params)
659 def abstract_eval_fun(fun, *avals, debug_info=None, **params):
--> 660 _, avals_out, _ = trace_to_jaxpr_dynamic(
661 lu.wrap_init(fun, params), avals, debug_info)
662 assert all(isinstance(aval, AbstractValue) for aval in avals_out)
663 return avals_out
unified_io/lib/python3.9/site-packages/jax/_src/profiler.py:314, in annotate_function..wrapper(*args, **kwargs)
311 @wraps(func)
312 def wrapper(*args, **kwargs):
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1981, in trace_to_jaxpr_dynamic(fun, in_avals, debug_info, keep_inputs)
1979 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
1980 main.jaxpr_stack = () # type: ignore
-> 1981 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
1982 fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
1983 del main, fun
1984 return jaxpr, out_avals, consts
unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998, in trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs, debug_info)
1996 in_tracers = input_type_to_tracers(trace.new_arg, in_avals)
1997 in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1998 ans = fun.call_wrapped(*in_tracers_)
1999 out_tracers = map(trace.full_raise, ans)
2000 jaxpr, consts = frame.to_jaxpr(out_tracers)
/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, **kwargs)
164 gen = gen_static_args = out_store = None
166 try:
--> 167 ans = self.f(*args, **dict(self.params, **kwargs))
168 except:
169 # Some transformations yield from inside context managers, so we have to
170 # interrupt them before reraising the exception. Otherwise they will only
171 # get garbage-collected at some later time, running their cleanup tasks
172 # only after this exception is handled, which can corrupt the global
173 # state.
174 while stack:
File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, **kwargs)
164 gen = gen_static_args = out_store = None
166 try:
--> 167 ans = self.f(*args, **dict(self.params, **kwargs))
168 except:
169 # Some transformations yield from inside context managers, so we have to
170 # interrupt them before reraising the exception. Otherwise they will only
171 # get garbage-collected at some later time, running their cleanup tasks
172 # only after this exception is handled, which can corrupt the global
173 # state.
174 while stack:
, in TargetSequence.post_init(self)
90 assert self.position_id.shape[:2] in [(1, seq_len), (bs, seq_len)]
92 assert self.modality_id.shape in [(), (1, seq_len), (bs, seq_len)]
---> 93 assert self.modality_id.dtype == jnp.int32
95 if self.target_tokens is not None:
96 assert self.target_tokens.shape == (bs, seq_len)
AssertionError:
I have added the ability to load the model in float32 model, I was able to run the demo with the XL model using one A6000 and the XXL with 2 A600.
I am not sure about the assertion error but feel free to create a new issue with instructions to reproduce it.
change dtype in t5x/examples/unified_io/t5_1_1/{model_size}.gin and t5x/examples/unified_io.config.py to float32 in i got new error:
AssertionError Traceback (most recent call last) Cell In[11], line 18 16 vocab = get_default_vocabulary() 17 partitioner = partitioning.PjitPartitioner(num_partitions=8) ---> 18 parameters, param_axes = uio_utils.get_parameters(model, FULL_CKPT_PATH, partitioner)
get_parameters(model, model_checkpoint, partitioner, rng) 83 input_shapes, input_types = get_input_spec(1) 84 if partitioner is not None: ---> 85 train_state_initializer = TrainStateInitializer( 86 optimizer_def=None, 87 init_fn=model.get_initial_variables, 88 input_shapes=input_shapes, 89 input_types=input_types, 90 partitioner=partitioner 91 ) 92 param_axes = train_state_initializer.train_state_axes.params 93 params = LegacyCheckpointManager( 94 restore_cfg=RestoreCheckpointConfig(model_checkpoint), 95 train_state_shape=train_state_initializer.global_train_state_shape, 96 partitioner=partitioner 97 ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params
unified_io_2/t5x/utils.py:958, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, model, input_types) 955 return train_state_lib.InferenceState.create(initial_variables) 957 self._partitioner = partitioner --> 958 self.global_train_state_shape = jax.eval_shape( 959 initialize_train_state, rng=jax.random.PRNGKey(0)) 961 self.train_state_axes = partitioner.get_mesh_axes( 962 self.global_train_state_shape) 963 self._initialize_train_state = initialize_train_state
File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/_src/api.py:3201, in eval_shape(fun, *args, **kwargs) 3199 wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) 3200 debug_info = pe.debug_info(fun, in_tree, True, "eval_shape") -> 3201 out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, 3202 *map(shaped_abstractify, args_flat), 3203 debug_info=debug_info) 3204 out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out] 3205 return tree_unflatten(out_tree(), out)
unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:660, in abstract_eval_fun(fun, debug_info, *avals, **params) 659 def abstract_eval_fun(fun, *avals, debug_info=None, **params): --> 660 _, avals_out, _ = trace_to_jaxpr_dynamic( 661 lu.wrap_init(fun, params), avals, debug_info) 662 assert all(isinstance(aval, AbstractValue) for aval in avals_out) 663 return avals_out
unified_io/lib/python3.9/site-packages/jax/_src/profiler.py:314, in annotate_function..wrapper(*args, **kwargs) 311 @wraps(func) 312 def wrapper(*args, **kwargs): 313 with TraceAnnotation(name, **decorator_kwargs): --> 314 return func(*args, **kwargs) 315 return wrapper
unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1981, in trace_to_jaxpr_dynamic(fun, in_avals, debug_info, keep_inputs) 1979 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore 1980 main.jaxpr_stack = () # type: ignore -> 1981 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( 1982 fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) 1983 del main, fun 1984 return jaxpr, out_avals, consts
unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998, in trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs, debug_info) 1996 in_tracers = input_type_to_tracers(trace.new_arg, in_avals) 1997 in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] -> 1998 ans = fun.call_wrapped(*in_tracers_) 1999 out_tracers = map(trace.full_raise, ans) 2000 jaxpr, consts = frame.to_jaxpr(out_tracers)
/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, **kwargs) 164 gen = gen_static_args = out_store = None 166 try: --> 167 ans = self.f(*args, **dict(self.params, **kwargs)) 168 except: 169 # Some transformations yield from inside context managers, so we have to 170 # interrupt them before reraising the exception. Otherwise they will only 171 # get garbage-collected at some later time, running their cleanup tasks 172 # only after this exception is handled, which can corrupt the global 173 # state. 174 while stack:
File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, **kwargs) 164 gen = gen_static_args = out_store = None 166 try: --> 167 ans = self.f(*args, **dict(self.params, **kwargs)) 168 except: 169 # Some transformations yield from inside context managers, so we have to 170 # interrupt them before reraising the exception. Otherwise they will only 171 # get garbage-collected at some later time, running their cleanup tasks 172 # only after this exception is handled, which can corrupt the global 173 # state. 174 while stack: , in TargetSequence.post_init(self) 90 assert self.position_id.shape[:2] in [(1, seq_len), (bs, seq_len)] 92 assert self.modality_id.shape in [(), (1, seq_len), (bs, seq_len)] ---> 93 assert self.modality_id.dtype == jnp.int32 95 if self.target_tokens is not None: 96 assert self.target_tokens.shape == (bs, seq_len)
AssertionError:
@chrisc36 @robooootx did you ever find a solution to this?
for error
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray
It is mostly due to import orbax.checkpoint
. Since this project has taken quite a long time, some of the packages we used are from older versions; recently we found that using pip install dependencies with Python 3.9 would indeed cause conflicts between the Jax and orbax.checkpoint
if specifying dtype="bfloat16"
, but it still works with Python 3.8 (e.g., 3.8.10, which is the default in TPU VMs). After downgrading Python to 3.8, please also downgrade pyglove==0.4.3
which is required by seqio
and the latest version released 3 weeks ago only supports Python 3.9. We'll look into this dependency issue more deeply but feel free to use this workaround for now!
For AssertionError
, I haven't met this with recent debugging. Could you please share more details? Can you provide a minimal script to reproduce? With the above change, there's no need to change dtype in t5x/examples/unified_io/t5_1_1/{model_size}.gin and t5x/examples/unified_io.config.py to float32
and you can directly set supports_bfloat16 = False
or True
if using GPU.
I only seem to get AssertionError
when attempting to use float32. With your solution above using bfloat16 worked for me.