erfanzar/EasyDeL

value error using flash attention

Closed this issue · 1 comments

Im using exactly the document sample to fine tune Tinylama 1.1B on ultra chat 200k dataset , after logging in to wandb , this error happened :

Sincerely, how to fix this
document link : https://easydel.readthedocs.io/en/latest/finetuning_example.html
If you mind , please create a colab gist about the documentation and comment it here

'''
alueError: shard_map applied to the function 'functools.partial(<PjitFunction of <function flash_attention at 0x7d5fa01e96c0>>, causal=True, sm_scale=0.125, block_sizes=BlockSizes(block_q=1, block_k_major=1, block_k=1, block_b=1, block_q_major_dkv=1, block_k_major_dkv=1, block_k_dkv=1, block_q_dkv=1, block_k_major_dq=1, block_k_dq=1, block_q_dq=1), debug=False)' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

  • args[0] of shape float32[1,32,1,64], where args[0] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7d5fa01e96c0>>, causal=True, sm_scale=0.125, block_sizes=BlockSizes(block_q=1, block_k_major=1, block_k=1, block_b=1, block_q_major_dkv=1, block_k_major_dkv=1, block_k_dkv=1, block_q_dkv=1, block_k_major_dq=1, block_k_dq=1, block_q_dq=1), debug=False)'s parameter 'q', corresponds to in_specs[0] of value PartitionSpec(('dp', 'fsdp'), None, 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

  • args[1] of shape float32[1,32,1,64], where args[1] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7d5fa01e96c0>>, causal=True, sm_scale=0.125, block_sizes=BlockSizes(block_q=1, block_k_major=1, block_k=1, block_b=1, block_q_major_dkv=1, block_k_major_dkv=1, block_k_dkv=1, block_q_dkv=1, block_k_major_dq=1, block_k_dq=1, block_q_dq=1), debug=False)'s parameter 'k', corresponds to in_specs[1] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

  • args[2] of shape float32[1,32,1,64], where args[2] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7d5fa01e96c0>>, causal=True, sm_scale=0.125, block_sizes=BlockSizes(block_q=1, block_k_major=1, block_k=1, block_b=1, block_q_major_dkv=1, block_k_major_dkv=1, block_k_dkv=1, block_q_dkv=1, block_k_major_dq=1, block_k_dq=1, block_q_dq=1), debug=False)'s parameter 'v', corresponds to in_specs[2] of value PartitionSpec(('dp', 'fsdp'), 'sp', 'tp', None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

  • args[3] of shape bfloat16[1,32,1,1], where args[3] is bound to functools.partial(<PjitFunction of <function flash_attention at 0x7d5fa01e96c0>>, causal=True, sm_scale=0.125, block_sizes=BlockSizes(block_q=1, block_k_major=1, block_k=1, block_b=1, block_q_major_dkv=1, block_k_major_dkv=1, block_k_dkv=1, block_q_dkv=1, block_k_major_dq=1, block_k_dq=1, block_q_dq=1), debug=False)'s parameter 'ab', corresponds to in_specs[3] of value PartitionSpec(('dp', 'fsdp'), None, None, None), which maps array axis 0 (of size 1) to mesh axes ('dp', 'fsdp') (of total size 8), but 8 does not evenly divide 1

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<PjitFunction of <function flash_attention at 0x7d5fa01e96c0>>, causal=True, sm_scale=0.125, block_sizes=BlockSizes(block_q=1, block_k_major=1, block_k=1, block_b=1, block_q_major_dkv=1, block_k_major_dkv=1, block_k_dkv=1, block_q_dkv=1, block_k_major_dq=1, block_k_dq=1, block_q_dq=1), debug=False)' appropriately.
'''
Screenshot 2024-05-30 160246

hi
actually you have to first learn to use jax sharding method and strategis and read the whole docs if you want to use advanced options and i noticed you are using colab and colab give you TPU-v2, easydel right now works on TPU-v2 and v3 but in next version supporting those accelerators will be droped (maybe)

there's bug in documentation with readthedocs and they are not available at the moment but soon they will be fixed.

to fix this issue
set "input_shape": (8, 8) this attribute is depended on you devices and accelerators