load_in_8bit doesn't work on Kaggle TPU
IvoryTower800 opened this issue · 2 comments
Hi, I tried to use this new feature load_in_8bit=True to finetune gemma model on kaggle tpu. However, it showed the error as below. I'm wondering whether is a bug or it's a feature that will not support on tpu v3.
Besides, I found the ring attention can be used on kaggle tpu now which couldn't before. Amazing! Thank you.
Time Took to Complete Task configure functions and sharding them (microseconds) : 1513.6299133300781
Action : Sharding Passed Parameters
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[13], line 144
137 trainer = CausalLanguageModelTrainer(
138 train_arguments,
139 dataset_train,
140 # checkpoint_path='/root/' + ckpt_name
141 )
143 # output = trainer.train()
--> 144 output = trainer.train(flax.core.FrozenDict({"params": params}))
145 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
147 api.upload_file(
148 path_or_fileobj=output.checkpoint_path,
149 path_in_repo=output.checkpoint_path.split('/')[-1],
150 repo_id="ivt1993/writer_2b_gemma",
151 repo_type="model"
152 )
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:397, in CausalLanguageModelTrainer.train(self, model_parameters, state)
390 termcolor.cprint(
391 "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
392 "Process.",
393 color="red",
394 force_color=True
395 )
396 start_time = time.time()
--> 397 sharded_state, shard_fns, gather_fns = self.initialize_state(
398 model_parameters=model_parameters,
399 state=state
400 )
402 count_model_parameters(sharded_state.params)
403 with self.mesh:
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:287, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
280 if not isinstance(model_parameters, flax.core.FrozenDict):
281 prefix_print(
282 "Warning",
283 "Model Parameters should be like FrozenDict({'params': params}) make sure to "
284 "pass as type FrozenDict in case of not getting UnExcepted Errors "
285 )
--> 287 model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
288 lambda f, x: f(x),
289 shard_fns.params,
290 model_parameters,
291 )
292 sharded_state = self.create_sharded_state_from_params_function(model_parameters)
293 elif model_parameters is not None and self.checkpoint_path is not None:
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:312, in tree_map(f, tree, is_leaf, *rest)
310 leaves, treedef = tree_flatten(tree, is_leaf)
311 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 312 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:312, in <genexpr>(.0)
310 leaves, treedef = tree_flatten(tree, is_leaf)
311 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 312 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:288, in CausalLanguageModelTrainer.initialize_state.<locals>.<lambda>(f, x)
280 if not isinstance(model_parameters, flax.core.FrozenDict):
281 prefix_print(
282 "Warning",
283 "Model Parameters should be like FrozenDict({'params': params}) make sure to "
284 "pass as type FrozenDict in case of not getting UnExcepted Errors "
285 )
287 model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
--> 288 lambda f, x: f(x),
289 shard_fns.params,
290 model_parameters,
291 )
292 sharded_state = self.create_sharded_state_from_params_function(model_parameters)
293 elif model_parameters is not None and self.checkpoint_path is not None:
File /usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py:50, in make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn(tensor)
49 def shard_fn(tensor):
---> 50 return jax_shard_function(tensor).block_until_ready()
[... skipping hidden 12 frame]
File /usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py:38, in make_shard_and_gather_fns.<locals>.make_to_dtype_fn.<locals>.to_dtype(tensor)
36 elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
37 return jnp.asarray(tensor).astype(dtype_spec.dtype)
---> 38 return jnp.asarray(tensor)
File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2233, in asarray(a, dtype, order, copy)
2231 if dtype is not None:
2232 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
-> 2233 return array(a, dtype=dtype, copy=bool(copy), order=order)
File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2186, in array(object, dtype, copy, order, ndmin)
2184 out = np.array(memoryview(object), copy=copy)
2185 else:
-> 2186 raise TypeError(f"Unexpected input type for array: {type(object)}")
2188 out_array: Array = lax_internal._convert_element_type(
2189 out, dtype, weak_type=weak_type)
2190 if ndmin > ndim(out_array):
TypeError: Unexpected input type for array: <class 'fjformer.linen.linen.LinearBitKernel'>
@IvoryTower800 hi, and thanks for using EasyDeL!
load_in_8bit=True
option is not supported for Fine-Tuning Model Right now and it's for serving purposes only (right now)
and it will soon be supported for LoRA, DPO, SFT, CLM fine-tuning and pre-training and when the time comes it wont be only a TPU feature only and it will be supported across all of the chips and devices (GPU/CPU/TPU/...).
@erfanzar That will be really cool! Thank you!