Falcon-11B: Dict key mismatch; expected keys: ['input_layernorm', 'mlp', 'self_attention']; dict: {'self_attention': {'query_key_value': {'kernel': Array
s-smits opened this issue · 9 comments
Describe the bug
Layer names do not coincide
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[51], line 1
----> 1 output = trainer.train(
2 model_parameters=model_parameters if not use_lora else None,
3 state=None
4 )
5 # api.create_repo(new_repo_id, private=True, exist_ok=True)
6 # api.upload_file(
7 # path_or_fileobj=output.checkpoint_path,
8 # repo_id=new_repo_id,
9 # path_in_repo=output.last_save_file_name
10 # )
File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:402, in CausalLanguageModelTrainer.train(self, model_parameters, state)
395 termcolor.cprint(
396 "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
397 "Process.",
398 color="red",
399 force_color=True
400 )
401 start_time = time.time()
--> 402 sharded_state, shard_fns, gather_fns = self.initialize_state(
403 model_parameters=model_parameters,
404 state=state
405 )
407 count_model_parameters(sharded_state.params)
408 with self.mesh:
File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:290, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
283 if not isinstance(model_parameters, flax.core.FrozenDict):
284 prefix_print(
285 "Warning",
286 "Model Parameters should be like FrozenDict({'params': params}) make sure to "
287 "pass as type FrozenDict in case of not getting UnExcepted Errors "
288 )
--> 290 model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
291 lambda f, x: f(x),
292 shard_fns.params,
293 model_parameters,
294 )
295 sharded_state = self.create_sharded_state_from_params_function(model_parameters)
296 elif model_parameters is not None and self.checkpoint_path is not None:
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:243, in tree_map(f, tree, is_leaf, *rest)
210 """Maps a multi-input function over pytree args to produce a new pytree.
211
212 Args:
(...)
240 [[5, 7, 9], [6, 1, 2]]
241 """
242 leaves, treedef = tree_flatten(tree, is_leaf)
--> 243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:243, in <listcomp>(.0)
210 """Maps a multi-input function over pytree args to produce a new pytree.
211
212 Args:
(...)
240 [[5, 7, 9], [6, 1, 2]]
241 """
242 leaves, treedef = tree_flatten(tree, is_leaf)
--> 243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Dict key mismatch; expected keys: ['input_layernorm', 'mlp', 'self_attention']; dict: {'self_attention': {'query_key_value': {'kernel': Array([[ 1.355e-02, -5.432e-03, 3.442e-02, ..., -2.722e-02, -1.955e-05,
3.613e-02],
[ 3.015e-02, -7.751e-03, 7.568e-03, ..., 2.197e-02, 4.639e-02,
-9.644e-03],
[ 2.454e-02, -2.487e-03, 5.280e-03, ..., -1.880e-02, 1.770e-02,
-6.989e-03],
...,
[ 2.332e-02, -2.820e-02, 3.198e-02, ..., -3.198e-02, 4.395e-03,
-1.233e-02],
[ 1.855e-02, 1.465e-02, -7.935e-03, ..., 1.373e-02, -4.834e-02,
-4.370e-02],
[-6.683e-03, -1.288e-02, -1.575e-02, ..., 5.542e-02, 2.307e-02,
2.490e-02]], dtype=float16)}, 'dense': {'kernel': Array([[-0.00836 , -0.001678 , -0.00867 , ..., -0.01227 , 0.01355 ,
-0.01868 ],
[-0.03296 , 0.02063 , 0.04102 , ..., 0.0083 , 0.00928 ,
0.00806 ],
[-0.01276 , -0.01794 , -0.01587 , ..., 0.00525 , -0.006165 ,
-0.007812 ],
...,
[-0.02185 , -0.00206 , 0.01477 , ..., 0.003387 , -0.01953 ,
0.03613 ],
[-0.04712 , -0.00595 , -0.03174 , ..., -0.012695 , -0.02344 ,
-0.00772 ],
[ 0.03064 , -0.00653 , -0.002441 , ..., -0.01855 , 0.0013275,
0.002121 ]], dtype=float16)}}, 'mlp': {'dense_h_to_4h': {'kernel': Array([[-0.02246 , -0.00586 , -0.01007 , ..., -0.06006 , -0.002808 ,
-0.01611 ],
[-0.06006 , 0.02307 , -0.00958 , ..., 0.0199 , -0.05127 ,
-0.008484 ],
[ 0.00203 , 0.006897 , 0.02405 , ..., 0.02368 , 0.0008583,
-0.03613 ],
...,
[ 0.00296 , 0.03076 , 0.02747 , ..., -0.01538 , 0.00409 ,
-0.01227 ],
[-0.03784 , -0.03198 , 0.01672 , ..., 0.01831 , 0.01538 ,
0.002716 ],
[ 0.02734 , -0.01904 , -0.0718 , ..., -0.005554 , -0.0166 ,
-0.03735 ]], dtype=float16)}, 'dense_4h_to_h': {'kernel': Array([[ 0.0058 , -0.02832 , 0.01575 , ..., 0.005707, -0.00132 ,
0.0119 ],
[-0.0094 , 0.02283 , -0.02063 , ..., 0.01434 , -0.02417 ,
0.02783 ],
[-0.02002 , 0.01294 , 0.00116 , ..., 0.01099 , 0.003296,
-0.0371 ],
...,
[-0.02246 , 0.01733 , 0.00638 , ..., 0.01892 , -0.01746 ,
-0.0376 ],
[ 0.003616, -0.01575 , -0.00867 , ..., 0.01007 , 0.02771 ,
0.0152 ],
[-0.009766, -0.01855 , -0.00836 , ..., -0.01477 , 0.007935,
-0.0271 ]], dtype=float16)}}}.
To Reproduce
`!pip install fjformer datasets gradio wandb sentencepiece transformers=4.41.0 #git+https://github.com/huggingface/transformers -U -q
!pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install tensorflow --upgrade
HF_TOKEN = "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"
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
!apt-get update && apt-get upgrade -y && apt-get install golang -y
!pip install git+https://github.com/erfanzar/EasyDeL.git -U
from easydel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDeLModelForCausalLM,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
EasyDeLState,
EasyDeLXRapTureConfig,
get_modules_by_type,
easystate_to_huggingface_model,
SFTTrainer,
conversations_formatting_function,
AutoEasyDeLConfig
)
from datasets import load_dataset
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding
import jax
from transformers import AutoConfig
from huggingface_hub import HfApi, hf_hub_download
import os
import datasets
import re
PartitionSpec = sharding.PartitionSpec
api = HfApi()
import safetensors.torch
import jax.numpy as jnp
import numpy as np
Correct repository ID
repo_id = "ssmits/Falcon2-5.5B-Dutch"
Define the base filename and the number of files you expect
base_filename = "model-"
num_files = 12 # Adjust this number based on how many files you have
List to store the file paths
file_paths = []
Loop through each file number and download
for i in range(1, num_files + 1):
file_suffix = f"{i:05d}-of-{num_files:05d}.safetensors" # Formats the number as 00001-of-00012, etc.
filename = f"{base_filename}{file_suffix}"
# Downloading the model file
file_path = hf_hub_download(repo_id=repo_id, filename=filename, token=HF_TOKEN)
if file_path:
file_paths.append(file_path)
Load each state
states = []
for path in file_paths:
with jax.default_device(jax.devices("cpu")[0]):
state_dict = safetensors.torch.load_file(path)
state_dict = jax.tree_map(lambda x: jnp.array(x.float().numpy(), dtype=jnp.float16), state_dict)
states.append(state_dict)
sharding_axis_dims = (1, 1, -1, 1)
max_length = 2048
input_shape = (1, max_length)
pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path_tokenizer = "ssmits/Falcon2-5.5B-Dutch"
new_repo_id = "EasyDeL-SFT-Tuned-Model"
checkpoint_path = None
if checkpoint_path is None:
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=input_shape,
device_map="auto",
sharding_axis_dims=sharding_axis_dims,
config_kwargs=dict(
use_scan_mlp=False,
ffn_hidden_size=16384
)
)
config = model.config
model_parameters = FrozenDict({"params": params})
else:
model_parameters = None
config = AutoEasyDeLConfig.from_pretrained(pretrained_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path_tokenizer,
trust_remote_code=True
)
tokenizer.padding_side = 'right'
block_size = 128
dtype = jnp.float16
use_lora = None
config.add_basic_configurations(
attn_mechanism="flash",
block_b=1,
block_q=block_size,
block_k=block_size,
block_k_major=block_size,
shard_attention_computation=True,
)
configs_to_initialize_model_class = {
"config": config,
"dtype": dtype,
"param_dtype": dtype,
"input_shape": input_shape
}
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
rapture_config = EasyDeLXRapTureConfig(
model_parameters,
lora_dim=128,
fully_fine_tune_parameters=["embed_tokens"],
lora_fine_tune_parameters=["query_key_value"],
verbose=True
) if use_lora else None
from datasets import load_dataset
from transformers import AutoTokenizer
dataset_train = load_dataset(
"BramVanroy/ultra_feedback_dutch_cleaned",
name="sft_gpt4_hq",
split="train_sft"
)
tokenizer = AutoTokenizer.from_pretrained("ssmits/Falcon2-5.5B-Dutch")
def apply_template(example):
formatted_chat = []
for message in example['messages']:
role = message['role']
content = message['content']
formatted_chat.append(f"<|{role}|>\n{content}")
return {"formatted_chat": "".join(formatted_chat)}
dataset_train = dataset_train.map(apply_template, remove_columns=dataset_train.column_names)
print(dataset_train)
train_arguments = TrainArguments(
model_class=get_modules_by_type(config.model_type)[1],
configs_to_initialize_model_class=configs_to_initialize_model_class,
custom_rule=config.get_partition_rules(True),
num_train_epochs=1,
learning_rate=1e-5,
learning_rate_end=2e-6,
warmup_steps=400,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
weight_decay=0.02,
total_batch_size=8,
max_sequence_length=max_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=sharding_axis_dims,
use_pjit_attention_force=False,
gradient_accumulation_steps=1,
init_input_shape=input_shape,
dtype=dtype,
param_dtype=dtype,
step_start_point=0,
model_name="EasyDeL-SFT",
training_time="1H", # setting maximum time on trainer for kaggle session
force_batch_and_gradient_accumulation_steps_calculation=False,
rapture_config=rapture_config,
)
trainer = SFTTrainer(
arguments=train_arguments,
train_dataset=dataset_train,
eval_dataset=None,
tokenizer=tokenizer,
dataset_text_field='formatted_chat',
packing=False,
num_of_sequences=2048,
checkpoint_path=checkpoint_path,
)
output = trainer.train(
model_parameters=model_parameters if not use_lora else None,
state=None
)
api.create_repo(new_repo_id, private=True, exist_ok=True)
api.upload_file(
path_or_fileobj=output.checkpoint_path,
repo_id=new_repo_id,
path_in_repo=output.last_save_file_name
)
print(config.multi_query) # Output should be False
import transformers
with jax.default_device(jax.devices("cpu")[0]):
state = EasyDelState.load_state(
output.checkpoint_path,
input_shape=(8, 1),
)
if model_use_tie_word_embedding:
state_new_params = {
"params" : state.params["params"] | {
"lm_head" : {
"kernel" : state.params["params"]["model"]["embed_tokens"]["embedding"].T
}
}
}
state = state.replace( params = state_new_params )
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
config.dict.pop("_name_or_path", None)
with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=state,
base_huggingface_module=transformers.Qwen2ForCausalLM,
config=config
)
half_model = model.half()
tokenizer.push_to_hub(new_repo_id)
half_model.push_to_hub(new_repo_id, private=True)`
Im debugging at the moment and the problem is actually inside the huggingface
They will change the configuration s automatically and new decore arch will not be applied for falcon models
I have two ways
Recreate the whole reading and hub system of huggingface
Use simple static model converters like what easylm or maxtext do for llama
Thank you for your quick reply. So the best action would be to wait until this huggingface problem is fixed? I can add the issue to the HF Falcon-11B repo.
@s-smits it's fixed, now i can finetune a Falcon11B model myself, you can try that
Awesome, testing it right now.
Action : Sharding Passed Parameters
XlaRuntimeError Traceback (most recent call last)
Cell In[25], line 1
----> 1 output = trainer.train(
2 model_parameters=model_parameters if not use_lora else None,
3 state=None
4 )
5 # api.create_repo(new_repo_id, private=True, exist_ok=True)
6 # api.upload_file(
7 # path_or_fileobj=output.checkpoint_path,
8 # repo_id=new_repo_id,
9 # path_in_repo=output.last_save_file_name
10 # )
File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:406, in CausalLanguageModelTrainer.train(self, model_parameters, state)
399 termcolor.cprint(
400 "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
401 "Process.",
402 color="red",
403 force_color=True
404 )
405 start_time = time.time()
--> 406 sharded_state, shard_fns, gather_fns = self.initialize_state(
407 model_parameters=model_parameters,
408 state=state
409 )
411 count_model_parameters(sharded_state.params)
412 with self.mesh:
File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:299, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
288 prefix_print(
289 "Warning",
290 "Model Parameters should be like FrozenDict({'params': params}) make sure to "
291 "pass as type FrozenDict in case of not getting UnExcepted Errors "
292 )
294 model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
295 lambda f, x: f(x),
296 shard_fns.params,
297 model_parameters,
298 )
--> 299 sharded_state = self.create_sharded_state_from_params_function(model_parameters)
300 elif model_parameters is not None and self.checkpoint_path is not None:
301 raise EasyDeLTimerError(
302 "You can't pass model_parameters
and checkpoint_path
at same time"
303 )
[... skipping hidden 10 frame]
File /usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1159, in ExecuteReplicated.call(self, *args)
1157 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1158 else:
-> 1159 results = self.xla_executable.execute_sharded(input_bufs)
1160 if dispatch.needs_check_special():
1161 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 128.00M. That was not possible. There are 88.58M free.; (1x1x0_HBM1): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
Got an OOM with batch size 1 for TPU v3-8 so 128GB available. At least the model gets loaded now, so great addition.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
adding these would not help unfortunately.
you can train the model actually let me check and ill give you the code to how to train your model more efficiently
look I didn't actually finetune Falcon 11b but I'm finetuning Phi 14B so this code should work just fine
Configs
I'm training with batch size 6 since I'm using sequence sharing method on kaggle TPUs
sharding_axis_dims = (1, 1, 1, -1)
max_length = 2048
# max_length = 1792
input_shape = (1, max_length)
pretrained_model_name_or_path = "microsoft/Phi-3-medium-4k-instruct"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = "erfanzar/Phi-3-Medium-Instruct-v0.1"
dtype = jnp.bfloat16
use_lora = False
block_size = 512
attn_mechanism = "sharded_vanilla"
attention_partitions = dict(
query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
value_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
bias_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, None),
attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
)
LOADING MODEL
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device = jax.devices('cpu')[0],
input_shape = input_shape,
device_map = "auto",
sharding_axis_dims = sharding_axis_dims,
verbose_params=True,
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism=attn_mechanism,
**attention_partitions
),
**attention_partitions
)
config = model.config
model_use_tie_word_embedding = config.tie_word_embeddings
model_parameters = FrozenDict({"params" : params})
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path_tokenizer,
trust_remote_code=True
)
config.add_basic_configurations(
attn_mechanism=attn_mechanism,
shard_attention_computation=True,
**attention_partitions
)
configs_to_initialize_model_class={
"config" : config,
"dtype" : dtype,
"param_dtype" : dtype,
"input_shape" : input_shape
}
if tokenizer.pad_token == None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
It is working, apparently the problem was explicitly needing to state backend="tpu" at the TrainArguments, I assumed this was already set automatically through the CausalLanguageModelTrainer.