erfanzar/EasyDeL

oom when llama2-7b sft

kuangdao opened this issue · 5 comments

i try to stf llama2-7b and oom, can it support fsdp or tensor parallel

who can tell me why ?

the error is :

截屏2024-06-20 14 55 08

and the code is :
'''
from easydel import (
TrainArguments,
AutoEasyDeLModelForCausalLM,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
SFTTrainer,
conversations_formatting_function # i have added this one for newcomers so if they
# don't know what's going on they can use this pre created prompter
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

#huggingface_repo_id_or_path = "/cfs/models/Llama2-Chinese-7b-Chat"
huggingface_repo_id_or_path="TinyLlama-1.1B-intermediate-step-1431k-3T"
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 4096
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}

#sharding_axis_dims = (1, -1, 1, 1)

train_arguments = TrainArguments(
model_class=type(model),
model_name="SFT-EasyDeL",
num_train_epochs=3,
configs_to_initialize_model_class=configs_to_initialize_model_class,
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
weight_decay=0.01,
total_batch_size=1,
max_training_steps=None, # None to let trainer Decide
do_train=True,
fully_sharded_data_parallel=True,
#sharding_array=sharding_axis_dims,
#step_partition_spec='fsdp',
do_eval=False, # it's optional but supported
backend="gpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_sequence_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in sequence and model parallel automatic and share data between devices
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16
)

def prompter(sample):
return [conversations_formatting_function(tokenizer, messages_field="messages")(sample)]

#train_dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")

data_dict = {'train':'./ultrachat_200k/ultrachat_200k-train_sft-00000-of-00003.arrow'}

train_dataset = load_dataset('arrow', data_files=data_dict)['train']

trainer = SFTTrainer(
arguments=train_arguments,
train_dataset=train_dataset,
eval_dataset=None, # we don't have eval dataset rn :)
tokenizer=tokenizer,
dataset_text_field=None,
formatting_func=prompter,
packing=True,
num_of_sequences=max_length,
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

'''
i run scripts as such python train.py

i run it in A800 which 1 node and 8 gpus

hello @kuangdao and thanks for using EasyDeL, and sorry for late response
you can try given code and it's using FSDP but you can also change to Sequence Parallelization.

from easydel import (
    TrainArguments,
    AutoEasyDeLModelForCausalLM,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    SFTTrainer,
    PartitionAxis,
    conversations_formatting_function  # i have added this one for newcomers so if they
    # don't know what's going on they can use this pre created prompter
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

# huggingface_repo_id_or_path = "/cfs/models/Llama2-Chinese-7b-Chat"
dtype = jnp.bfloat16
block_size = 512
attn_mechanism = "sharded_vanilla"
partition_axis = PartitionAxis()
huggingface_repo_id_or_path = "TinyLlama-1.1B-intermediate-step-1431k-3T"
sharding_axis_dims = (1, -1, 1, 1)  # Change to 1,1,1,-1 for Sequence Sharding
max_length = 4096
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    huggingface_repo_id_or_path,
    param_dtype=dtype,
    dtype=dtype,
    input_shape=(8, 8),  # since you said you have 8 GPUs
    auto_shard_params=True,
    sharding_axis_dims=sharding_axis_dims,
    verbose_params=True,
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism=attn_mechanism,
        partition_axis=partition_axis
    ),
    partition_axis=partition_axis,
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": dtype,
    "param_dtype": dtype,
    "input_shape": (8, 8)
}

# sharding_axis_dims = (1, -1, 1, 1)

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="SFT-EasyDeL",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
    weight_decay=0.01,
    total_batch_size=8,  # Note if you are using FSDP you can't use batch size 1 since you have 8 GPUs
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    fully_sharded_data_parallel=True,
    force_batch_and_gradient_accumulation_steps_calculation=False,
    do_eval=False,  # it's optional but supported
    backend="gpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_sequence_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=sharding_axis_dims,
    # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    use_pjit_attention_force=False,
    init_input_shape=(8, 8),
    dtype=dtype,
    param_dtype=dtype,
    step_start_point=0,
    do_last_save=False,
    do_shard_fns=False,
    track_memory=False,  # Install GO lang first ...
)


def prompter(sample):
    return [conversations_formatting_function(tokenizer, messages_field="messages")(sample)]


# train_dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")

data_dict = {'train': './ultrachat_200k/ultrachat_200k-train_sft-00000-of-00003.arrow'}

train_dataset = load_dataset('arrow', data_files=data_dict)['train']

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=None,  # we don't have eval dataset rn :)
    tokenizer=tokenizer,
    dataset_text_field=None,
    formatting_func=prompter,
    packing=True,
    num_of_sequences=max_length,
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

Is that fixed @kuangdao

this issue is closed because no response has been given