erfanzar/EasyDeL

Unable to Load EasyDeL State

w11wo opened this issue · 6 comments

Hi, thanks for making EasyDeL.

Describe the bug

We have finished training a Mistral model using EasyDel and am trying to convert the model back to HuggingFace. However we faced the following issue:

Traceback (most recent call last):
  File "/mnt/disks/persist/home/scripts/convert_easydel_hf.py", line 74, in <module>
    main()
  File "/mnt/disks/persist/home/scripts/convert_easydel_hf.py", line 64, in main
    state = EasyDelState.load_state(args.checkpoint_path)
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/etils/easystate.py", line 309, in load_state
    module_in = module(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 576, in __init__
    super().__init__(config, module, input_shape=input_shape,
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py", line 390, in __init__
    super().__init__(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 219, in __init__
    random_params = self.init_weights(self.key, input_shape)
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 619, in init_weights
    module_init_outputs = self.module.init(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 1000, in __call__
    outputs = self.model(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 893, in __call__
    outputs = self.layers(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 781, in __call__
    output = layer(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 528, in __call__
    attention_output = self.self_attn(
  File "/home/davidsamuel101/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
    return rematted(variable_groups, rng_groups, *dyn_args)
  File "/home/davidsamuel101/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
    y = fn(scope, *args)
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 403, in __call__
    attentions = self.attention_performer.__call__(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py", line 193, in __call__
    attentions = self._qkv_normal_flash_op(
  File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py", line 440, in _qkv_normal_flash_op
    attention_o = shard_map(
ValueError: shard_map applied to the function 'functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 4, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

* args[0] of shape float32[1,32,1,128], where args[0] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), None, 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

* args[1] of shape float32[1,32,1,128], where args[1] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

* args[2] of shape float32[1,32,1,128], where args[2] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

* args[3] of shape float32[1,32,1,1], where args[3] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'ab', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)' appropriately.

To Reproduce

We followed the code to load the EasyDeL state as follows:

output = trainer.train(model_parameters=model_parameters, state=None)
state = EasyDelState.load_state(output.checkpoint_path)

FYI. We are training on TPU v4-8. Thanks in advance.

Hi and thanks for using EasyDeL
The error that you have right now caused by sharing errors you have 4 ways to fix that.

  1. Use cpu for offload and don't use sharding functions

  2. Change input shape to 4,1024 to prevent this error (your error right now says 4 is not even to 1)

  3. Use sharding axis 1,1,1,-1

  4. Don't load state using flash attention

@erfanzar, thanks for the swift response. I am writing a separate model conversion script (apart from training). It is currently like this

state = EasyDelState.load_state(args.checkpoint_path)

Just this line alone will break. Could you please give the exact solution? I am a bit confused as to how to apply the changes you suggested since the additional parameters to the EasyDelState.load_state method is only adding sharding functions. Thanks.

  1. Use cpu for offload and don't use sharding functions
with jax.default_device(jax.devices("cpu")[0]):
    state = EasyDelState.load_state(args.checkpoint_path)
  1. Change input shape to 4,1024 to prevent this error (your error right now says 4 is not even to 1)
 state = EasyDelState.load_state(args.checkpoint_path, input_shape=(4,1024))

Thanks for the quick fix and the response, @erfanzar. I am going to test it once my current training iteration finishes (in about ~40h), and I'll let you know if it works. Cheers.

Hi @erfanzar, I can confirm that the latest code changes work!

I ended up running:

device_num = len(jax.devices())
input_shape = (device_num, args.model_max_length) # (4, 1024)

with jax.default_device(jax.devices("cpu")[0]):
    state = EasyDelState.load_state(args.checkpoint_path, input_shape=input_shape)

Cheers and thanks again!

Thank for contributing and using EasyDeL cheers❤️