Can't load checkpoints continue training
IvoryTower800 opened this issue · 7 comments
Hi, Sorry to bother you again.
currently it has to specify the sharding_axis_dims and input_shape to load model normally, which is different from it used to be. Otherwise, it will say the mesh shape is not correct 8 is not divided by 1.
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
sharding_axis_dims=(1,1,1,-1),
input_shape=(1,max_length))
I was finetuning a model and saved a checkpoint. I can train without error. However, when I want to continue finetune and add the parameter of my checkpoint_path in trainer. it says below error. Could tell me what should I do? Thank you so much!
huggingface_repo_id_or_path = "/kaggle/input/llama-3-8b-it"
max_length = 8192
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
config_kwargs={"attn_mechanism":"sharded_vanilla",'max_position_embeddings': max_length},
sharding_axis_dims=(1,1,1,-1),
input_shape=(1,max_length))
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, max_length)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="llama3_8b",
num_train_epochs=2,
configs_to_initialize_model_class=configs_to_initialize_model_class,
learning_rate=2e-6,
learning_rate_end=2e-7,
optimizer=EasyDelOptimizers.ADAMW,
scheduler=EasyDelSchedulers.LINEAR,
weight_decay=0.01,
total_batch_size=1,
max_training_steps=None, # None to let trainer Decide
do_train=True,
do_eval=False, # it's optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_sequence_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, 1, 1, -1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
init_input_shape=(1,max_length),
# everything training will be in sequence and model parallel automatic and share data between devices
remove_ckpt_after_load=True,
gradient_accumulation_steps=3,
training_time="8H",
track_memory=True,
neftune_noise_alpha=5.0,
force_batch_and_gradient_accumulation_steps_calculation=True,
loss_re_mat="",
dtype=jnp.bfloat16
)
train_dataset = load_dataset('csv',data_files="/kaggle/input/data_8192_v3.csv")['train']
trainer = SFTTrainer(
arguments=train_arguments,
train_dataset=train_dataset,
eval_dataset=None,
tokenizer=tokenizer,
dataset_text_field="text",
dataset_num_proc=96,
packing=False,
checkpoint_path='/root/' + ckpt_name
)
output = trainer.train()
Time Took to Complete Task configure dataloaders (microseconds) : 318.15624237060547
Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 2573.209047317505
Time Took to Complete Task configure functions and sharding them (microseconds) : 3353.928804397583
Action : Loading Model From /root/writer_llama3_8b-S3413.easy
Loading Checkpoints From File: 374it [02:02, 3.06it/s, shard_functions_mismatch=73]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 14
1 trainer = SFTTrainer(
2 arguments=train_arguments,
3 train_dataset=train_dataset,
(...)
11 # formatting_func=prompter,
12 )
---> 14 output = trainer.train()
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:398, in CausalLanguageModelTrainer.train(self, model_parameters, state)
391 termcolor.cprint(
392 "Performance Mode is ON, we will ignore the Memory Tracking, WANDB Logging, and extra information "
393 "Process.",
394 color="red",
395 force_color=True
396 )
397 start_time = time.time()
--> 398 sharded_state, shard_fns, gather_fns = self.initialize_state(
399 model_parameters=model_parameters,
400 state=state
401 )
403 count_model_parameters(sharded_state.params)
404 with self.mesh:
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:237, in CausalLanguageModelTrainer.initialize_state(self, model_parameters, state)
233 prefix_print(
234 "Action", f"Loading Model From {self.checkpoint_path}"
235 )
236 with jax.default_device(self.arguments.offload_device):
--> 237 sharded_state = EasyDelState.load_state(
238 verbose=self.arguments.verbose,
239 state_shard_fns=shard_fns,
240 init_optimizer_state=True,
241 checkpoint_path=self.checkpoint_path,
242 input_shape=(1,8192),
243 )
244 state_shape = jax.eval_shape(lambda: sharded_state)
245 state_partition_spec = match_partition_rules(
246 self.config.get_partition_rules(
247 fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel
248 ) if self.arguments.custom_rule is None else self.arguments.custom_rule,
249 state_shape
250 )
File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:310, in EasyDelState.load_state(cls, checkpoint_path, dtype, param_dtype, precision, init_optimizer_state, state_shard_fns, verbose, input_shape)
308 cfg_behave[k] = eval(v)
309 module_config = cfg.from_dict(cfg_behave)
--> 310 module_in = module(
311 config=module_config,
312 dtype=dtype,
313 param_dtype=param_dtype,
314 precision=precision,
315 input_shape=input_shape
316 )
317 else:
318 raise TypeError(
319 "Om seems like i couldn't read model correctly ;("
320 )
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:597, in FlaxLlamaPreTrainedModel.__init__(self, config, input_shape, seed, dtype, _do_init, **kwargs)
579 """
580 The __init__ function is called when the class is instantiated.
581 It sets up the instance of the class, and defines what happens when it's created.
(...)
594
595 """
596 module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 597 super().__init__(config, module, input_shape=input_shape,
598 seed=seed, dtype=dtype, _do_init=_do_init)
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py:447, in EasyDelFlaxPretrainedModel.__init__(self, config, module, input_shape, seed, dtype, param_dtype, precision, _do_init)
436 def __init__(
437 self,
438 config: PretrainedConfig,
(...)
445 _do_init: bool = True,
446 ):
--> 447 super().__init__(
448 config=config,
449 module=module,
450 input_shape=input_shape,
451 seed=seed,
452 dtype=dtype,
453 _do_init=_do_init
454 )
File /usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:220, in FlaxPreTrainedModel.__init__(self, config, module, input_shape, seed, dtype, _do_init)
216 self._is_initialized = _do_init
218 if _do_init:
219 # randomly initialized parameters
--> 220 random_params = self.init_weights(self.key, input_shape)
221 params_shape_tree = jax.eval_shape(lambda params: params, random_params)
222 else:
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:632, in FlaxLlamaPreTrainedModel.init_weights(self, rng, input_shape, params)
622 module_init_outputs = self.module.init(
623 rngs,
624 input_ids,
(...)
629 return_dict=False,
630 )
631 else:
--> 632 module_init_outputs = self.module.init(
633 rngs, input_ids, attention_mask, position_ids, return_dict=False)
635 random_params = module_init_outputs["params"]
637 if params is not None:
[... skipping hidden 9 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:1067, in FlaxLlamaForCausalLMModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
1062 if position_ids is None:
1063 position_ids = jnp.broadcast_to(
1064 jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
1065 (batch_size, seq_length)
1066 )
-> 1067 outputs = self.model(
1068 input_ids,
1069 attention_mask,
1070 position_ids,
1071 deterministic=deterministic,
1072 init_cache=init_cache,
1073 output_attentions=output_attentions,
1074 output_hidden_states=output_hidden_states,
1075 return_dict=return_dict,
1076 extra_embedding=extra_embedding
1077 )
1079 hidden_states = outputs[0]
1081 if self.config.tie_word_embeddings:
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:964, in FlaxLlamaModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
959 inputs_embeds = inputs_embeds + \
960 extra_embedding if extra_embedding is not None else inputs_embeds
961 hidden_states = self.dropout(
962 inputs_embeds, deterministic=deterministic)
--> 964 outputs = self.layers(
965 hidden_states=hidden_states,
966 freq_cis=self.freq_cis,
967 attention_mask=attention_mask,
968 position_ids=position_ids,
969 causal_mask=self.causal_mask,
970 deterministic=deterministic,
971 init_cache=init_cache,
972 output_attentions=output_attentions,
973 output_hidden_states=output_hidden_states,
974 return_dict=return_dict,
975 )
977 hidden_states = outputs[0]
978 hidden_states = self.norm(hidden_states)
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:850, in FlaxLlamaBlockCollection.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
847 if output_hidden_states:
848 all_hidden_states += (hidden_states,)
--> 850 layer_outputs = block(
851 hidden_states=hidden_states,
852 freq_cis=freq_cis,
853 attention_mask=attention_mask,
854 position_ids=position_ids,
855 causal_mask=causal_mask,
856 deterministic=deterministic,
857 init_cache=init_cache,
858 output_attentions=output_attentions,
859 fcm_mask=fcm_mask,
860 )
861 hidden_states = layer_outputs[0]
863 if output_attentions:
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:530, in FlaxLlamaBlock.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
497 def __call__(
498 self,
499 hidden_states: chex.Array,
(...)
508 fcm_mask: Optional[jnp.ndarray] = None,
509 ):
510 """
511 The __call__ function is the main function of a TransformerEncoderLayer.
512 It takes in hidden states, frequency-domain inputs, and masks as input. It then
(...)
528
529 """
--> 530 attn_outputs = self.self_attn(
531 self.input_layernorm(hidden_states),
532 freq_cis,
533 attention_mask,
534 position_ids,
535 causal_mask,
536 segment_ids,
537 deterministic,
538 init_cache,
539 output_attentions,
540 fcm_mask,
541 )
542 attn_output = attn_outputs[0]
543 hidden_states = hidden_states + attn_output
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static.<locals>.inner(scope_fn, repack_fn, variable_groups, rng_groups, *args)
564 y = fn(scope, *args)
565 return y, repack_fn(scope)
--> 567 return rematted(variable_groups, rng_groups, *dyn_args)
[... skipping hidden 7 frame]
File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static.<locals>.inner.<locals>.rematted(variable_groups, rng_groups, *dyn_args)
562 args = _repack_remat_args(dyn_args, static_args)
563 scope = scope_fn(variable_groups, rng_groups)
--> 564 y = fn(scope, *args)
565 return y, repack_fn(scope)
[... skipping hidden 3 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:356, in FlaxLlamaAttention.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
347 attention_bias = lax.select(
348 attention_mask > 0,
349 jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
350 jnp.full(attention_mask.shape, jnp.finfo(
351 self.dtype).min).astype(self.dtype),
352 )
354 query_length, key_length = query_states.shape[1], key_states.shape[1]
--> 356 attentions = self.attention_performer.__call__(
357 query_states=query_states,
358 key_states=key_states,
359 value_states=value_states,
360 bias=attention_bias,
361 attention_mask=attention_mask,
362 causal=False,
363 dropout_rng=dropout_rng,
364 deterministic=deterministic,
365 query_sequence_length=query_length,
366 key_value_sequence_length=key_length,
367 uses_cache=self.has_variable("cache", "cached_key") or init_cache,
368 segment_ids=segment_ids
369 )
372 attn_output = self._merge_heads(attentions.attention_outputs)
373 if self.config.shard_attention_computation:
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:277, in AttentionModule.__call__(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache)
268 return self.vanilla_attention(
269 query_states=query_states,
270 key_states=key_states,
(...)
274 deterministic=deterministic,
275 )
276 elif self.attn_mechanism == "sharded_vanilla":
--> 277 return self.sharded_vanilla_attention(
278 query_states=query_states,
279 key_states=key_states,
280 value_states=value_states,
281 bias=bias,
282 dropout_rng=dropout_rng,
283 deterministic=deterministic,
284 )
285 elif self.attn_mechanism == "ring":
286 return self.ring_attention(
287 query_states=query_states,
288 key_states=key_states,
(...)
295 attention_mask=attention_mask
296 )
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:539, in AttentionModule.sharded_vanilla_attention(self, query_states, key_states, value_states, bias, deterministic, dropout_rng)
537 out_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec
538 with self.mesh:
--> 539 output = shard_map(
540 partial(
541 shard_vanilla_attention,
542 deterministic=deterministic,
543 dropout_rng=dropout_rng,
544 dtype=dtype,
545 precision=self.precision,
546 attention_dropout=self.attention_dropout
547 ),
548 mesh=self.mesh,
549 in_specs=(
550 query_sequence_partition,
551 self.key_partition_spec,
552 self.value_partition_spec,
553 PartitionSpec(("dp", "fsdp"), None, None, None),
554 ),
555 out_specs=out_spec,
556 )(query_states, key_states, value_states, bias)
557 output = fjformer.with_sharding_constraint(output, out_spec)
559 return AttentionOutput(
560 attention_weights=None,
561 attention_outputs=output
562 )
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/jax/experimental/shard_map.py:199, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs)
197 if any(f is not no_fail for f in fail):
198 msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
--> 199 raise ValueError(msg)
ValueError: shard_map applied to the function 'functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
* args[0] of shape float32[1,8192,32,128], where args[0] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'query_states', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
* args[1] of shape float32[1,8192,32,128], where args[1] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'key_states', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
* args[2] of shape float32[1,8192,32,128], where args[2] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'value_states', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
* args[3] of shape float32[1,1,8192,8192], where args[3] is bound to functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1
Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function shard_vanilla_attention at 0x79f014655c60>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' appropriately.
I can not load the easy file too... the error is similar.
from EasyDel import EasyDelState
import jax
state = EasyDelState.load_state(
checkpoint_path='llama3_8b-S3413.easy',
verbose=True,
# state_shard_fns=state_shard_fns, # You can pass that
init_optimizer_state=False,
input_shape=(1,8192)
)
hello, and thanks for using EasyDel
EasyDelState.load_state
always use mesh_dim = (1,-1,1,1) so you can easily load the model In cases like this, just pass input_shape
in EasyDelState.load_state
as (8,8)
and init_input_shape
in TrainArguments
as (8,8) too
for now this will fix your issue, in next version ill make that automated.
@erfanzar Thank you. I can load the checkpoint with EasyDelState.load_state
now. However, If I want to continue training and set the init_input_shape=(8,1024)
. I also have to change my batch_size to 8
, which is not in line with my previous training. Otherwise, it will show below error. My previous batch_size is 3, and my sequence length is 8192
.
Action : Loading Model From /root/writer_llama3_8b-S3413.easy
Loading Checkpoints From File: 374it [01:21, 4.59it/s, shard_functions_mismatch=73]
Model Contain 8.030261248 Billion Parameters
0%| | 0/170236 [00:00<?, ?it/s]
ValueError Traceback (most recent call last)
Cell In[2], line 54
37 # desired_indices = range(0, 100)
38 # train_dataset = train_dataset.select(desired_indices)
41 trainer = SFTTrainer(
42 arguments=train_arguments,
43 train_dataset=train_dataset,
(...)
51 # formatting_func=prompter,
52 )
---> 54 output = trainer.train()
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py:450, in CausalLanguageModelTrainer.train(self, model_parameters, state)
443 for ssb in self.arguments.ids_to_pop_from_dataset:
444 _ = batch.pop(ssb, None)
446 (
447 sharded_state,
448 loss,
449 metrics,
--> 450 ) = self.sharded_train_step_function(sharded_state, batch)
452 trained_tokens = jnp.multiply(
453 self.arguments.max_sequence_length, jnp.multiply(
454 current_step,
455 self.arguments.total_batch_size
456 )
457 ) # It's faster
459 with jax.spmd_mode("allow_all"):
[... skipping hidden 12 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/fwd_bwd_functions.py:88, in create_casual_language_model_train_step..casual_language_model_train_step(state, batch)
85 return loss, (accuracy, z_loss_computed, aux_loss)
87 grad_fn = jax.value_and_grad(calculate_loss, has_aux=True)
---> 88 (loss__, (accuracy__, z_loss_computed__, aux_loss__)), grad = grad_fn(state.params)
89 state = state.apply_gradients(grads=grad)
91 grad_norms = jax.tree_map(jnp.linalg.norm, grad)
[... skipping hidden 8 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/fwd_bwd_functions.py:53, in create_casual_language_model_train_step..casual_language_model_train_step..calculate_loss(params)
51 else:
52 labels = labels[..., 1:]
---> 53 model_outputs = state.apply_fn(params=params, **batch, return_dict=True)
54 logits = model_outputs.logits
55 aux_loss = getattr(model_outputs, "aux_loss", None)
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:746, in FlaxLlamaPreTrainedModel.call(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict, extra_embedding, add_params_field, **kwargs)
743 else:
744 mutable = False
--> 746 outputs = self.module.apply(
747 inputs,
748 jnp.array(input_ids, dtype="i4"),
749 jnp.array(attention_mask, dtype="i4"),
750 jnp.array(position_ids, dtype="i4"),
751 not train,
752 False,
753 output_attentions,
754 output_hidden_states,
755 return_dict,
756 extra_embedding,
757 rngs=rngs,
758 mutable=mutable,
759 )
761 if past_key_values is not None and return_dict:
762 outputs, past_key_values = outputs
[... skipping hidden 6 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:1067, in FlaxLlamaForCausalLMModule.call(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
1062 if position_ids is None:
1063 position_ids = jnp.broadcast_to(
1064 jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
1065 (batch_size, seq_length)
1066 )
-> 1067 outputs = self.model(
1068 input_ids,
1069 attention_mask,
1070 position_ids,
1071 deterministic=deterministic,
1072 init_cache=init_cache,
1073 output_attentions=output_attentions,
1074 output_hidden_states=output_hidden_states,
1075 return_dict=return_dict,
1076 extra_embedding=extra_embedding
1077 )
1079 hidden_states = outputs[0]
1081 if self.config.tie_word_embeddings:
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:964, in FlaxLlamaModule.call(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
959 inputs_embeds = inputs_embeds +
960 extra_embedding if extra_embedding is not None else inputs_embeds
961 hidden_states = self.dropout(
962 inputs_embeds, deterministic=deterministic)
--> 964 outputs = self.layers(
965 hidden_states=hidden_states,
966 freq_cis=self.freq_cis,
967 attention_mask=attention_mask,
968 position_ids=position_ids,
969 causal_mask=self.causal_mask,
970 deterministic=deterministic,
971 init_cache=init_cache,
972 output_attentions=output_attentions,
973 output_hidden_states=output_hidden_states,
974 return_dict=return_dict,
975 )
977 hidden_states = outputs[0]
978 hidden_states = self.norm(hidden_states)
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:850, in FlaxLlamaBlockCollection.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
847 if output_hidden_states:
848 all_hidden_states += (hidden_states,)
--> 850 layer_outputs = block(
851 hidden_states=hidden_states,
852 freq_cis=freq_cis,
853 attention_mask=attention_mask,
854 position_ids=position_ids,
855 causal_mask=causal_mask,
856 deterministic=deterministic,
857 init_cache=init_cache,
858 output_attentions=output_attentions,
859 fcm_mask=fcm_mask,
860 )
861 hidden_states = layer_outputs[0]
863 if output_attentions:
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:530, in FlaxLlamaBlock.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
497 def call(
498 self,
499 hidden_states: chex.Array,
(...)
508 fcm_mask: Optional[jnp.ndarray] = None,
509 ):
510 """
511 The call function is the main function of a TransformerEncoderLayer.
512 It takes in hidden states, frequency-domain inputs, and masks as input. It then
(...)
528
529 """
--> 530 attn_outputs = self.self_attn(
531 self.input_layernorm(hidden_states),
532 freq_cis,
533 attention_mask,
534 position_ids,
535 causal_mask,
536 segment_ids,
537 deterministic,
538 init_cache,
539 output_attentions,
540 fcm_mask,
541 )
542 attn_output = attn_outputs[0]
543 hidden_states = hidden_states + attn_output
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static..inner(scope_fn, repack_fn, variable_groups, rng_groups, *args)
564 y = fn(scope, *args)
565 return y, repack_fn(scope)
--> 567 return rematted(variable_groups, rng_groups, *dyn_args)
[... skipping hidden 7 frame]
File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static..inner..rematted(variable_groups, rng_groups, *dyn_args)
562 args = _repack_remat_args(dyn_args, static_args)
563 scope = scope_fn(variable_groups, rng_groups)
--> 564 y = fn(scope, *args)
565 return y, repack_fn(scope)
[... skipping hidden 3 frame]
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:356, in FlaxLlamaAttention.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
347 attention_bias = lax.select(
348 attention_mask > 0,
349 jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
350 jnp.full(attention_mask.shape, jnp.finfo(
351 self.dtype).min).astype(self.dtype),
352 )
354 query_length, key_length = query_states.shape[1], key_states.shape[1]
--> 356 attentions = self.attention_performer.call(
357 query_states=query_states,
358 key_states=key_states,
359 value_states=value_states,
360 bias=attention_bias,
361 attention_mask=attention_mask,
362 causal=False,
363 dropout_rng=dropout_rng,
364 deterministic=deterministic,
365 query_sequence_length=query_length,
366 key_value_sequence_length=key_length,
367 uses_cache=self.has_variable("cache", "cached_key") or init_cache,
368 segment_ids=segment_ids
369 )
372 attn_output = self._merge_heads(attentions.attention_outputs)
373 if self.config.shard_attention_computation:
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:277, in AttentionModule.call(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, attention_mask, segment_ids, causal, deterministic, dropout_rng, uses_cache)
268 return self.vanilla_attention(
269 query_states=query_states,
270 key_states=key_states,
(...)
274 deterministic=deterministic,
275 )
276 elif self.attn_mechanism == "sharded_vanilla":
--> 277 return self.sharded_vanilla_attention(
278 query_states=query_states,
279 key_states=key_states,
280 value_states=value_states,
281 bias=bias,
282 dropout_rng=dropout_rng,
283 deterministic=deterministic,
284 )
285 elif self.attn_mechanism == "ring":
286 return self.ring_attention(
287 query_states=query_states,
288 key_states=key_states,
(...)
295 attention_mask=attention_mask
296 )
File /usr/local/lib/python3.10/site-packages/EasyDel/modules/attention_module.py:539, in AttentionModule.sharded_vanilla_attention(self, query_states, key_states, value_states, bias, deterministic, dropout_rng)
537 out_spec = self.generation_attention_partition_spec if is_generating else self.attention_partition_spec
538 with self.mesh:
--> 539 output = shard_map(
540 partial(
541 shard_vanilla_attention,
542 deterministic=deterministic,
543 dropout_rng=dropout_rng,
544 dtype=dtype,
545 precision=self.precision,
546 attention_dropout=self.attention_dropout
547 ),
548 mesh=self.mesh,
549 in_specs=(
550 query_sequence_partition,
551 self.key_partition_spec,
552 self.value_partition_spec,
553 PartitionSpec(("dp", "fsdp"), None, None, None),
554 ),
555 out_specs=out_spec,
556 )(query_states, key_states, value_states, bias)
557 output = fjformer.with_sharding_constraint(output, out_spec)
559 return AttentionOutput(
560 attention_weights=None,
561 attention_outputs=output
562 )
[... skipping hidden 2 frame]
File /usr/local/lib/python3.10/site-packages/jax/experimental/shard_map.py:199, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs)
197 if any(f is not no_fail for f in fail):
198 msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
--> 199 raise ValueError(msg)
ValueError: shard_map applied to the function 'functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
-
args[0] of shape float32[3,8192,32,128], where args[0] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'query_states', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3
-
args[1] of shape float32[3,8192,32,128], where args[1] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'key_states', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3
-
args[2] of shape float32[3,8192,32,128], where args[2] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'value_states', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3
-
args[3] of shape float32[3,1,8192,8192], where args[3] is bound to functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)'s parameter 'bias', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 3) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 3
Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<function shard_vanilla_attention at 0x7fe2880ec430>, deterministic=True, dropout_rng=None, dtype=dtype('float32'), precision=None, attention_dropout=0.0)' appropriately.
now TrainArguments
takes in a loaded_model_config_kwargs
, so you can re-customize your model config
argument = TrainArguments(
loaded_model_config_kwargs = {"axis_dims":(1,1,1,-1)} # theses arguments will replace the older configs
)
and thanks for letting me know! if there's any other issue ill be happy to help you or add new features.
@erfanzar Thank you! You are so considerable!