Out of Memory issue in new easydel version.
nyl199310 opened this issue · 6 comments
Hi, I can run below code with previous easydel version without any problem.
!pip install git+https://github.com/erfanzar/EasyDeL.git@d06931e79cc3ef63920007d9e4f95fd0289df3cf # This version works well.
!pip install fjformer==0.0.51
but when I used latest easydel. it says:
XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 15.86G of 15.48G hbm. Exceeded hbm capacity by 382.92M.
from easydel import (
AutoEasyDeLModelForCausalLM,
AutoEasyDeLConfig,
EasyDeLState,
TrainArguments,
EasyDeLOptimizers,
EasyDeLSchedulers,
SFTTrainer,
ORPOTrainer,
EasyDeLGradientCheckPointers,
easystate_to_huggingface_model,
get_modules_by_type
)
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaForCausalLM, AutoConfig
from jax import numpy as jnp, lax
import jax
import flax
from huggingface_hub import HfApi
huggingface_model_repo_id = "NousResearch/Hermes-2-Pro-Llama-3-8B"
max_length = 8192
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
'huggingface_model_repo_id',
device=jax.devices('cpu')[0],
input_shape=(1,8192),
device_map="auto",
sharding_axis_dims=(1, 1, 1, -1),
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism='sharded_vanilla',
),
)
config = AutoEasyDeLConfig.from_pretrained(
huggingface_model_repo_id
)
tokenizer = AutoTokenizer.from_pretrained(
huggingface_model_repo_id,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, max_length)
}
train_arguments = TrainArguments(
model_class=get_modules_by_type(model.config.model_type)[1],
model_name="llama3",
num_train_epochs=1,
configs_to_initialize_model_class=configs_to_initialize_model_class,
learning_rate=2e-5,
# step_start_point=step_start_point,
learning_rate_end=2e-7,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.LINEAR,
weight_decay=0.01,
#dataloader_num_workers=96,
total_batch_size=1,
max_training_steps=None,
do_train=True,
do_eval=False,
backend="tpu",
max_sequence_length=max_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, 1, 1, -1),
init_input_shape=(1,max_length),
remove_ckpt_after_load=True,
gradient_accumulation_steps=3,
training_time="8H",
track_memory=True,
neftune_noise_alpha=5.0,
force_batch_and_gradient_accumulation_steps_calculation=True,
loss_re_mat="",
dtype=jnp.bfloat16
)
train_dataset = load_dataset('csv',data_files="/kaggle/input/insert-p1/insert_p1.csv")['train']
desired_indices = range(0, 200)
train_dataset = train_dataset.select(desired_indices)
trainer = SFTTrainer(
arguments=train_arguments,
train_dataset=train_dataset,
eval_dataset=None,
tokenizer=tokenizer,
dataset_text_field="text",
dataset_num_proc=96,
packing=False,
)
hello and thanks for using EasyDeL
It's likely due to recent changes in attention mechanism to fix miss computation problems.
I'll test the same code and try to find the issue and fix that.
Thank you so much @erfanzar . and there is another issue. When using the ORPOTrainer, the tokenize speed is very slow. about 1~2 examples per second. There isn't a parameter like SFTTrainer e.g, dataset_num_proc=96
. the same hardware can achieve about 3000 examples per second in SFTTrainer.
ORPOTrainer support dataset_map_arguments
which is a dict that will be passed to Dataset.map, but anyway I added dataset_num_proc
to it for you.
@nyl199310 you can use legacy_sharded_vanilla
for old attention but that one have a lot of miss computations in different devices
Hi, there is a bug after adding the dataset_num_proc
parameter.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 7
3 train_dataset = train_dataset.select(desired_indices)
4 train_dataset = train_dataset.rename_column('question', 'prompt')
----> 7 trainer = ORPOTrainer(
8 arguments=train_arguments,
9 max_length = 8192,
10 max_prompt_length = 8192,
11 max_completion_length = 2048,
12 beta = 0.1,
13 train_dataset=train_dataset,
14 eval_dataset=None,
15 tokenizer=tokenizer,
16 low_mem_usage=True,
17 )
19 output = trainer.train(flax.core.FrozenDict({"params": params}))
File /usr/local/lib/python3.10/site-packages/easydel/trainer/orpo/orpo_trainer.py:168, in ORPOTrainer.__init__(self, arguments, max_length, max_prompt_length, max_completion_length, beta, disable_dropout, label_pad_token_id, is_encoder_decoder, padding_value, data_collator, train_dataset, eval_dataset, tokenizer, dataset_num_proc, _do_init_fns, dataset_map_arguments, low_mem_usage)
166 if dataset_map_arguments is None:
167 dataset_map_arguments = {}
--> 168 train_dataset = train_dataset.map(
169 self.tokenize_row,
170 dataset_num_proc=dataset_num_proc,
171 **dataset_map_arguments
172 )
173 if eval_dataset is not None:
174 eval_dataset = eval_dataset.map(
175 self.tokenize_row,
176 dataset_num_proc=dataset_num_proc,
177 **dataset_map_arguments
178 )
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:592, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
590 self: "Dataset" = kwargs.pop("self")
591 # apply actual function
--> 592 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
593 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
594 for dataset in datasets:
595 # Remove task templates if a column mapping of the template is no longer valid
File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:557, in transmit_format.<locals>.wrapper(*args, **kwargs)
550 self_format = {
551 "type": self._format_type,
552 "format_kwargs": self._format_kwargs,
553 "columns": self._format_columns,
554 "output_all_columns": self._output_all_columns,
555 }
556 # apply actual function
--> 557 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
558 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
559 # re-apply format to the output
TypeError: Dataset.map() got an unexpected keyword argument 'dataset_num_proc'