Error while finetuning Tinyllama on Kaggle TPU
jchauhan opened this issue · 5 comments
jchauhan commented
Describe the bug
An error while training tiny llama on kaggle
/root
/usr/local/lib/python3.10/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "model_name" has conflict with protected namespace "model_".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
warnings.warn(
Information : track_memory is set to False by default inorder make make training faster. you can turn it on with just passing `track_memory=True` in TrainArguments
Downloading data files: 100%|███████████████████| 1/1 [00:00<00:00, 7319.90it/s]
Extracting data files: 100%|████████████████████| 1/1 [00:00<00:00, 1059.70it/s]
Generating train split: 176 examples [00:00, 37836.88 examples/s]
Map (num_proc=12): 100%|██████████████| 176/176 [00:00<00:00, 416.75 examples/s]
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=12): 0%| | 0/176 [00:00<?, ? examples/s]/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
Map (num_proc=12): 100%|██████████████| 176/176 [00:00<00:00, 437.34 examples/s]
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: jchauhan (safedep). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.3
wandb: Run data is saved locally in /root/wandb/run-20240216_181547-lgbd2mo5
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run twinkling-fish-7
wandb: ⭐️ View project at https://wandb.ai/safedep/EasyDeL-my_first_model_to_train_using_easydel
wandb: 🚀 View run at https://wandb.ai/safedep/EasyDeL-my_first_model_to_train_using_easydel/runs/lgbd2mo5
Time Took to Complete Task configure dataloaders (microseconds) : 0.4432201385498047
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 1331.115484237671
Time Took to Complete Task configure functions and sharding them (microseconds) : 1449.2170810699463
Action : Sharding Passed Parameters
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/train.py", line 140, in <module>
output = trainer.train(flax.core.FrozenDict({"params": params}))
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 488, in train
sharded_state, shard_fns, gather_fns = self.initialize_state(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 402, in initialize_state
model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 403, in <lambda>
lambda f, x: f(x),
File "/usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py", line 50, in shard_fn
return jax_shard_function(tensor).block_until_ready()
ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float32[2048,32000]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/train.py", line 140, in <module>
output = trainer.train(flax.core.FrozenDict({"params": params}))
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 488, in train
sharded_state, shard_fns, gather_fns = self.initialize_state(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 402, in initialize_state
model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 403, in <lambda>
lambda f, x: f(x),
File "/usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py", line 50, in shard_fn
return jax_shard_function(tensor).block_until_ready()
ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float32[2048,32000]
To Reproduce
huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
max_length = 1024
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="my_first_model_to_train_using_easydel",
num_train_epochs=3,
configs_to_initialize_model_class=configs_to_init_model_class,
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
scheduler=EasyDelSchedulers.LINEAR,
# "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
weight_decay=0.01,
total_batch_size=1,
max_steps=10000, # None to let trainer Decide
do_train=True,
do_eval=False, # it's optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in fully FSDP automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16
)
erfanzar commented
Code works just fine
Can you share the packages that you have installed
jchauhan commented
here are the dependencies
# !pip install git+https://github.com/erfanzar/EasyDeL.git
!!pip install EasyDeL==0.0.50
!pip install sentencepiece
!pip install jaxlib==0.4.19
!pip install jax[tpu]==0.4.19 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!apt-get update && apt-get upgrade -y
!apt-get install golang -y
I tried both versions of Easydel
erfanzar commented
Use head version that's working fine or you can install from pypi tomorrow when I release next version
saidineshpola commented
This is not fixed and i am still getting the error and My pip installations are
!pip install git+https://github.com/erfanzar/EasyDeL.git
!pip install jax[tpu]==0.4.20 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Tested with 0.4.21 as well. The issue is not fixed