Error while training GPT2 on the kaggle
Closed this issue · 3 comments
jchauhan commented
Describe the bug
Error while training gpt2 on kaggle
/root
Downloading base model...
<class 'EasyDel.modules.gpt2.modelling_gpt2_flax.FlaxGPT2LMHeadModel'>
<class 'EasyDel.modules.gpt2.gpt2_configuration.GPT2Config'>
Downloading data files: 100%|██████████████████| 1/1 [00:00<00:00, 11214.72it/s]
Extracting data files: 100%|████████████████████| 1/1 [00:00<00:00, 1438.38it/s]
Generating train split: 186074 examples [00:00, 353355.15 examples/s]
Map (num_proc=12): 100%|██████| 186074/186074 [00:03<00:00, 47310.49 examples/s]
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=12): 0%| | 0/186074 [00:00<?, ? examples/s]/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
Map (num_proc=12): 100%|███████| 186074/186074 [00:21<00:00, 8523.98 examples/s]
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: jchauhan (safedep). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in /root/wandb/run-20240201_154611-g3pguwrw
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run avid-sky-14
wandb: ⭐️ View project at [https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel](https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel%3C/span%3E)
wandb: 🚀 View run at [https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel/runs/g3pguwrw](https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel/runs/g3pguwrw%3C/span%3E)
Time Took to Complete Task configure dataloaders (microseconds) : 0.4191398620605469
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 597.6324081420898
Time Took to Complete Task configure functions and sharding them (microseconds) : 745.0320720672607
Action : Sharding Passed Parameters
Traceback (most recent call last):
File "/root/train.py", line 123, in <module>
output = trainer.train(flax.core.FrozenDict({"params": params}))
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 478, in train
sharded_state, shard_fns, gather_fns = self.initialize_state(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 393, in initialize_state
params = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in tree_map
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in <listcomp>
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Dict key mismatch; expected keys: ['transformer']; dict: {'transformer': {'wte': {'embedding':
To Reproduce
%%writefile /root/train.py
import os
import jax.numpy
import EasyDel
from EasyDel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDelModelForCausalLM,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wand_key = user_secrets.get_secret("WAND_KEY")
# os.environ["WANDB_DISABLED"] = "false"
os.environ["WANDB_API_KEY"] = wand_key
base_model_hf_repo_id_or_path = "gpt2"
max_length = 1024
trained_model_name = "****"
trained_model_hf_repo_id = f"****/{trained_model_name}"
easydel_trained_model_name = f"{trained_model_name}.easydel"
training_data_files="****"
import json
import sys
jcdataset = load_dataset('****', split='train')
f = open("./lmsys-toxic-gpt.json", "w")
for conversation in jcdataset['chunks']:
out = "<|input|><|response|>"
for req_res in conversation:
out = out + req_res['prompt']
f.write(json.dumps({'train': out}))
f.write("\n")
out = "<|input|>" + req_res['response'] +"<|response|>"
print("Downloading base model...")
model, params = AutoEasyDelModelForCausalLM.from_pretrained(base_model_hf_repo_id_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
base_model_hf_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}
model.config.use_sacn_mlp = False
print(type(model))
print(type(model.config))
train_arguments = TrainArguments(
model_class=type(model),
model_name=easydel_trained_model_name,
num_train_epochs=3,
configs_to_initialize_model_class=configs_to_init_model_class,
custom_rule=model.config.get_partition_rules(True),
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
scheduler=EasyDelSchedulers.LINEAR,
# "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
weight_decay=0.01,
total_batch_size=8,
max_steps=None, # None to let trainer Decide
do_train=True,
do_eval=False, # it's optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in fully FSDP automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16
)
def ultra_chat_prompting_process(
data_chunk
):
return {"prompt": data_chunk['train']}
tokenization_process = lambda data_chunk: tokenizer(
data_chunk["prompt"],
add_special_tokens=False,
max_length=max_length,
padding="max_length"
)
dataset = load_dataset("json", data_files=training_data_files)
dataset_train = dataset["train"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
tokenization_process,
num_proc=12,
remove_columns=dataset_train.column_names
)
# you can do the same for evaluation process dataset
trainer = CausalLanguageModelTrainer(
train_arguments,
dataset_train,
checkpoint_path=None
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
import tempfile
import os
from huggingface_hub import Repository, create_repo
from transformers import LlamaForCausalLM, LlamaTokenizer
import jax
from EasyDel import (
AutoEasyDelConfig,
EasyDelState,
easystate_to_huggingface_model
)
# Function to create a Hugging Face repository
def create_hf_repo(repo_name, hub_token=None):
tmp_dir = tempfile.TemporaryDirectory()
tmp_output_dir = tmp_dir.name
if repo_name is None:
repo_name = os.path.basename(tmp_output_dir)
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=hub_token).repo_id
# Clone repo locally
repo = Repository(tmp_output_dir, clone_from=repo_id, token=hub_token)
tmp_dir.cleanup()
return repo
# Define the base model ID, checkpoint path, and target Hugging Face repo ID
chkpoint_path = output.checkpoint_path
# Load configuration for the custom model
config = AutoEasyDelConfig.from_pretrained(base_model_hf_repo_id_or_path)
# Create the custom model using EasyDel
with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=EasyDelState.load_state(chkpoint_path),
base_huggingface_module=LlamaForCausalLM,
config=config
)
#
model = model.half() # Convert to a Hugging Face model
# Check if the target Hugging Face repo exists, and create it if not
hub_token = None # login is already done
# repo = create_hf_repo(trained_model_hf_repo_id, hub_token)
# Optionally, you can push the base model to the target repo as well
base_model = LlamaForCausalLM.from_pretrained(base_model_hf_repo_id_or_path)
# base_model.push_to_hub(trained_model_hf_repo_id, token=hub_token)
tokenizer.push_to_hub(trained_model_hf_repo_id, token=hub_token)
# Push the custom model to the target Hugging Face repo
model.push_to_hub(trained_model_hf_repo_id, token=hub_token)
erfanzar commented
hello the issue with GPT2 is now fixed you can rerun your script
erfanzar commented
is the issue fixed?
erfanzar commented
this issue is being closed due to no response has been given