NaN loss in ORPOTrainer with legacy_sharded_vanilla
nyl199310 opened this issue · 9 comments
Hi, I tried to use ORPOTrainer to finetune a model. I found that if I use sharded_vanilla
or other attention mechanisms, it reports memory resource exhausted issue but the loss stats are normal (it can only run first step then out of memory error).
Training: 1%| | 1/100 [01:26<2:23:09, 86.76s/it, epoch=0, learning_rate=1.98e-5, log_odds_chosen=5451.875, log_odds_ratio=0.0, logits/chosen=-1.4420047, logits/rejected=-1.5062321, logps/chosen=-111799.91, logps/rejected=-117251.78, loss=2.25, mean_loss=2.25, nll_loss=2.2457747, perplexity=9.45, rewards/accuracies=1.0, rewards/chosen=-11179.991, rewards/margins=545.1875, rewards/rejected=-11725.179, step=1, step_time=81.1]
Only if I use legacy_sharded_vanilla
, there is no out of memory error, but all the loss stats are nan.
Training: 36%|███▌ | 36/100 [02:49<02:34, 2.41s/it, epoch=0, learning_rate=1.29e-5, log_odds_chosen=nan, log_odds_ratio=nan, logits/chosen=nan, logits/rejected=nan, logps/chosen=nan, logps/rejected=nan, loss=nan, mean_loss=nan, nll_loss=nan, perplexity=nan, rewards/accuracies=0.0, rewards/chosen=nan, rewards/margins=nan, rewards/rejected=nan, step=36, step_time=0.133]
hello and thanks for using EasyDeL
actually i told you that legacy_sharded_vanilla
have a lot of miss computations and i recommended to don't use that, this attention only works good for AMD GPUs i don't really know why
but you can run
from easydel import AttentionModule
print(AttentionModule.test_attentions(axis_dims=(1,1,1,-1)))
Hi and thank you for your explanation. I may have given you the wrong idea. I first reduced the max_length
then I tested three attention mechanism sharded_vanilla
, local_ring
, wise_ring
, legacy_sharded_vanilla
. All of them are nan loss stats (except first step). in addition, it's lora + orpo. below is my full code.
from easydel import (
AutoEasyDeLModelForCausalLM,
EasyDeLXRapTureConfig,
AutoEasyDeLConfig,
EasyDeLState,
TrainArguments,
EasyDeLOptimizers,
EasyDeLSchedulers,
SFTTrainer,
ORPOTrainer,
EasyDeLGradientCheckPointers,
easystate_to_huggingface_model,
get_modules_by_type
)
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaForCausalLM, AutoConfig
from jax import numpy as jnp, lax
from flax.core import FrozenDict
import jax
import flax
from huggingface_hub import HfApi
huggingface_model_repo_id = "NousResearch/Meta-Llama-3-8B-Instruct"
max_length = 2048
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
huggingface_model_repo_id,
device=jax.devices('cpu')[0],
input_shape=(1,2048),
device_map="auto",
sharding_axis_dims=(1, 1, 1, -1),
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism='sharded_vanilla',
max_length=2048
),
)
config = AutoEasyDeLConfig.from_pretrained(
huggingface_model_repo_id
)
tokenizer = AutoTokenizer.from_pretrained(
huggingface_model_repo_id,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, max_length)
}
params = FrozenDict({"params": params})
rapture = EasyDeLXRapTureConfig(
parameters=params,
lora_dim=128,
fully_fine_tune_parameters=[], # Model layer to be fully fine tuned
lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"], # LoRA Layer Targets you can pass this to none
# For only Layer Tuning or transfer learning
verbose=True
)
train_arguments = TrainArguments(
model_class=get_modules_by_type(model.config.model_type)[1],
model_name="llama3",
num_train_epochs=1,
configs_to_initialize_model_class=configs_to_initialize_model_class,
learning_rate=2e-5,
# step_start_point=step_start_point,
learning_rate_end=2e-7,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.LINEAR,
weight_decay=0.01,
#dataloader_num_workers=96,
total_batch_size=1,
max_training_steps=None,
do_train=True,
do_eval=False,
backend="tpu",
max_sequence_length=max_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, 1, 1, -1),
init_input_shape=(1,max_length),
remove_ckpt_after_load=True,
gradient_accumulation_steps=1,
training_time="8H",
track_memory=True,
neftune_noise_alpha=5.0,
force_batch_and_gradient_accumulation_steps_calculation=True,
loss_re_mat="",
dtype=jnp.bfloat16,
use_wandb=False,
rapture_config=rapture,
merge_lora_rapture_parameters=True
)
train_dataset = load_dataset("Intel/orca_dpo_pairs")['train']
desired_indices = range(0, 100)
train_dataset = train_dataset.select(desired_indices)
train_dataset = train_dataset.rename_column('question', 'prompt')
trainer = ORPOTrainer(
arguments=train_arguments,
max_length = 2048,
max_prompt_length = 2048,
max_completion_length = 2048,
beta = 0.1,
train_dataset=train_dataset,
eval_dataset=None,
tokenizer=tokenizer,
low_mem_usage=True,
dataset_num_proc=1
)
output = trainer.train()
Besides, I don't know if it's because the nan loss issue, after finetuning, the lora merging process report an issue like below.
Training: 99%|█████████▉| 99/100 [04:06<00:01, 1.32s/it, epoch=0, learning_rate=3.98e-7, log_odds_chosen=nan, log_odds_ratio=nan, logits/chosen=nan, logits/rejected=nan, logps/chosen=nan, logps/rejected=nan, loss=nan, mean_loss=nan, nll_loss=nan, perplexity=nan, rewards/accuracies=0.0, rewards/chosen=nan, rewards/margins=nan, rewards/rejected=nan, step=99, step_time=0.129]
Info : Merging LoRA Parameters.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:376, in XRapTure.merge_parameters.<locals>._ensure_delete(val)
375 try:
--> 376 val.device_buffer.delete()
377 except ValueError:
File /usr/local/lib/python3.10/site-packages/jax/_src/array.py:484, in ArrayImpl.device_buffer(self)
483 return self._arrays[0]
--> 484 raise ValueError('Length of buffers is greater than 1. Please use '
485 '`.device_buffers` instead.')
ValueError: Length of buffers is greater than 1. Please use `.device_buffers` instead.
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
Cell In[1], line 152
136 train_dataset = train_dataset.rename_column('question', 'prompt')
139 trainer = ORPOTrainer(
140 arguments=train_arguments,
141 max_length = 2048,
(...)
149 dataset_num_proc=1
150 )
--> 152 output = trainer.train()
153 # output = trainer.train(flax.core.FrozenDict({"params": params}))
154
155
(...)
174 # config.push_to_hub("ivt1993/writer_llama3_8b_test", private=True, token='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')
175 # print('done')
File /usr/local/lib/python3.10/site-packages/easydel/trainer/orpo/orpo_trainer.py:1082, in ORPOTrainer.train(self, model_parameters, state)
1072 if self.arguments.merge_lora_rapture_parameters and self.rapture is not None:
1073 print(
1074 termcolor.colored(
1075 "Info : ", color="red", force_color=True
(...)
1079 )
1080 )
1081 self.model_state = self.model_state.replace(
-> 1082 params=self.rapture.merge_parameters(self.model_state.params)
1083 )
1085 shard_fns, gather_fns = make_shard_and_gather_fns(
1086 partition_specs=match_partition_rules(
1087 rules=self.model_state.module.config.get_partition_rules(
(...)
1092 dtype_specs=self.arguments.dtype
1093 )
1094 output = ORPOTrainerOutput(
1095 state=self.model_state,
1096 mesh=self.mesh,
(...)
1099 checkpoint_manager=self.checkpoint_manager,
1100 )
File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:390, in XRapTure.merge_parameters(lora_parameters, destructive)
387 return result
388 return param
--> 390 return tree_map_with_implicit(map_fn, lora_parameters)
File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/implicit_array.py:643, in combine_leaf_predicate.<locals>.new_fn(new_is_leaf, *args)
641 def combined_is_leaf(arg):
642 return is_leaf(arg) or new_is_leaf(arg)
--> 643 return base_fn(*args, is_leaf=combined_is_leaf)
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in tree_map(f, tree, is_leaf, *rest)
242 leaves, treedef = tree_flatten(tree, is_leaf)
243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in <genexpr>(.0)
242 leaves, treedef = tree_flatten(tree, is_leaf)
243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:386, in XRapTure.merge_parameters.<locals>.map_fn(param)
384 result = materialize(param)
385 if destructive:
--> 386 jax.tree_map(_ensure_delete, param)
387 return result
388 return param
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in tree_map(f, tree, is_leaf, *rest)
242 leaves, treedef = tree_flatten(tree, is_leaf)
243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in <genexpr>(.0)
242 leaves, treedef = tree_flatten(tree, is_leaf)
243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:378, in XRapTure.merge_parameters.<locals>._ensure_delete(val)
376 val.device_buffer.delete()
377 except ValueError:
--> 378 val.device_buffers.delete()
AttributeError: 'list' object has no attribute 'delete'
thanks for reporting LoRA Issue, btw let me explain the process in attentionModule
attentions and attention module works very different in each scenario, each device, and each config
for example Splash, Flash Attention, Splash attention, Blockwise, which are the best supported modules are not working well on TPUs and how you can find best attention mechanism that works best for you
Tip
Following script are Kaggle-TPU-v3 with JAX 0.4.28 Results, and they are different on your device with different JAX versions
>ed.AttentionModule.test_attentions(axis_dims=(1,-1,1,1)) # FSDP Attention
METHOD | OUT DIFF | GRADIENT DIFF | TEST PASSED | COMP TIME |
---|---|---|---|---|
LOCAL_RING | 4.5249023 | 0.11016822 | False | 6.365082 |
BLOCKWISE | 2.142334 | 0.1539554 | False | 1.802229 |
VANILLA | 0.0028076172 | 0.0055647814 | True | 0.019311 |
WISE_RING | 1917.0201 | 31.526917 | False | 5.255283 |
SHARDED_VANILLA | 0.0014648438 | 0.0 | True | 0.040799 |
LEGACY_SHARDED_VANILLA | 0.0014648438 | 0.0 | True | 8.804201 |
FLASH | 4.2730713 | 0.11256987 | False | 2.643749 |
SPLASH | 8935.969 | 470.02725 | False | 4.505385 |
CUDNN | NA | NA | NA | NA |
PALLAS_FLASH | NA | NA | NA | NA |
as you can see in this case Legacy shared vanilla, vanilla, and shared vanilla works fine
but let change axis dims to 1,1,1,-1 or sequence sharding method
>ed.AttentionModule.test_attentions(axis_dims=(1,1,1,-1)) # sequence sharding Attention
METHOD | OUT DIFF | GRADIENT DIFF | TEST PASSED | COMP TIME |
---|---|---|---|---|
LOCAL_RING | nan | nan | False | 5.171325 |
BLOCKWISE | 2.1427002 | 0.1539554 | False | 4.047265 |
VANILLA | 0.0005493164 | 0.0 | True | 1.270251 |
WISE_RING | nan | nan | False | 5.366389 |
SHARDED_VANILLA | 0.0005493164 | 0.0 | True | 3.149804 |
LEGACY_SHARDED_VANILLA | nan | nan | False | 8.098974 |
FLASH | NA | NA | NA | NA |
SPLASH | 8935.971 | 470.02725 | False | 5.481151 |
CUDNN | NA | NA | NA | NA |
PALLAS_FLASH | NA | NA | NA | NA |
and as you can see some attention output None in Sequence Sharding method and only Vanilla and Sharded Vanilla works here, but in case that your using other TPU versions or GPUs all of the attention works for you
@nyl199310 hello, is the issue fixed?
@erfanzar Sorry, I was not able to test it. I cannot start the training with the latest code. it always stop after displaying below output with jax version=0.4.25.
/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
E0601 03:43:44.061192666 2836 oauth2_credentials.cc:238] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-06-01T03:43:44.0611721+00:00", grpc_status:2}
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
/usr/local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Converting Model: 100%|██████████| 164/164 [01:22<00:00, 1.99it/s, missed_shardings=0]
/usr/local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
/usr/local/lib/python3.10/site-packages/easydel/trainer/training_configurations.py:388: UserWarning: setting `log_grad_norms` to off since using log grad norms while using LoRA is not Supported.
warnings.warn(
Warning : You are using LoRA (Low-Rank Adaptation of Large Language Models) and this feature isstill in Beta mode so it might act unexpected
Downloading readme: 100%|██████████| 196/196 [00:00<00:00, 1.14MB/s]
Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]
Downloading data: 0%| | 0.00/36.3M [00:00<?, ?B/s]
Downloading data: 29%|██▉ | 10.5M/36.3M [00:00<00:01, 13.8MB/s]
Downloading data: 87%|████████▋ | 31.5M/36.3M [00:01<00:00, 32.1MB/s]
Downloading data: 100%|██████████| 36.3M/36.3M [00:01<00:00, 28.9MB/s]
Downloading data files: 100%|██████████| 1/1 [00:01<00:00, 1.34s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 959.79it/s]
Generating train split: 12859 examples [00:00, 95186.13 examples/s]
Map: 100%|██████████| 100/100 [01:18<00:00, 1.27 examples/s]
if jax version = 0.4.28. it only display below output, then stop runing remaining code.
/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
I'm still trying to figure out what's wrong with my code. it used to run normally.
It's due to kaggle environment changes they must have changed a lot of things in their entire environment, ill fix this asap
this should fix that
pip install -r https://raw.githubusercontent.com/erfanzar/EasyDeL/main/env_requirements.txt
this should fix that
pip install -r https://raw.githubusercontent.com/erfanzar/EasyDeL/main/env_requirements.txt
you don't need to do this anymore it's fixed in new easydel version 0.0.67