training does not start using latest easydel
Closed this issue · 6 comments
IvoryTower800 commented
Hi, I the training doesn't start with recent update. I tried different models and parameters. It only show below information. Then it stop running.
Besides, I need to manually set below parameters to load the model. Otherwise, there is a value error.
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
auto_shard_params=True,
sharding_axis_dims=(1,-1,1,1),
input_shape=(8,max_length))
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: ivorytower800. Use `wandb login --relogin` to force relogin
Tracking run with wandb version 0.16.6
Run data is saved locally in /kaggle/working/wandb/run-20240424_161106-rims0awm
Syncing run [woven-snowflake-158](https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b/runs/rims0awm) to [Weights & Biases](https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b) ([docs](https://wandb.me/run))
View project at https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b
View run at https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b/runs/rims0awm
erfanzar commented
Hi, Are you sure you are calling tariner.train() ?
erfanzar commented
Can you share the code?
IvoryTower800 commented
@erfanzar Sure, below is the code.
from EasyDel import (
TrainArguments,
AutoEasyDelModelForCausalLM,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers,
SFTTrainer,
CausalLanguageModelTrainer,
conversations_formatting_function
)
from datasets import load_dataset,load_from_disk
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
max_length = 8192
huggingface_repo_id_or_path = "google/gemma-2b-it"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
auto_shard_params=True,
sharding_axis_dims=(1,-1,1,1),
input_shape=(8,max_length))
# model.config.add_basic_configurations(
# attn_mechanism="wise_ring", # Using Flash Attention here you can simply just set this to normal or ring
# )
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (8, max_length)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="writer-gemma-2b",
num_train_epochs=1,
configs_to_initialize_model_class=configs_to_initialize_model_class,
custom_rule=model.config.get_partition_rules(True),
learning_rate=0.000001846,
learning_rate_end=2e-7,
max_sequence_length=max_length,
optimizer=EasyDelOptimizers.ADAMW,
scheduler=EasyDelSchedulers.LINEAR,
weight_decay=0.01,
warmup_steps=0,
total_batch_size=8,
save_optimizer_state=False,
max_training_steps=None,
do_train=True,
do_eval=False,
backend="tpu",
max_length=max_length,
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1),
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
init_input_shape=(8, max_length),
gradient_accumulation_steps=1,
loss_re_mat="",
dtype=jnp.bfloat16,
training_time="8H",
track_memory=True,
force_batch_and_gradient_accumulation_steps_calculation=True,
use_wandb=True, # This disable WANB usage
)
dataset_train = load_from_disk('/kaggle/input/sadlfkjaslkgma8192')
desired_indices = range(0, len(dataset_train))
dataset_train = dataset_train.select(desired_indices)
trainer = CausalLanguageModelTrainer(
train_arguments,
dataset_train,
# checkpoint_path='/root/' + ckpt_name
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
erfanzar commented
if you are using auto_shard_params=True
for loading model you should disable do_shard_fns
in TrainingArguments
erfanzar commented
@IvoryTower800 is that fixed?
IvoryTower800 commented
@erfanzar Hi, sorry for the late reply. it was fixed. Thank you!