erfanzar/EasyDeL

Out of Memory issue in new easydel version.

nyl199310 opened this issue · 6 comments

Hi, I can run below code with previous easydel version without any problem.

!pip install git+https://github.com/erfanzar/EasyDeL.git@d06931e79cc3ef63920007d9e4f95fd0289df3cf # This version works well.
!pip install fjformer==0.0.51

but when I used latest easydel. it says:

XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 15.86G of 15.48G hbm. Exceeded hbm capacity by 382.92M.

from easydel import (
    AutoEasyDeLModelForCausalLM,
    AutoEasyDeLConfig,
    EasyDeLState,
    TrainArguments,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    SFTTrainer,
    ORPOTrainer,
    EasyDeLGradientCheckPointers,
    easystate_to_huggingface_model,
    get_modules_by_type
)
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaForCausalLM, AutoConfig
from jax import numpy as jnp, lax
import jax
import flax
from huggingface_hub import HfApi


huggingface_model_repo_id = "NousResearch/Hermes-2-Pro-Llama-3-8B"
max_length = 8192


model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    'huggingface_model_repo_id',
    device=jax.devices('cpu')[0],
    input_shape=(1,8192),
    device_map="auto",
    sharding_axis_dims=(1, 1, 1, -1),
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism='sharded_vanilla',
    ),
)



config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_model_repo_id,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token


configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, max_length)
}


train_arguments = TrainArguments(
    model_class=get_modules_by_type(model.config.model_type)[1],
    model_name="llama3",
    num_train_epochs=1,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=2e-5,
#     step_start_point=step_start_point,
    learning_rate_end=2e-7,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.LINEAR,
    weight_decay=0.01,
    #dataloader_num_workers=96,
    total_batch_size=1,
    max_training_steps=None,
    do_train=True,
    do_eval=False,
    backend="tpu",
    max_sequence_length=max_length,
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),
    init_input_shape=(1,max_length),
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=3,
    training_time="8H",
    track_memory=True,
    neftune_noise_alpha=5.0,
    force_batch_and_gradient_accumulation_steps_calculation=True,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


train_dataset = load_dataset('csv',data_files="/kaggle/input/insert-p1/insert_p1.csv")['train']
desired_indices = range(0, 200)
train_dataset = train_dataset.select(desired_indices)

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=None,
    tokenizer=tokenizer,
    dataset_text_field="text",
    dataset_num_proc=96,
    packing=False,
)

hello and thanks for using EasyDeL
It's likely due to recent changes in attention mechanism to fix miss computation problems.
I'll test the same code and try to find the issue and fix that.

Thank you so much @erfanzar . and there is another issue. When using the ORPOTrainer, the tokenize speed is very slow. about 1~2 examples per second. There isn't a parameter like SFTTrainer e.g, dataset_num_proc=96. the same hardware can achieve about 3000 examples per second in SFTTrainer.

ORPOTrainer support dataset_map_arguments which is a dict that will be passed to Dataset.map, but anyway I added dataset_num_proc to it for you.

@erfanzar Thank you!

@nyl199310 you can use legacy_sharded_vanilla for old attention but that one have a lot of miss computations in different devices

Hi, there is a bug after adding the dataset_num_proc parameter.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 7
      3 train_dataset = train_dataset.select(desired_indices)
      4 train_dataset = train_dataset.rename_column('question', 'prompt')
----> 7 trainer = ORPOTrainer(
      8     arguments=train_arguments,
      9     max_length = 8192,
     10     max_prompt_length = 8192,
     11     max_completion_length = 2048,
     12     beta = 0.1,
     13     train_dataset=train_dataset,
     14     eval_dataset=None,
     15     tokenizer=tokenizer,
     16     low_mem_usage=True,
     17 )
     19 output = trainer.train(flax.core.FrozenDict({"params": params}))

File /usr/local/lib/python3.10/site-packages/easydel/trainer/orpo/orpo_trainer.py:168, in ORPOTrainer.__init__(self, arguments, max_length, max_prompt_length, max_completion_length, beta, disable_dropout, label_pad_token_id, is_encoder_decoder, padding_value, data_collator, train_dataset, eval_dataset, tokenizer, dataset_num_proc, _do_init_fns, dataset_map_arguments, low_mem_usage)
    166 if dataset_map_arguments is None:
    167     dataset_map_arguments = {}
--> 168 train_dataset = train_dataset.map(
    169     self.tokenize_row,
    170     dataset_num_proc=dataset_num_proc,
    171     **dataset_map_arguments
    172 )
    173 if eval_dataset is not None:
    174     eval_dataset = eval_dataset.map(
    175         self.tokenize_row,
    176         dataset_num_proc=dataset_num_proc,
    177         **dataset_map_arguments
    178     )

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:592, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    590     self: "Dataset" = kwargs.pop("self")
    591 # apply actual function
--> 592 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    593 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    594 for dataset in datasets:
    595     # Remove task templates if a column mapping of the template is no longer valid

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:557, in transmit_format.<locals>.wrapper(*args, **kwargs)
    550 self_format = {
    551     "type": self._format_type,
    552     "format_kwargs": self._format_kwargs,
    553     "columns": self._format_columns,
    554     "output_all_columns": self._output_all_columns,
    555 }
    556 # apply actual function
--> 557 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    558 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    559 # re-apply format to the output

TypeError: Dataset.map() got an unexpected keyword argument 'dataset_num_proc'