erfanzar/EasyDeL

Falcon-11B: Dict key mismatch; expected keys: ['input_layernorm', 'mlp', 'self_attention']; dict: {'self_attention': {'query_key_value': {'kernel': Array

s-smits opened this issue · 9 comments

Describe the bug
Layer names do not coincide

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[51], line 1
----> 1 output = trainer.train(
      2     model_parameters=model_parameters if not use_lora else None,
      3     state=None
      4 )
      5 # api.create_repo(new_repo_id, private=True, exist_ok=True)
      6 # api.upload_file(
      7 #     path_or_fileobj=output.checkpoint_path,
      8 #     repo_id=new_repo_id,
      9 #     path_in_repo=output.last_save_file_name
     10 # )

File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:402, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    395     termcolor.cprint(
    396         "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
    397         "Process.",
    398         color="red",
    399         force_color=True
    400     )
    401 start_time = time.time()
--> 402 sharded_state, shard_fns, gather_fns = self.initialize_state(
    403     model_parameters=model_parameters,
    404     state=state
    405 )
    407 count_model_parameters(sharded_state.params)
    408 with self.mesh:

File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:290, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
    283     if not isinstance(model_parameters, flax.core.FrozenDict):
    284         prefix_print(
    285             "Warning",
    286             "Model Parameters should be like FrozenDict({'params': params}) make sure to "
    287             "pass as type FrozenDict in case of not getting UnExcepted Errors "
    288         )
--> 290     model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
    291         lambda f, x: f(x),
    292         shard_fns.params,
    293         model_parameters,
    294     )
    295     sharded_state = self.create_sharded_state_from_params_function(model_parameters)
    296 elif model_parameters is not None and self.checkpoint_path is not None:

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:243, in tree_map(f, tree, is_leaf, *rest)
    210 """Maps a multi-input function over pytree args to produce a new pytree.
    211 
    212 Args:
   (...)
    240   [[5, 7, 9], [6, 1, 2]]
    241 """
    242 leaves, treedef = tree_flatten(tree, is_leaf)
--> 243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:243, in <listcomp>(.0)
    210 """Maps a multi-input function over pytree args to produce a new pytree.
    211 
    212 Args:
   (...)
    240   [[5, 7, 9], [6, 1, 2]]
    241 """
    242 leaves, treedef = tree_flatten(tree, is_leaf)
--> 243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Dict key mismatch; expected keys: ['input_layernorm', 'mlp', 'self_attention']; dict: {'self_attention': {'query_key_value': {'kernel': Array([[ 1.355e-02, -5.432e-03,  3.442e-02, ..., -2.722e-02, -1.955e-05,
         3.613e-02],
       [ 3.015e-02, -7.751e-03,  7.568e-03, ...,  2.197e-02,  4.639e-02,
        -9.644e-03],
       [ 2.454e-02, -2.487e-03,  5.280e-03, ..., -1.880e-02,  1.770e-02,
        -6.989e-03],
       ...,
       [ 2.332e-02, -2.820e-02,  3.198e-02, ..., -3.198e-02,  4.395e-03,
        -1.233e-02],
       [ 1.855e-02,  1.465e-02, -7.935e-03, ...,  1.373e-02, -4.834e-02,
        -4.370e-02],
       [-6.683e-03, -1.288e-02, -1.575e-02, ...,  5.542e-02,  2.307e-02,
         2.490e-02]], dtype=float16)}, 'dense': {'kernel': Array([[-0.00836  , -0.001678 , -0.00867  , ..., -0.01227  ,  0.01355  ,
        -0.01868  ],
       [-0.03296  ,  0.02063  ,  0.04102  , ...,  0.0083   ,  0.00928  ,
         0.00806  ],
       [-0.01276  , -0.01794  , -0.01587  , ...,  0.00525  , -0.006165 ,
        -0.007812 ],
       ...,
       [-0.02185  , -0.00206  ,  0.01477  , ...,  0.003387 , -0.01953  ,
         0.03613  ],
       [-0.04712  , -0.00595  , -0.03174  , ..., -0.012695 , -0.02344  ,
        -0.00772  ],
       [ 0.03064  , -0.00653  , -0.002441 , ..., -0.01855  ,  0.0013275,
         0.002121 ]], dtype=float16)}}, 'mlp': {'dense_h_to_4h': {'kernel': Array([[-0.02246  , -0.00586  , -0.01007  , ..., -0.06006  , -0.002808 ,
        -0.01611  ],
       [-0.06006  ,  0.02307  , -0.00958  , ...,  0.0199   , -0.05127  ,
        -0.008484 ],
       [ 0.00203  ,  0.006897 ,  0.02405  , ...,  0.02368  ,  0.0008583,
        -0.03613  ],
       ...,
       [ 0.00296  ,  0.03076  ,  0.02747  , ..., -0.01538  ,  0.00409  ,
        -0.01227  ],
       [-0.03784  , -0.03198  ,  0.01672  , ...,  0.01831  ,  0.01538  ,
         0.002716 ],
       [ 0.02734  , -0.01904  , -0.0718   , ..., -0.005554 , -0.0166   ,
        -0.03735  ]], dtype=float16)}, 'dense_4h_to_h': {'kernel': Array([[ 0.0058  , -0.02832 ,  0.01575 , ...,  0.005707, -0.00132 ,
         0.0119  ],
       [-0.0094  ,  0.02283 , -0.02063 , ...,  0.01434 , -0.02417 ,
         0.02783 ],
       [-0.02002 ,  0.01294 ,  0.00116 , ...,  0.01099 ,  0.003296,
        -0.0371  ],
       ...,
       [-0.02246 ,  0.01733 ,  0.00638 , ...,  0.01892 , -0.01746 ,
        -0.0376  ],
       [ 0.003616, -0.01575 , -0.00867 , ...,  0.01007 ,  0.02771 ,
         0.0152  ],
       [-0.009766, -0.01855 , -0.00836 , ..., -0.01477 ,  0.007935,
        -0.0271  ]], dtype=float16)}}}.

To Reproduce
`!pip install fjformer datasets gradio wandb sentencepiece transformers=4.41.0 #git+https://github.com/huggingface/transformers -U -q
!pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install tensorflow --upgrade
HF_TOKEN = "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"
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<HF_TOKEN>')"
!apt-get update && apt-get upgrade -y && apt-get install golang -y

!pip install git+https://github.com/erfanzar/EasyDeL.git -U

from easydel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDeLModelForCausalLM,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
EasyDeLState,
EasyDeLXRapTureConfig,
get_modules_by_type,
easystate_to_huggingface_model,
SFTTrainer,
conversations_formatting_function,
AutoEasyDeLConfig
)

from datasets import load_dataset
from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding
import jax
from transformers import AutoConfig
from huggingface_hub import HfApi, hf_hub_download
import os
import datasets
import re

PartitionSpec = sharding.PartitionSpec
api = HfApi()

import safetensors.torch
import jax.numpy as jnp
import numpy as np

Correct repository ID

repo_id = "ssmits/Falcon2-5.5B-Dutch"

Define the base filename and the number of files you expect

base_filename = "model-"
num_files = 12 # Adjust this number based on how many files you have

List to store the file paths

file_paths = []

Loop through each file number and download

for i in range(1, num_files + 1):
file_suffix = f"{i:05d}-of-{num_files:05d}.safetensors" # Formats the number as 00001-of-00012, etc.
filename = f"{base_filename}{file_suffix}"

# Downloading the model file
file_path = hf_hub_download(repo_id=repo_id, filename=filename, token=HF_TOKEN)

if file_path:
    file_paths.append(file_path)

Load each state

states = []
for path in file_paths:
with jax.default_device(jax.devices("cpu")[0]):
state_dict = safetensors.torch.load_file(path)
state_dict = jax.tree_map(lambda x: jnp.array(x.float().numpy(), dtype=jnp.float16), state_dict)
states.append(state_dict)

sharding_axis_dims = (1, 1, -1, 1)
max_length = 2048
input_shape = (1, max_length)

pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path_tokenizer = "ssmits/Falcon2-5.5B-Dutch"
new_repo_id = "EasyDeL-SFT-Tuned-Model"

checkpoint_path = None
if checkpoint_path is None:
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=input_shape,
device_map="auto",
sharding_axis_dims=sharding_axis_dims,
config_kwargs=dict(
use_scan_mlp=False,
ffn_hidden_size=16384
)
)

config = model.config
model_parameters = FrozenDict({"params": params})

else:
model_parameters = None
config = AutoEasyDeLConfig.from_pretrained(pretrained_model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path_tokenizer,
trust_remote_code=True
)
tokenizer.padding_side = 'right'

block_size = 128
dtype = jnp.float16
use_lora = None

config.add_basic_configurations(
attn_mechanism="flash",
block_b=1,
block_q=block_size,
block_k=block_size,
block_k_major=block_size,
shard_attention_computation=True,
)

configs_to_initialize_model_class = {
"config": config,
"dtype": dtype,
"param_dtype": dtype,
"input_shape": input_shape
}

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

rapture_config = EasyDeLXRapTureConfig(
model_parameters,
lora_dim=128,
fully_fine_tune_parameters=["embed_tokens"],
lora_fine_tune_parameters=["query_key_value"],
verbose=True
) if use_lora else None

from datasets import load_dataset
from transformers import AutoTokenizer

dataset_train = load_dataset(
"BramVanroy/ultra_feedback_dutch_cleaned",
name="sft_gpt4_hq",
split="train_sft"
)

tokenizer = AutoTokenizer.from_pretrained("ssmits/Falcon2-5.5B-Dutch")

def apply_template(example):
formatted_chat = []
for message in example['messages']:
role = message['role']
content = message['content']
formatted_chat.append(f"<|{role}|>\n{content}")

return {"formatted_chat": "".join(formatted_chat)}

dataset_train = dataset_train.map(apply_template, remove_columns=dataset_train.column_names)

print(dataset_train)

train_arguments = TrainArguments(
model_class=get_modules_by_type(config.model_type)[1],
configs_to_initialize_model_class=configs_to_initialize_model_class,
custom_rule=config.get_partition_rules(True),

num_train_epochs=1,
learning_rate=1e-5,
learning_rate_end=2e-6,
warmup_steps=400,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
weight_decay=0.02,
total_batch_size=8,
max_sequence_length=max_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=sharding_axis_dims,
use_pjit_attention_force=False,
gradient_accumulation_steps=1,

init_input_shape=input_shape,
dtype=dtype,
param_dtype=dtype,
step_start_point=0,

model_name="EasyDeL-SFT",
training_time="1H", # setting maximum time on trainer for kaggle session
force_batch_and_gradient_accumulation_steps_calculation=False,
rapture_config=rapture_config,

)

trainer = SFTTrainer(
arguments=train_arguments,
train_dataset=dataset_train,
eval_dataset=None,
tokenizer=tokenizer,
dataset_text_field='formatted_chat',
packing=False,
num_of_sequences=2048,
checkpoint_path=checkpoint_path,
)

output = trainer.train(
model_parameters=model_parameters if not use_lora else None,
state=None
)

api.create_repo(new_repo_id, private=True, exist_ok=True)

api.upload_file(

path_or_fileobj=output.checkpoint_path,

repo_id=new_repo_id,

path_in_repo=output.last_save_file_name

)

print(config.multi_query) # Output should be False

import transformers

with jax.default_device(jax.devices("cpu")[0]):
state = EasyDelState.load_state(
output.checkpoint_path,
input_shape=(8, 1),
)

if model_use_tie_word_embedding:
    state_new_params = {
        "params" : state.params["params"] | {
            "lm_head" : {
                "kernel" : state.params["params"]["model"]["embed_tokens"]["embedding"].T
            }
        }
    }

    state = state.replace( params = state_new_params )

config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)

config.dict.pop("_name_or_path", None)

with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=state,
base_huggingface_module=transformers.Qwen2ForCausalLM,
config=config
)

half_model = model.half()

tokenizer.push_to_hub(new_repo_id)
half_model.push_to_hub(new_repo_id, private=True)`

Im debugging at the moment and the problem is actually inside the huggingface
They will change the configuration s automatically and new decore arch will not be applied for falcon models
I have two ways

Recreate the whole reading and hub system of huggingface

Use simple static model converters like what easylm or maxtext do for llama

Thank you for your quick reply. So the best action would be to wait until this huggingface problem is fixed? I can add the issue to the HF Falcon-11B repo.

@s-smits it's fixed, now i can finetune a Falcon11B model myself, you can try that

Awesome, testing it right now.

Action : Sharding Passed Parameters

XlaRuntimeError Traceback (most recent call last)
Cell In[25], line 1
----> 1 output = trainer.train(
2 model_parameters=model_parameters if not use_lora else None,
3 state=None
4 )
5 # api.create_repo(new_repo_id, private=True, exist_ok=True)
6 # api.upload_file(
7 # path_or_fileobj=output.checkpoint_path,
8 # repo_id=new_repo_id,
9 # path_in_repo=output.last_save_file_name
10 # )

File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:406, in CausalLanguageModelTrainer.train(self, model_parameters, state)
399 termcolor.cprint(
400 "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
401 "Process.",
402 color="red",
403 force_color=True
404 )
405 start_time = time.time()
--> 406 sharded_state, shard_fns, gather_fns = self.initialize_state(
407 model_parameters=model_parameters,
408 state=state
409 )
411 count_model_parameters(sharded_state.params)
412 with self.mesh:

File /usr/local/lib/python3.10/site-packages/easydel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:299, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
288 prefix_print(
289 "Warning",
290 "Model Parameters should be like FrozenDict({'params': params}) make sure to "
291 "pass as type FrozenDict in case of not getting UnExcepted Errors "
292 )
294 model_parameters = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
295 lambda f, x: f(x),
296 shard_fns.params,
297 model_parameters,
298 )
--> 299 sharded_state = self.create_sharded_state_from_params_function(model_parameters)
300 elif model_parameters is not None and self.checkpoint_path is not None:
301 raise EasyDeLTimerError(
302 "You can't pass model_parameters and checkpoint_path at same time"
303 )

[... skipping hidden 10 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1159, in ExecuteReplicated.call(self, *args)
1157 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1158 else:
-> 1159 results = self.xla_executable.execute_sharded(input_bufs)
1160 if dispatch.needs_check_special():
1161 out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 128.00M. That was not possible. There are 88.58M free.; (1x1x0_HBM1): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

Got an OOM with batch size 1 for TPU v3-8 so 128GB available. At least the model gets loaded now, so great addition.

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

adding these would not help unfortunately.

you can train the model actually let me check and ill give you the code to how to train your model more efficiently

look I didn't actually finetune Falcon 11b but I'm finetuning Phi 14B so this code should work just fine

Configs

I'm training with batch size 6 since I'm using sequence sharing method on kaggle TPUs

sharding_axis_dims = (1, 1, 1, -1)
max_length = 2048
# max_length = 1792
input_shape = (1, max_length)

pretrained_model_name_or_path = "microsoft/Phi-3-medium-4k-instruct"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = "erfanzar/Phi-3-Medium-Instruct-v0.1"

dtype = jnp.bfloat16
use_lora = False
block_size = 512
attn_mechanism = "sharded_vanilla"

attention_partitions = dict(
    query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    value_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
    bias_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, None),
    attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
)

LOADING MODEL

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    device = jax.devices('cpu')[0],
    input_shape = input_shape,
    device_map = "auto",
    sharding_axis_dims = sharding_axis_dims,
    verbose_params=True,
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism=attn_mechanism,
        **attention_partitions
    ),
    **attention_partitions
)

config = model.config

model_use_tie_word_embedding = config.tie_word_embeddings

model_parameters = FrozenDict({"params" : params})

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path_tokenizer,
    trust_remote_code=True
)

config.add_basic_configurations(
    attn_mechanism=attn_mechanism,
    shard_attention_computation=True,
    **attention_partitions
)

configs_to_initialize_model_class={
    "config" : config,
    "dtype" : dtype,
    "param_dtype" : dtype,
    "input_shape" : input_shape
}

if tokenizer.pad_token == None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "right"

It is working, apparently the problem was explicitly needing to state backend="tpu" at the TrainArguments, I assumed this was already set automatically through the CausalLanguageModelTrainer.