erfanzar/EasyDeL

Error while finetuning Tinyllama on Kaggle TPU

jchauhan opened this issue · 5 comments

Describe the bug
An error while training tiny llama on kaggle

/root
/usr/local/lib/python3.10/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "model_name" has conflict with protected namespace "model_".

You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
  warnings.warn(
Information :  track_memory is set to False by default inorder make make training faster. you can turn it on with just passing `track_memory=True` in TrainArguments
Downloading data files: 100%|███████████████████| 1/1 [00:00<00:00, 7319.90it/s]
Extracting data files: 100%|████████████████████| 1/1 [00:00<00:00, 1059.70it/s]
Generating train split: 176 examples [00:00, 37836.88 examples/s]
Map (num_proc=12): 100%|██████████████| 176/176 [00:00<00:00, 416.75 examples/s]
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=12):   0%|                         | 0/176 [00:00<?, ? examples/s]/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
Map (num_proc=12): 100%|██████████████| 176/176 [00:00<00:00, 437.34 examples/s]
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: jchauhan (safedep). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.3
wandb: Run data is saved locally in /root/wandb/run-20240216_181547-lgbd2mo5
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run twinkling-fish-7
wandb: ⭐️ View project at https://wandb.ai/safedep/EasyDeL-my_first_model_to_train_using_easydel
wandb: 🚀 View run at https://wandb.ai/safedep/EasyDeL-my_first_model_to_train_using_easydel/runs/lgbd2mo5
Time Took to Complete Task configure dataloaders (microseconds) : 0.4432201385498047
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 1331.115484237671
Time Took to Complete Task configure functions and sharding them (microseconds) : 1449.2170810699463
Action : Sharding Passed Parameters
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/train.py", line 140, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 488, in train
    sharded_state, shard_fns, gather_fns = self.initialize_state(
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 402, in initialize_state
    model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 403, in <lambda>
    lambda f, x: f(x),
  File "/usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py", line 50, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float32[2048,32000]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/train.py", line 140, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 488, in train
    sharded_state, shard_fns, gather_fns = self.initialize_state(
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 402, in initialize_state
    model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 403, in <lambda>
    lambda f, x: f(x),
  File "/usr/local/lib/python3.10/site-packages/fjformer/partition_utils/mesh_utils.py", line 50, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float32[2048,32000]

To Reproduce

huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 1024
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="my_first_model_to_train_using_easydel",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=1,
    max_steps=10000,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

    

@erfanzar Any progress on it?

Code works just fine
Can you share the packages that you have installed

here are the dependencies

# !pip install git+https://github.com/erfanzar/EasyDeL.git
!!pip install EasyDeL==0.0.50
!pip install sentencepiece
!pip install jaxlib==0.4.19
!pip install jax[tpu]==0.4.19 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 
!apt-get update && apt-get upgrade -y
!apt-get install golang -y 

I tried both versions of Easydel

Use head version that's working fine or you can install from pypi tomorrow when I release next version

This is not fixed and i am still getting the error and My pip installations are
!pip install git+https://github.com/erfanzar/EasyDeL.git
!pip install jax[tpu]==0.4.20 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Tested with 0.4.21 as well. The issue is not fixed