Unable to Load EasyDeL State
w11wo opened this issue · 6 comments
Hi, thanks for making EasyDeL.
Describe the bug
We have finished training a Mistral model using EasyDel and am trying to convert the model back to HuggingFace. However we faced the following issue:
Traceback (most recent call last):
File "/mnt/disks/persist/home/scripts/convert_easydel_hf.py", line 74, in <module>
main()
File "/mnt/disks/persist/home/scripts/convert_easydel_hf.py", line 64, in main
state = EasyDelState.load_state(args.checkpoint_path)
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/etils/easystate.py", line 309, in load_state
module_in = module(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 576, in __init__
super().__init__(config, module, input_shape=input_shape,
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py", line 390, in __init__
super().__init__(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 219, in __init__
random_params = self.init_weights(self.key, input_shape)
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 619, in init_weights
module_init_outputs = self.module.init(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 1000, in __call__
outputs = self.model(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 893, in __call__
outputs = self.layers(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 781, in __call__
output = layer(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 528, in __call__
attention_output = self.self_attn(
File "/home/davidsamuel101/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/home/davidsamuel101/.local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 403, in __call__
attentions = self.attention_performer.__call__(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py", line 193, in __call__
attentions = self._qkv_normal_flash_op(
File "/home/davidsamuel101/miniconda3/envs/llm/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py", line 440, in _qkv_normal_flash_op
attention_o = shard_map(
ValueError: shard_map applied to the function 'functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:
The mesh given has shape (1, 4, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').
* args[0] of shape float32[1,32,1,128], where args[0] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), None, 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
* args[1] of shape float32[1,32,1,128], where args[1] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
* args[2] of shape float32[1,32,1,128], where args[2] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
* args[3] of shape float32[1,32,1,1], where args[3] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)'s parameter 'ab', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 4), but 4 does not evenly divide 1
Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<PjitFunction of <function flash_attention at 0x7ef070706710>>, causal=False, sm_scale=0.08838834764831843, block_sizes=BlockSizes(block_q=1, block_k_major=128, block_k=128, block_b=1, block_q_major_dkv=1, block_k_major_dkv=128, block_k_dkv=128, block_q_dkv=1, block_k_major_dq=128, block_k_dq=128, block_q_dq=1), debug=False)' appropriately.
To Reproduce
We followed the code to load the EasyDeL state as follows:
output = trainer.train(model_parameters=model_parameters, state=None)
state = EasyDelState.load_state(output.checkpoint_path)
FYI. We are training on TPU v4-8. Thanks in advance.
Hi and thanks for using EasyDeL
The error that you have right now caused by sharing errors you have 4 ways to fix that.
-
Use cpu for offload and don't use sharding functions
-
Change input shape to 4,1024 to prevent this error (your error right now says 4 is not even to 1)
-
Use sharding axis 1,1,1,-1
-
Don't load state using flash attention
@erfanzar, thanks for the swift response. I am writing a separate model conversion script (apart from training). It is currently like this
state = EasyDelState.load_state(args.checkpoint_path)
Just this line alone will break. Could you please give the exact solution? I am a bit confused as to how to apply the changes you suggested since the additional parameters to the EasyDelState.load_state
method is only adding sharding functions. Thanks.
- Use cpu for offload and don't use sharding functions
with jax.default_device(jax.devices("cpu")[0]):
state = EasyDelState.load_state(args.checkpoint_path)
- Change input shape to 4,1024 to prevent this error (your error right now says 4 is not even to 1)
state = EasyDelState.load_state(args.checkpoint_path, input_shape=(4,1024))
Thanks for the quick fix and the response, @erfanzar. I am going to test it once my current training iteration finishes (in about ~40h), and I'll let you know if it works. Cheers.
Hi @erfanzar, I can confirm that the latest code changes work!
I ended up running:
device_num = len(jax.devices())
input_shape = (device_num, args.model_max_length) # (4, 1024)
with jax.default_device(jax.devices("cpu")[0]):
state = EasyDelState.load_state(args.checkpoint_path, input_shape=input_shape)
Cheers and thanks again!
Thank for contributing and using EasyDeL cheers❤️