BlinkDL/RWKV-LM

size mismatch for blocks.11.ffn.value.weight: copying a param with shape

zhaodice opened this issue · 3 comments

Command in RWKV-v4neo

python train.py --load_model /home/user/models/LLM/rwkv/RWKV-5-World-3B-v2-20231118-ctx16k.pth --proj_dir ./test --data_file ttt_text_document --data_type binidx --vocab_size 65536 --ctx_len 16384 --epoch_steps 10 --epoch_count 100 --epoch_begin 0 --epoch_save 5 --micro_bsz 1 --n_layer 24 --n_embd 2560 --pre_ffn 0 --head_qk 0 --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 --my_testing "r2r4"

but it reported :

size mismatch for blocks.11.ffn.value.weight: copying a param with shape torch.Size([2560, 8960]) from checkpoint, the shape in current model is torch.Size([2560, 10240]).

I have read https://huggingface.co/BlinkDL/rwkv-5-world , and no idea to set the param... please help me!

I've tried but :

(rwkv-role) user@calculator:~/git/RWKV_Role_Playing_API/RWKV-LM/RWKV-v5$ python train.py --load_model /home/user/models/LLM/rwkv/RWKV-5-World-3B-v2-20231118-ctx16k.pth --proj_dir ./test --data_file ttt_text_document --data_type binidx --vocab_size 65536 --ctx_len 16384 --epoch_steps 10 --epoch_count 100 --epoch_begin 0 --epoch_save 5 --micro_bsz 1 --n_layer 26 --n_embd 2560 --pre_ffn 0 --head_qk 0 --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 --my_testing "r2r4"
ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_Z15kernel_backwardIN3c108BFloat16EEviiiiPKT_S4_S4_PKfS6_S4_S4_PS2_S7_S7_S7_S7_' for 'sm_89'
ptxas info    : Function properties for _Z15kernel_backwardIN3c108BFloat16EEviiiiPKT_S4_S4_PKfS6_S4_S4_PS2_S7_S7_S7_S7_
    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 168 registers, 1536 bytes smem, 464 bytes cmem[0]
ptxas info    : Compiling entry function '_Z14kernel_forwardIN3c108BFloat16EEviiiiPKT_S4_S4_PKfS4_PS2_' for 'sm_89'
ptxas info    : Function properties for _Z14kernel_forwardIN3c108BFloat16EEviiiiPKT_S4_S4_PKfS4_PS2_
    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 100 registers, 1024 bytes smem, 416 bytes cmem[0]
[3/3] c++ wkv5_op.o wkv5_cuda.cuda.o -shared -L/home/user/anaconda3/envs/rwkv-role/lib/python3.9/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda-12.3/lib64 -lcudart -o wkv5.so
Loading extension module wkv5...
INFO:pytorch_lightning.utilities.rank_zero:########## Loading /home/user/models/LLM/rwkv/RWKV-5-World-3B-v2-20231118-ctx16k.pth... ##########
Traceback (most recent call last):
  File "/home/user/git/RWKV_Role_Playing_API/RWKV-LM/RWKV-v5/train.py", line 281, in <module>
    model.load_state_dict(load_dict)
  File "/home/user/anaconda3/envs/rwkv-role/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RWKV:
        Unexpected key(s) in state_dict: "blocks.26.ln1.weight", "blocks.26.ln1.bias", "blocks.26.ln2.weight", "blocks.26.ln2.bias", "blocks.26.att.time_mix_k", "blocks.26.att.time_mix_v", "blocks.26.att.time_mix_r", "blocks.26.att.time_mix_g", "blocks.26.att.time_decay", "blocks.26.att.time_faaaa", "blocks.26.att.receptance.weight", "blocks.26.att.key.weight", "blocks.26.att.value.weight", "blocks.26.att.output.weight", "blocks.26.att.gate.weight", "blocks.26.att.ln_x.weight", "blocks.26.att.ln_x.bias", "blocks.26.ffn.time_mix_k", "blocks.26.ffn.time_mix_r", "blocks.26.ffn.key.weight", "blocks.26.ffn.receptance.weight", "blocks.26.ffn.value.weight", "blocks.27.ln1.weight", "blocks.27.ln1.bias", "blocks.27.ln2.weight", "blocks.27.ln2.bias", "blocks.27.att.time_mix_k", "blocks.27.att.time_mix_v", "blocks.27.att.time_mix_r", "blocks.27.att.time_mix_g", "blocks.27.att.time_decay", "blocks.27.att.time_faaaa", "blocks.27.att.receptance.weight", "blocks.27.att.key.weight", "blocks.27.att.value.weight", "blocks.27.att.output.weight", "blocks.27.att.gate.weight", "blocks.27.att.ln_x.weight", "blocks.27.att.ln_x.bias", "blocks.27.ffn.time_mix_k", "blocks.27.ffn.time_mix_r", "blocks.27.ffn.key.weight", "blocks.27.ffn.receptance.weight", "blocks.27.ffn.value.weight", "blocks.28.ln1.weight", "blocks.28.ln1.bias", "blocks.28.ln2.weight", "blocks.28.ln2.bias", "blocks.28.att.time_mix_k", "blocks.28.att.time_mix_v", "blocks.28.att.time_mix_r", "blocks.28.att.time_mix_g", "blocks.28.att.time_decay", "blocks.28.att.time_faaaa", "blocks.28.att.receptance.weight", "blocks.28.att.key.weight", "blocks.28.att.value.weight", "blocks.28.att.output.weight", "blocks.28.att.gate.weight", "blocks.28.att.ln_x.weight", "blocks.28.att.ln_x.bias", "blocks.28.ffn.time_mix_k", "blocks.28.ffn.time_mix_r", "blocks.28.ffn.key.weight", "blocks.28.ffn.receptance.weight", "blocks.28.ffn.value.weight", "blocks.29.ln1.weight", "blocks.29.ln1.bias", "blocks.29.ln2.weight", "blocks.29.ln2.bias", "blocks.29.att.time_mix_k", "blocks.29.att.time_mix_v", "blocks.29.att.time_mix_r", "blocks.29.att.time_mix_g", "blocks.29.att.time_decay", "blocks.29.att.time_faaaa", "blocks.29.att.receptance.weight", "blocks.29.att.key.weight", "blocks.29.att.value.weight", "blocks.29.att.output.weight", "blocks.29.att.gate.weight", "blocks.29.att.ln_x.weight", "blocks.29.att.ln_x.bias", "blocks.29.ffn.time_mix_k", "blocks.29.ffn.time_mix_r", "blocks.29.ffn.key.weight", "blocks.29.ffn.receptance.weight", "blocks.29.ffn.value.weight", "blocks.30.ln1.weight", "blocks.30.ln1.bias", "blocks.30.ln2.weight", "blocks.30.ln2.bias", "blocks.30.att.time_mix_k", "blocks.30.att.time_mix_v", "blocks.30.att.time_mix_r", "blocks.30.att.time_mix_g", "blocks.30.att.time_decay", "blocks.30.att.time_faaaa", "blocks.30.att.receptance.weight", "blocks.30.att.key.weight", "blocks.30.att.value.weight", "blocks.30.att.output.weight", "blocks.30.att.gate.weight", "blocks.30.att.ln_x.weight", "blocks.30.att.ln_x.bias", "blocks.30.ffn.time_mix_k", "blocks.30.ffn.time_mix_r", "blocks.30.ffn.key.weight", "blocks.30.ffn.receptance.weight", "blocks.30.ffn.value.weight", "blocks.31.ln1.weight", "blocks.31.ln1.bias", "blocks.31.ln2.weight", "blocks.31.ln2.bias", "blocks.31.att.time_mix_k", "blocks.31.att.time_mix_v", "blocks.31.att.time_mix_r", "blocks.31.att.time_mix_g", "blocks.31.att.time_decay", "blocks.31.att.time_faaaa", "blocks.31.att.receptance.weight", "blocks.31.att.key.weight", "blocks.31.att.value.weight", "blocks.31.att.output.weight", "blocks.31.att.gate.weight", "blocks.31.att.ln_x.weight", "blocks.31.att.ln_x.bias", "blocks.31.ffn.time_mix_k", "blocks.31.ffn.time_mix_r", "blocks.31.ffn.key.weight", "blocks.31.ffn.receptance.weight", "blocks.31.ffn.value.weight".

fixed withn_layer = 32