erfanzar/EasyDeL

load_in_8bit doesn't work on Kaggle TPU

IvoryTower800 opened this issue · 2 comments

Hi, I tried to use this new feature load_in_8bit=True to finetune gemma model on kaggle tpu. However, it showed the error as below. I'm wondering whether is a bug or it's a feature that will not support on tpu v3.

Besides, I found the ring attention can be used on kaggle tpu now which couldn't before. Amazing! Thank you.

Time Took to Complete Task configure functions and sharding them (microseconds) : 1513.6299133300781
Action : Sharding Passed Parameters
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[13], line 144
    137 trainer = CausalLanguageModelTrainer(
    138     train_arguments,
    139     dataset_train,
    140 #     checkpoint_path='/root/' + ckpt_name
    141 )
    143 # output = trainer.train()
--> 144 output = trainer.train(flax.core.FrozenDict({"params": params}))
    145 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
    147 api.upload_file(
    148     path_or_fileobj=output.checkpoint_path,
    149     path_in_repo=output.checkpoint_path.split('/')[-1],
    150     repo_id="ivt1993/writer_2b_gemma",
    151     repo_type="model"
    152 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:397, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    390     termcolor.cprint(
    391         "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
    392         "Process.",
    393         color="red",
    394         force_color=True
    395     )
    396 start_time = time.time()
--> 397 sharded_state, shard_fns, gather_fns = self.initialize_state(
    398     model_parameters=model_parameters,
    399     state=state
    400 )
    402 count_model_parameters(sharded_state.params)
    403 with self.mesh:

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:287, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
    280     if not isinstance(model_parameters, flax.core.FrozenDict):
    281         prefix_print(
    282             "Warning",
    283             "Model Parameters should be like FrozenDict({'params': params}) make sure to "
    284             "pass as type FrozenDict in case of not getting UnExcepted Errors "
    285         )
--> 287     model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
    288         lambda f, x: f(x),
    289         shard_fns.params,
    290         model_parameters,
    291     )
    292     sharded_state = self.create_sharded_state_from_params_function(model_parameters)
    293 elif model_parameters is not None and self.checkpoint_path is not None:

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:312, in tree_map(f, tree, is_leaf, *rest)
    310 leaves, treedef = tree_flatten(tree, is_leaf)
    311 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 312 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:312, in <genexpr>(.0)
    310 leaves, treedef = tree_flatten(tree, is_leaf)
    311 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 312 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:288, in CausalLanguageModelTrainer.initialize_state.<locals>.<lambda>(f, x)
    280     if not isinstance(model_parameters, flax.core.FrozenDict):
    281         prefix_print(
    282             "Warning",
    283             "Model Parameters should be like FrozenDict({'params': params}) make sure to "
    284             "pass as type FrozenDict in case of not getting UnExcepted Errors "
    285         )
    287     model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
--> 288         lambda f, x: f(x),
    289         shard_fns.params,
    290         model_parameters,
    291     )
    292     sharded_state = self.create_sharded_state_from_params_function(model_parameters)
    293 elif model_parameters is not None and self.checkpoint_path is not None:

File /usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py:50, in make_shard_and_gather_fns.<locals>.make_shard_fn.<locals>.shard_fn(tensor)
     49 def shard_fn(tensor):
---> 50     return jax_shard_function(tensor).block_until_ready()

    [... skipping hidden 12 frame]

File /usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py:38, in make_shard_and_gather_fns.<locals>.make_to_dtype_fn.<locals>.to_dtype(tensor)
     36 elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
     37     return jnp.asarray(tensor).astype(dtype_spec.dtype)
---> 38 return jnp.asarray(tensor)

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2233, in asarray(a, dtype, order, copy)
   2231 if dtype is not None:
   2232   dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 2233 return array(a, dtype=dtype, copy=bool(copy), order=order)

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2186, in array(object, dtype, copy, order, ndmin)
   2184   out = np.array(memoryview(object), copy=copy)
   2185 else:
-> 2186   raise TypeError(f"Unexpected input type for array: {type(object)}")
   2188 out_array: Array = lax_internal._convert_element_type(
   2189     out, dtype, weak_type=weak_type)
   2190 if ndmin > ndim(out_array):

TypeError: Unexpected input type for array: <class 'fjformer.linen.linen.LinearBitKernel'>

@IvoryTower800 hi, and thanks for using EasyDeL!

load_in_8bit=True option is not supported for Fine-Tuning Model Right now and it's for serving purposes only (right now)
and it will soon be supported for LoRA, DPO, SFT, CLM fine-tuning and pre-training and when the time comes it wont be only a TPU feature only and it will be supported across all of the chips and devices (GPU/CPU/TPU/...).

@erfanzar That will be really cool! Thank you!