
Can't load checkpoints continue training

Hi, Sorry to bother you again.

currently it has to specify the sharding_axis_dims and input_shape to load model normally, which is different from it used to be. Otherwise, it will say the mesh shape is not correct 8 is not divided by 1.

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,

I was finetuning a model and saved a checkpoint. I can train without error. However, when I want to continue finetune and add the parameter of my checkpoint_path in trainer. it says below error. Could tell me what should I do? Thank you so much!

huggingface_repo_id_or_path = "/kaggle/input/llama-3-8b-it"
max_length = 8192
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            config_kwargs={"attn_mechanism":"sharded_vanilla",'max_position_embeddings': max_length},

tokenizer = AutoTokenizer.from_pretrained(
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, max_length)

train_arguments = TrainArguments(
    max_training_steps=None,  # None to let trainer Decide
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_sequence_length=max_length,  # Note that you have to change this in the model config too
    sharding_array=(1, 1, 1, -1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices

train_dataset = load_dataset('csv',data_files="/kaggle/input/data_8192_v3.csv")['train']

trainer = SFTTrainer(
    checkpoint_path='/root/' + ckpt_name

output = trainer.train()

Time Took to Complete Task configure dataloaders (microseconds) : 318.15624237060547
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 2573.209047317505
Time Took to Complete Task configure functions and sharding them (microseconds) : 3353.928804397583
Action : Loading Model From /root/writer_llama3_8b-S3413.easy
Loading Checkpoints From File: 374it [02:02,  3.06it/s, shard_functions_mismatch=73]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/, in FlaxLlamaForCausalLMModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
   1062 if position_ids is None:
   1063     position_ids = jnp.broadcast_to(
   1064         jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
   1065         (batch_size, seq_length)
   1066     )
-> 1067 outputs = self.model(
   1068     input_ids,
   1069     attention_mask,
   1070     position_ids,
   1071     deterministic=deterministic,
   1072     init_cache=init_cache,
   1073     output_attentions=output_attentions,
   1074     output_hidden_states=output_hidden_states,
   1075     return_dict=return_dict,
   1076     extra_embedding=extra_embedding
   1077 )
   1079 hidden_states = outputs[0]
   1081 if self.config.tie_word_embeddings:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/, in FlaxLlamaModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
    959 inputs_embeds = inputs_embeds + \
    960                 extra_embedding if extra_embedding is not None else inputs_embeds
    961 hidden_states = self.dropout(
    962     inputs_embeds, deterministic=deterministic)
--> 964 outputs = self.layers(
    965     hidden_states=hidden_states,
    966     freq_cis=self.freq_cis,
    967     attention_mask=attention_mask,
    968     position_ids=position_ids,
    969     causal_mask=self.causal_mask,
    970     deterministic=deterministic,
    971     init_cache=init_cache,
    972     output_attentions=output_attentions,
    973     output_hidden_states=output_hidden_states,
    974     return_dict=return_dict,
    975 )
    977 hidden_states = outputs[0]
    978 hidden_states = self.norm(hidden_states)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/, in FlaxLlamaBlockCollection.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
    847 if output_hidden_states:
    848     all_hidden_states += (hidden_states,)
--> 850 layer_outputs = block(
    851     hidden_states=hidden_states,
    852     freq_cis=freq_cis,
    853     attention_mask=attention_mask,
    854     position_ids=position_ids,
    855     causal_mask=causal_mask,
    856     deterministic=deterministic,
    857     init_cache=init_cache,
    858     output_attentions=output_attentions,
    859     fcm_mask=fcm_mask,
    860 )
    861 hidden_states = layer_outputs[0]
    863 if output_attentions:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/, in FlaxLlamaBlock.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
    497 def __call__(
    498         self,
    499         hidden_states: chex.Array,
    508         fcm_mask: Optional[jnp.ndarray] = None,
    509 ):
    510     """
    511     The __call__ function is the main function of a TransformerEncoderLayer.
    512     It takes in hidden states, frequency-domain inputs, and masks as input. It then
    529     """
--> 530     attn_outputs = self.self_attn(
    531         self.input_layernorm(hidden_states),
    532         freq_cis,
    533         attention_mask,
    534         position_ids,
    535         causal_mask,
    536         segment_ids,
    537         deterministic,
    538         init_cache,
    539         output_attentions,
    540         fcm_mask,
    541     )
    542     attn_output = attn_outputs[0]
    543     hidden_states = hidden_states + attn_output

File /usr/local/lib/python3.10/site-packages/flax/linen/, in core_remat_static.<locals>.inner(scope_fn, repack_fn, variable_groups, rng_groups, *args)
    564   y = fn(scope, *args)
    565   return y, repack_fn(scope)
--> 567 return rematted(variable_groups, rng_groups, *dyn_args)

File /usr/local/lib/python3.10/site-packages/flax/linen/, in core_remat_static.<locals>.inner.<locals>.rematted(variable_groups, rng_groups, *dyn_args)
    562 args = _repack_remat_args(dyn_args, static_args)
    563 scope = scope_fn(variable_groups, rng_groups)
--> 564 y = fn(scope, *args)
    565 return y, repack_fn(scope)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/, in FlaxLlamaAttention.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
    347 attention_bias =
    348     attention_mask > 0,
    349     jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
    350     jnp.full(attention_mask.shape, jnp.finfo(
    351         self.dtype).min).astype(self.dtype),
    352 )
    354 query_length, key_length = query_states.shape[1], key_states.shape[1]
--> 356 attentions = self.attention_performer.__call__(
    357     query_states=query_states,
    358     key_states=key_states,
    359     value_states=value_states,
    360     bias=attention_bias,
    361     attention_mask=attention_mask,
    362     causal=False,
    363     dropout_rng=dropout_rng,
    364     deterministic=deterministic,
    365     query_sequence_length=query_length,
    366     key_value_sequence_length=key_length,
    367     uses_cache=self.has_variable("cache", "cached_key") or init_cache,
    368     segment_ids=segment_ids
    369 )
    372 attn_output = self._merge_heads(attentions.attention_outputs)
    373 if self.config.shard_attention_computation:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/, in AttentionModule.__call__(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache)
    268     return self.vanilla_attention(
    269         query_states=query_states,
    270         key_states=key_states,
    274         deterministic=deterministic,
    275     )
    276 elif self.attn_mechanism == "sharded_vanilla":
--> 277     return self.sharded_vanilla_attention(
    278         query_states=query_states,
    279         key_states=key_states,
    280         value_states=value_states,
    281         bias=bias,
    282         dropout_rng=dropout_rng,
    283         deterministic=deterministic,
    284     )
    285 elif self.attn_mechanism == "ring":
    286     return self.ring_attention(
    287         query_states=query_states,
    288         key_states=key_states,
    295         attention_mask=attention_mask
    296     )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/, in AttentionModule.sharded_vanilla_attention(self, query_states, key_states, value_states, bias, deterministic, dropout_rng)
    537 out_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec
    538 with self.mesh:
--> 539     output = shard_map(
    540         partial(
    541             shard_vanilla_attention,
    542             deterministic=deterministic,
    543             dropout_rng=dropout_rng,
    544             dtype=dtype,
    545             precision=self.precision,
    546             attention_dropout=self.attention_dropout
    547         ),
    548         mesh=self.mesh,
    549         in_specs=(
    550             query_sequence_partition,
    551             self.key_partition_spec,
    552             self.value_partition_spec,
    553             PartitionSpec(("dp", "fsdp"), None, None, None),
    554         ),
    555         out_specs=out_spec,
    556     )(query_states, key_states, value_states, bias)
    557     output = fjformer.with_sharding_constraint(output, out_spec)
    559     return AttentionOutput(
    560         attention_weights=None,
    561         attention_outputs=output
    562     )

File /usr/local/lib/python3.10/site-packages/jax/experimental/, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs)
    197 if any(f is not no_fail for f in fail):
    198   msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
--> 199   raise ValueError(msg)

ValueError: shard_map applied to the function 'functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

* args[0] of shape float32[1,8192,32,128], where args[0] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'query_states', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

* args[1] of shape float32[1,8192,32,128], where args[1] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'key_states', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

* args[2] of shape float32[1,8192,32,128], where args[2] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'value_states', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

* args[3] of shape float32[1,1,8192,8192], where args[3] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' appropriately.

I can not load the easy file too... the error is similar.

from EasyDel import EasyDelState
import jax
state = EasyDelState.load_state(
#     state_shard_fns=state_shard_fns,  # You can pass that

hello, and thanks for using EasyDel

EasyDelState.load_state always use mesh_dim = (1,-1,1,1) so you can easily load the model In cases like this, just pass input_shape in EasyDelState.load_state as (8,8)

and init_input_shape in TrainArguments as (8,8) too

for now this will fix your issue, in next version ill make that automated.

@erfanzar Thank you. I can load the checkpoint with EasyDelState.load_state now. However, If I want to continue training and set the init_input_shape=(8,1024). I also have to change my batch_size to 8, which is not in line with my previous training. Otherwise, it will show below error. My previous batch_size is 3, and my sequence length is 8192.

Action : Loading Model From /root/writer_llama3_8b-S3413.easy
Loading Checkpoints From File: 374it [01:21, 4.59it/s, shard_functions_mismatch=73]
Model Contain 8.030261248 Billion Parameters
0%| | 0/170236 [00:00<?, ?it/s]

[... skipping hidden 12 frame]

[... skipping hidden 8 frame]

[... skipping hidden 6 frame]

[... skipping hidden 2 frame]

[... skipping hidden 2 frame]

[... skipping hidden 2 frame]

[... skipping hidden 2 frame]

[... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/, in
347 attention_bias =
348 attention_mask > 0,
349 jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
350 jnp.full(attention_mask.shape, jnp.finfo(
351 self.dtype).min).astype(self.dtype),
352 )
354 query_length, key_length = query_states.shape[1], key_states.shape[1]
--> 356 attentions =
357 query_states=query_states,
358 key_states=key_states,
359 value_states=value_states,
360 bias=attention_bias,
361 attention_mask=attention_mask,
362 causal=False,
363 dropout_rng=dropout_rng,
364 deterministic=deterministic,
365 query_sequence_length=query_length,
366 key_value_sequence_length=key_length,
367 uses_cache=self.has_variable("cache", "cached_key") or init_cache,
368 segment_ids=segment_ids
369 )
372 attn_output = self._merge_heads(attentions.attention_outputs)
373 if self.config.shard_attention_computation:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/, in
268 return self.vanilla_attention(
269 query_states=query_states,
270 key_states=key_states,
274 deterministic=deterministic,
275 )
276 elif self.attn_mechanism == "sharded_vanilla":
--> 277 return self.sharded_vanilla_attention(
278 query_states=query_states,
279 key_states=key_states,
280 value_states=value_states,
281 bias=bias,
282 dropout_rng=dropout_rng,
283 deterministic=deterministic,
284 )
285 elif self.attn_mechanism == "ring":
286 return self.ring_attention(
287 query_states=query_states,
288 key_states=key_states,
295 attention_mask=attention_mask
296 )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/, in AttentionModule.sharded_vanilla_attention(self, query_states, key_states, value_states, bias, deterministic, dropout_rng)
537 out_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec
538 with self.mesh:
--> 539 output = shard_map(
540 partial(
541 shard_vanilla_attention,
542 deterministic=deterministic,
543 dropout_rng=dropout_rng,
544 dtype=dtype,
545 precision=self.precision,
546 attention_dropout=self.attention_dropout
547 ),
548 mesh=self.mesh,
549 in_specs=(
550 query_sequence_partition,
551 self.key_partition_spec,
552 self.value_partition_spec,
553 PartitionSpec(("dp", "fsdp"), None, None, None),
554 ),
555 out_specs=out_spec,
556 )(query_states, key_states, value_states, bias)
557 output = fjformer.with_sharding_constraint(output, out_spec)
559 return AttentionOutput(
560 attention_weights=None,
561 attention_outputs=output
562 )

File /usr/local/lib/python3.10/site-packages/jax/experimental/, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs)
197 if any(f is not no_fail for f in fail):
198 msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
--> 199 raise ValueError(msg)

ValueError: shard_map applied to the function 'functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

  • args[0] of shape float32[3,8192,32,128], where args[0] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'query_states', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

  • args[1] of shape float32[3,8192,32,128], where args[1] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'key_states', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

  • args[2] of shape float32[3,8192,32,128], where args[2] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'value_states', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

  • args[3] of shape float32[3,1,8192,8192], where args[3] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' appropriately.

now TrainArguments takes in a loaded_model_config_kwargs, so you can re-customize your model config

argument = TrainArguments(
   loaded_model_config_kwargs = {"axis_dims":(1,1,1,-1)} # theses arguments will replace the older configs

and thanks for letting me know! if there's any other issue ill be happy to help you or add new features.

@erfanzar Thank you! You are so considerable!