erfanzar/EasyDeL

Can't load checkpoints continue training

IvoryTower800 opened this issue · 7 comments

Hi, Sorry to bother you again.

currently it has to specify the sharding_axis_dims and input_shape to load model normally, which is different from it used to be. Otherwise, it will say the mesh shape is not correct 8 is not divided by 1.

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            sharding_axis_dims=(1,1,1,-1),
                                                            input_shape=(1,max_length))

I was finetuning a model and saved a checkpoint. I can train without error. However, when I want to continue finetune and add the parameter of my checkpoint_path in trainer. it says below error. Could tell me what should I do? Thank you so much!

huggingface_repo_id_or_path = "/kaggle/input/llama-3-8b-it"
max_length = 8192
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            config_kwargs={"attn_mechanism":"sharded_vanilla",'max_position_embeddings': max_length},
                                                            sharding_axis_dims=(1,1,1,-1),
                                                            input_shape=(1,max_length))


tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, max_length)
}



train_arguments = TrainArguments(
    model_class=type(model),
    model_name="llama3_8b",
    num_train_epochs=2,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=2e-6,
    learning_rate_end=2e-7,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.LINEAR,
    weight_decay=0.01,
    total_batch_size=1,
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_sequence_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    init_input_shape=(1,max_length),
    # everything training will be in sequence and model parallel automatic and share data between devices
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=3,
    training_time="8H",
    track_memory=True,
    neftune_noise_alpha=5.0,
    force_batch_and_gradient_accumulation_steps_calculation=True,
    loss_re_mat="",
    dtype=jnp.bfloat16
)


train_dataset = load_dataset('csv',data_files="/kaggle/input/data_8192_v3.csv")['train']

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=None,
    tokenizer=tokenizer,
    dataset_text_field="text",
    dataset_num_proc=96,
    packing=False,
    checkpoint_path='/root/' + ckpt_name
)

output = trainer.train()






Time Took to Complete Task configure dataloaders (microseconds) : 318.15624237060547
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 2573.209047317505
Time Took to Complete Task configure functions and sharding them (microseconds) : 3353.928804397583
Action : Loading Model From /root/writer_llama3_8b-S3413.easy
Loading Checkpoints From File: 374it [02:02,  3.06it/s, shard_functions_mismatch=73]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 14
      1 trainer = SFTTrainer(
      2     arguments=train_arguments,
      3     train_dataset=train_dataset,
   (...)
     11 #     formatting_func=prompter,
     12 )
---> 14 output = trainer.train()

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:398, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    391     termcolor.cprint(
    392         "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
    393         "Process.",
    394         color="red",
    395         force_color=True
    396     )
    397 start_time = time.time()
--> 398 sharded_state, shard_fns, gather_fns = self.initialize_state(
    399     model_parameters=model_parameters,
    400     state=state
    401 )
    403 count_model_parameters(sharded_state.params)
    404 with self.mesh:

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:237, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
    233 prefix_print(
    234     "Action", f"Loading Model From {self.checkpoint_path}"
    235 )
    236 with jax.default_device(self.arguments.offload_device):
--> 237     sharded_state = EasyDelState.load_state(
    238         verbose=self.arguments.verbose,
    239         state_shard_fns=shard_fns,
    240         init_optimizer_state=True,
    241         checkpoint_path=self.checkpoint_path,
    242         input_shape=(1,8192),
    243     )
    244     state_shape = jax.eval_shape(lambda: sharded_state)
    245     state_partition_spec = match_partition_rules(
    246         self.config.get_partition_rules(
    247             fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel
    248         ) if self.arguments.custom_rule is None else self.arguments.custom_rule,
    249         state_shape
    250     )

File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:310, in EasyDelState.load_state(cls, checkpoint_path, dtype, param_dtype, precision, init_optimizer_state, state_shard_fns, verbose, input_shape)
    308                 cfg_behave[k] = eval(v)
    309     module_config = cfg.from_dict(cfg_behave)
--> 310     module_in = module(
    311         config=module_config,
    312         dtype=dtype,
    313         param_dtype=param_dtype,
    314         precision=precision,
    315         input_shape=input_shape
    316     )
    317 else:
    318     raise TypeError(
    319         "Om seems like i couldn't read model correctly ;("
    320     )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:597, in FlaxLlamaPreTrainedModel.__init__(self, config, input_shape, seed, dtype, _do_init, **kwargs)
    579 """
    580 The __init__ function is called when the class is instantiated.
    581 It sets up the instance of the class, and defines what happens when it's created.
   (...)
    594 
    595 """
    596 module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 597 super().__init__(config, module, input_shape=input_shape,
    598                  seed=seed, dtype=dtype, _do_init=_do_init)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py:447, in EasyDelFlaxPretrainedModel.__init__(self, config, module, input_shape, seed, dtype, param_dtype, precision, _do_init)
    436 def __init__(
    437         self,
    438         config: PretrainedConfig,
   (...)
    445         _do_init: bool = True,
    446 ):
--> 447     super().__init__(
    448         config=config,
    449         module=module,
    450         input_shape=input_shape,
    451         seed=seed,
    452         dtype=dtype,
    453         _do_init=_do_init
    454     )

File /usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:220, in FlaxPreTrainedModel.__init__(self, config, module, input_shape, seed, dtype, _do_init)
    216 self._is_initialized = _do_init
    218 if _do_init:
    219     # randomly initialized parameters
--> 220     random_params = self.init_weights(self.key, input_shape)
    221     params_shape_tree = jax.eval_shape(lambda params: params, random_params)
    222 else:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:632, in FlaxLlamaPreTrainedModel.init_weights(self, rng, input_shape, params)
    622     module_init_outputs = self.module.init(
    623         rngs,
    624         input_ids,
   (...)
    629         return_dict=False,
    630     )
    631 else:
--> 632     module_init_outputs = self.module.init(
    633         rngs, input_ids, attention_mask, position_ids, return_dict=False)
    635 random_params = module_init_outputs["params"]
    637 if params is not None:

    [... skipping hidden 9 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:1067, in FlaxLlamaForCausalLMModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
   1062 if position_ids is None:
   1063     position_ids = jnp.broadcast_to(
   1064         jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
   1065         (batch_size, seq_length)
   1066     )
-> 1067 outputs = self.model(
   1068     input_ids,
   1069     attention_mask,
   1070     position_ids,
   1071     deterministic=deterministic,
   1072     init_cache=init_cache,
   1073     output_attentions=output_attentions,
   1074     output_hidden_states=output_hidden_states,
   1075     return_dict=return_dict,
   1076     extra_embedding=extra_embedding
   1077 )
   1079 hidden_states = outputs[0]
   1081 if self.config.tie_word_embeddings:

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:964, in FlaxLlamaModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
    959 inputs_embeds = inputs_embeds + \
    960                 extra_embedding if extra_embedding is not None else inputs_embeds
    961 hidden_states = self.dropout(
    962     inputs_embeds, deterministic=deterministic)
--> 964 outputs = self.layers(
    965     hidden_states=hidden_states,
    966     freq_cis=self.freq_cis,
    967     attention_mask=attention_mask,
    968     position_ids=position_ids,
    969     causal_mask=self.causal_mask,
    970     deterministic=deterministic,
    971     init_cache=init_cache,
    972     output_attentions=output_attentions,
    973     output_hidden_states=output_hidden_states,
    974     return_dict=return_dict,
    975 )
    977 hidden_states = outputs[0]
    978 hidden_states = self.norm(hidden_states)

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:850, in FlaxLlamaBlockCollection.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
    847 if output_hidden_states:
    848     all_hidden_states += (hidden_states,)
--> 850 layer_outputs = block(
    851     hidden_states=hidden_states,
    852     freq_cis=freq_cis,
    853     attention_mask=attention_mask,
    854     position_ids=position_ids,
    855     causal_mask=causal_mask,
    856     deterministic=deterministic,
    857     init_cache=init_cache,
    858     output_attentions=output_attentions,
    859     fcm_mask=fcm_mask,
    860 )
    861 hidden_states = layer_outputs[0]
    863 if output_attentions:

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:530, in FlaxLlamaBlock.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
    497 def __call__(
    498         self,
    499         hidden_states: chex.Array,
   (...)
    508         fcm_mask: Optional[jnp.ndarray] = None,
    509 ):
    510     """
    511     The __call__ function is the main function of a TransformerEncoderLayer.
    512     It takes in hidden states, frequency-domain inputs, and masks as input. It then
   (...)
    528 
    529     """
--> 530     attn_outputs = self.self_attn(
    531         self.input_layernorm(hidden_states),
    532         freq_cis,
    533         attention_mask,
    534         position_ids,
    535         causal_mask,
    536         segment_ids,
    537         deterministic,
    538         init_cache,
    539         output_attentions,
    540         fcm_mask,
    541     )
    542     attn_output = attn_outputs[0]
    543     hidden_states = hidden_states + attn_output

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static.<locals>.inner(scope_fn, repack_fn, variable_groups, rng_groups, *args)
    564   y = fn(scope, *args)
    565   return y, repack_fn(scope)
--> 567 return rematted(variable_groups, rng_groups, *dyn_args)

    [... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static.<locals>.inner.<locals>.rematted(variable_groups, rng_groups, *dyn_args)
    562 args = _repack_remat_args(dyn_args, static_args)
    563 scope = scope_fn(variable_groups, rng_groups)
--> 564 y = fn(scope, *args)
    565 return y, repack_fn(scope)

    [... skipping hidden 3 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:356, in FlaxLlamaAttention.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
    347 attention_bias = lax.select(
    348     attention_mask > 0,
    349     jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
    350     jnp.full(attention_mask.shape, jnp.finfo(
    351         self.dtype).min).astype(self.dtype),
    352 )
    354 query_length, key_length = query_states.shape[1], key_states.shape[1]
--> 356 attentions = self.attention_performer.__call__(
    357     query_states=query_states,
    358     key_states=key_states,
    359     value_states=value_states,
    360     bias=attention_bias,
    361     attention_mask=attention_mask,
    362     causal=False,
    363     dropout_rng=dropout_rng,
    364     deterministic=deterministic,
    365     query_sequence_length=query_length,
    366     key_value_sequence_length=key_length,
    367     uses_cache=self.has_variable("cache", "cached_key") or init_cache,
    368     segment_ids=segment_ids
    369 )
    372 attn_output = self._merge_heads(attentions.attention_outputs)
    373 if self.config.shard_attention_computation:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:277, in AttentionModule.__call__(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache)
    268     return self.vanilla_attention(
    269         query_states=query_states,
    270         key_states=key_states,
   (...)
    274         deterministic=deterministic,
    275     )
    276 elif self.attn_mechanism == "sharded_vanilla":
--> 277     return self.sharded_vanilla_attention(
    278         query_states=query_states,
    279         key_states=key_states,
    280         value_states=value_states,
    281         bias=bias,
    282         dropout_rng=dropout_rng,
    283         deterministic=deterministic,
    284     )
    285 elif self.attn_mechanism == "ring":
    286     return self.ring_attention(
    287         query_states=query_states,
    288         key_states=key_states,
   (...)
    295         attention_mask=attention_mask
    296     )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:539, in AttentionModule.sharded_vanilla_attention(self, query_states, key_states, value_states, bias, deterministic, dropout_rng)
    537 out_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec
    538 with self.mesh:
--> 539     output = shard_map(
    540         partial(
    541             shard_vanilla_attention,
    542             deterministic=deterministic,
    543             dropout_rng=dropout_rng,
    544             dtype=dtype,
    545             precision=self.precision,
    546             attention_dropout=self.attention_dropout
    547         ),
    548         mesh=self.mesh,
    549         in_specs=(
    550             query_sequence_partition,
    551             self.key_partition_spec,
    552             self.value_partition_spec,
    553             PartitionSpec(("dp", "fsdp"), None, None, None),
    554         ),
    555         out_specs=out_spec,
    556     )(query_states, key_states, value_states, bias)
    557     output = fjformer.with_sharding_constraint(output, out_spec)
    559     return AttentionOutput(
    560         attention_weights=None,
    561         attention_outputs=output
    562     )

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/jax/experimental/shard_map.py:199, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs)
    197 if any(f is not no_fail for f in fail):
    198   msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
--> 199   raise ValueError(msg)

ValueError: shard_map applied to the function 'functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

* args[0] of shape float32[1,8192,32,128], where args[0] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'query_states', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

* args[1] of shape float32[1,8192,32,128], where args[1] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'key_states', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

* args[2] of shape float32[1,8192,32,128], where args[2] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'value_states', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

* args[3] of shape float32[1,1,8192,8192], where args[3] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' appropriately.

I can not load the easy file too... the error is similar.

from EasyDel import EasyDelState
import jax
state = EasyDelState.load_state(
    checkpoint_path='llama3_8b-S3413.easy',
    verbose=True,
#     state_shard_fns=state_shard_fns,  # You can pass that
    init_optimizer_state=False,
    input_shape=(1,8192)
)

hello, and thanks for using EasyDel

EasyDelState.load_state always use mesh_dim = (1,-1,1,1) so you can easily load the model In cases like this, just pass input_shape in EasyDelState.load_state as (8,8)

and init_input_shape in TrainArguments as (8,8) too

for now this will fix your issue, in next version ill make that automated.

@erfanzar Thank you. I can load the checkpoint with EasyDelState.load_state now. However, If I want to continue training and set the init_input_shape=(8,1024). I also have to change my batch_size to 8, which is not in line with my previous training. Otherwise, it will show below error. My previous batch_size is 3, and my sequence length is 8192.

Action : Loading Model From /root/writer_llama3_8b-S3413.easy
Loading Checkpoints From File: 374it [01:21, 4.59it/s, shard_functions_mismatch=73]
Model Contain 8.030261248 Billion Parameters
0%| | 0/170236 [00:00<?, ?it/s]

ValueError Traceback (most recent call last)
Cell In[2], line 54
37 # desired_indices = range(0, 100)
38 # train_dataset = train_dataset.select(desired_indices)
41 trainer = SFTTrainer(
42 arguments=train_arguments,
43 train_dataset=train_dataset,
(...)
51 # formatting_func=prompter,
52 )
---> 54 output = trainer.train()

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:450, in CausalLanguageModelTrainer.train(self, model_parameters, state)
443 for ssb in self.arguments.ids_to_pop_from_dataset:
444 _ = batch.pop(ssb, None)
446 (
447 sharded_state,
448 loss,
449 metrics,
--> 450 ) = self.sharded_train_step_function(sharded_state, batch)
452 trained_tokens = jnp.multiply(
453 self.arguments.max_sequence_length, jnp.multiply(
454 current_step,
455 self.arguments.total_batch_size
456 )
457 ) # It's faster
459 with jax.spmd_mode("allow_all"):

[... skipping hidden 12 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/fwd_bwd_functions.py:88, in create_casual_language_model_train_step..casual_language_model_train_step(state, batch)
85 return loss, (accuracy, z_loss_computed, aux_loss)
87 grad_fn = jax.value_and_grad(calculate_loss, has_aux=True)
---> 88 (loss__, (accuracy__, z_loss_computed__, aux_loss__)), grad = grad_fn(state.params)
89 state = state.apply_gradients(grads=grad)
91 grad_norms = jax.tree_map(jnp.linalg.norm, grad)

[... skipping hidden 8 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/fwd_bwd_functions.py:53, in create_casual_language_model_train_step..casual_language_model_train_step..calculate_loss(params)
51 else:
52 labels = labels[..., 1:]
---> 53 model_outputs = state.apply_fn(params=params, **batch, return_dict=True)
54 logits = model_outputs.logits
55 aux_loss = getattr(model_outputs, "aux_loss", None)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:746, in FlaxLlamaPreTrainedModel.call(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict, extra_embedding, add_params_field, **kwargs)
743 else:
744 mutable = False
--> 746 outputs = self.module.apply(
747 inputs,
748 jnp.array(input_ids, dtype="i4"),
749 jnp.array(attention_mask, dtype="i4"),
750 jnp.array(position_ids, dtype="i4"),
751 not train,
752 False,
753 output_attentions,
754 output_hidden_states,
755 return_dict,
756 extra_embedding,
757 rngs=rngs,
758 mutable=mutable,
759 )
761 if past_key_values is not None and return_dict:
762 outputs, past_key_values = outputs

[... skipping hidden 6 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:1067, in FlaxLlamaForCausalLMModule.call(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
1062 if position_ids is None:
1063 position_ids = jnp.broadcast_to(
1064 jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
1065 (batch_size, seq_length)
1066 )
-> 1067 outputs = self.model(
1068 input_ids,
1069 attention_mask,
1070 position_ids,
1071 deterministic=deterministic,
1072 init_cache=init_cache,
1073 output_attentions=output_attentions,
1074 output_hidden_states=output_hidden_states,
1075 return_dict=return_dict,
1076 extra_embedding=extra_embedding
1077 )
1079 hidden_states = outputs[0]
1081 if self.config.tie_word_embeddings:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:964, in FlaxLlamaModule.call(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
959 inputs_embeds = inputs_embeds +
960 extra_embedding if extra_embedding is not None else inputs_embeds
961 hidden_states = self.dropout(
962 inputs_embeds, deterministic=deterministic)
--> 964 outputs = self.layers(
965 hidden_states=hidden_states,
966 freq_cis=self.freq_cis,
967 attention_mask=attention_mask,
968 position_ids=position_ids,
969 causal_mask=self.causal_mask,
970 deterministic=deterministic,
971 init_cache=init_cache,
972 output_attentions=output_attentions,
973 output_hidden_states=output_hidden_states,
974 return_dict=return_dict,
975 )
977 hidden_states = outputs[0]
978 hidden_states = self.norm(hidden_states)

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:850, in FlaxLlamaBlockCollection.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
847 if output_hidden_states:
848 all_hidden_states += (hidden_states,)
--> 850 layer_outputs = block(
851 hidden_states=hidden_states,
852 freq_cis=freq_cis,
853 attention_mask=attention_mask,
854 position_ids=position_ids,
855 causal_mask=causal_mask,
856 deterministic=deterministic,
857 init_cache=init_cache,
858 output_attentions=output_attentions,
859 fcm_mask=fcm_mask,
860 )
861 hidden_states = layer_outputs[0]
863 if output_attentions:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:530, in FlaxLlamaBlock.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
497 def call(
498 self,
499 hidden_states: chex.Array,
(...)
508 fcm_mask: Optional[jnp.ndarray] = None,
509 ):
510 """
511 The call function is the main function of a TransformerEncoderLayer.
512 It takes in hidden states, frequency-domain inputs, and masks as input. It then
(...)
528
529 """
--> 530 attn_outputs = self.self_attn(
531 self.input_layernorm(hidden_states),
532 freq_cis,
533 attention_mask,
534 position_ids,
535 causal_mask,
536 segment_ids,
537 deterministic,
538 init_cache,
539 output_attentions,
540 fcm_mask,
541 )
542 attn_output = attn_outputs[0]
543 hidden_states = hidden_states + attn_output

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static..inner(scope_fn, repack_fn, variable_groups, rng_groups, *args)
564 y = fn(scope, *args)
565 return y, repack_fn(scope)
--> 567 return rematted(variable_groups, rng_groups, *dyn_args)

[... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static..inner..rematted(variable_groups, rng_groups, *dyn_args)
562 args = _repack_remat_args(dyn_args, static_args)
563 scope = scope_fn(variable_groups, rng_groups)
--> 564 y = fn(scope, *args)
565 return y, repack_fn(scope)

[... skipping hidden 3 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:356, in FlaxLlamaAttention.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
347 attention_bias = lax.select(
348 attention_mask > 0,
349 jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
350 jnp.full(attention_mask.shape, jnp.finfo(
351 self.dtype).min).astype(self.dtype),
352 )
354 query_length, key_length = query_states.shape[1], key_states.shape[1]
--> 356 attentions = self.attention_performer.call(
357 query_states=query_states,
358 key_states=key_states,
359 value_states=value_states,
360 bias=attention_bias,
361 attention_mask=attention_mask,
362 causal=False,
363 dropout_rng=dropout_rng,
364 deterministic=deterministic,
365 query_sequence_length=query_length,
366 key_value_sequence_length=key_length,
367 uses_cache=self.has_variable("cache", "cached_key") or init_cache,
368 segment_ids=segment_ids
369 )
372 attn_output = self._merge_heads(attentions.attention_outputs)
373 if self.config.shard_attention_computation:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:277, in AttentionModule.call(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache)
268 return self.vanilla_attention(
269 query_states=query_states,
270 key_states=key_states,
(...)
274 deterministic=deterministic,
275 )
276 elif self.attn_mechanism == "sharded_vanilla":
--> 277 return self.sharded_vanilla_attention(
278 query_states=query_states,
279 key_states=key_states,
280 value_states=value_states,
281 bias=bias,
282 dropout_rng=dropout_rng,
283 deterministic=deterministic,
284 )
285 elif self.attn_mechanism == "ring":
286 return self.ring_attention(
287 query_states=query_states,
288 key_states=key_states,
(...)
295 attention_mask=attention_mask
296 )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:539, in AttentionModule.sharded_vanilla_attention(self, query_states, key_states, value_states, bias, deterministic, dropout_rng)
537 out_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec
538 with self.mesh:
--> 539 output = shard_map(
540 partial(
541 shard_vanilla_attention,
542 deterministic=deterministic,
543 dropout_rng=dropout_rng,
544 dtype=dtype,
545 precision=self.precision,
546 attention_dropout=self.attention_dropout
547 ),
548 mesh=self.mesh,
549 in_specs=(
550 query_sequence_partition,
551 self.key_partition_spec,
552 self.value_partition_spec,
553 PartitionSpec(("dp", "fsdp"), None, None, None),
554 ),
555 out_specs=out_spec,
556 )(query_states, key_states, value_states, bias)
557 output = fjformer.with_sharding_constraint(output, out_spec)
559 return AttentionOutput(
560 attention_weights=None,
561 attention_outputs=output
562 )

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/jax/experimental/shard_map.py:199, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs)
197 if any(f is not no_fail for f in fail):
198 msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
--> 199 raise ValueError(msg)

ValueError: shard_map applied to the function 'functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

  • args[0] of shape float32[3,8192,32,128], where args[0] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'query_states', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

  • args[1] of shape float32[3,8192,32,128], where args[1] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'key_states', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

  • args[2] of shape float32[3,8192,32,128], where args[2] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'value_states', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

  • args[3] of shape float32[3,1,8192,8192], where args[3] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' appropriately.

now TrainArguments takes in a loaded_model_config_kwargs, so you can re-customize your model config

argument = TrainArguments(
   loaded_model_config_kwargs = {"axis_dims":(1,1,1,-1)} # theses arguments will replace the older configs
)

and thanks for letting me know! if there's any other issue ill be happy to help you or add new features.

@erfanzar Thank you! You are so considerable!