BlinkDL/RWKV-LM

AssertionError while finetuning RWKVv5

Ethan-Chen-plus opened this issue · 8 comments

While finetuning RWKV, I use this script(using demo dataset by make_data.py and put demo.bin and demo.idx in ./data):

#!/bin/bash

BASE_NAME="model/demo"
N_LAYER="12"
N_EMBD="768"
M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM)
LR_INIT="6e-4"
LR_FINAL="6e-5"
GRAD_CP=0 # set to 1 to save VRAM (will be slower)
EPOCH_SAVE=10

# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
# use https://www.dcode.fr/prime-numbers-search

python train.py --load_model "../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \
 --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \
 --data_file "data/demo" --my_exit_tokens 1498226207 --magic_prime 2926181 \
 --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \
 --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \
 --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \
 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --ds_bucket_mb 200

I caught this error:

(rwkv) ubuntu@ip-172-31-67-197:~/MedicalGPT/rwkv/RWKV-LM/RWKV-v5$ CUDA_VISIBLE_DEVICES=2 bash demo-training-run-demo.sh
INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpw45qi2d_
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpw45qi2d_/_remote_module_non_scriptable.py
INFO:pytorch_lightning.utilities.rank_zero:########## work in progress ##########
[2024-01-13 12:22:27,924] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO:pytorch_lightning.utilities.rank_zero:
############################################################################
#
# RWKV-5 BF16 on 1x1 GPU, bsz 1x1x16=16, deepspeed_stage_2 
#
# Data = data/demo (binidx), ProjDir = model/demo
#
# Epoch = 0 to 71 (will continue afterwards), save every 10 epoch
#
# Each "epoch" = 2520 steps, 40320 samples, 20643840 tokens
#
# Model = 12 n_layer, 768 n_embd, 512 ctx_len
#
# Adam = lr 0.0006 to 6e-05, warmup 10 steps, beta (0.9, 0.99), eps 1e-08
#
# Found torch 1.13.1+cu117, recommend 1.13.1+cu117 or newer
# Found deepspeed 0.12.6, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning 1.9.5, recommend 1.9.5
#
############################################################################

INFO:pytorch_lightning.utilities.rank_zero:{'load_model': '../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth', 'wandb': 'RWKV-5-Test', 'proj_dir': 'model/demo', 'random_seed': -1, 'data_file': 'data/demo', 'data_type': 'binidx', 'vocab_size': 65536, 'ctx_len': 512, 'epoch_steps': 2520, 'epoch_count': 72, 'epoch_begin': 0, 'epoch_save': 10, 'micro_bsz': 16, 'n_layer': 12, 'n_embd': 768, 'dim_att': 768, 'dim_ffn': 2688, 'pre_ffn': 0, 'head_qk': 0, 'tiny_att_dim': 0, 'tiny_att_layer': -999, 'lr_init': 0.0006, 'lr_final': 6e-05, 'warmup_steps': 10, 'beta1': 0.9, 'beta2': 0.99, 'adam_eps': 1e-08, 'grad_cp': 0, 'dropout': 0, 'weight_decay': 0.001, 'weight_decay_final': -1, 'my_pile_version': 1, 'my_pile_stage': 3, 'my_pile_shift': 0, 'my_pile_edecay': 0, 'layerwise_lr': 1, 'ds_bucket_mb': 200, 'my_sample_len': 0, 'my_ffn_shift': 1, 'my_att_shift': 1, 'head_size_a': 64, 'head_size_divisor': 8, 'my_pos_emb': 0, 'load_partial': 0, 'magic_prime': 2926181, 'my_qa_mask': 0, 'my_random_steps': 0, 'my_testing': '', 'my_exit': 99999999, 'my_exit_tokens': 1498226207, 'logger': False, 'enable_checkpointing': False, 'default_root_dir': None, 'gradient_clip_val': 1.0, 'gradient_clip_algorithm': None, 'num_nodes': 1, 'num_processes': None, 'devices': '1', 'gpus': None, 'auto_select_gpus': None, 'tpu_cores': None, 'ipus': None, 'enable_progress_bar': True, 'overfit_batches': 0.0, 'track_grad_norm': -1, 'check_val_every_n_epoch': 100000000000000000000, 'fast_dev_run': False, 'accumulate_grad_batches': None, 'max_epochs': -1, 'min_epochs': None, 'max_steps': -1, 'min_steps': None, 'max_time': None, 'limit_train_batches': None, 'limit_val_batches': None, 'limit_test_batches': None, 'limit_predict_batches': None, 'val_check_interval': None, 'log_every_n_steps': 100000000000000000000, 'accelerator': 'gpu', 'strategy': 'deepspeed_stage_2', 'sync_batchnorm': False, 'precision': 'bf16', 'enable_model_summary': True, 'num_sanity_val_steps': 0, 'resume_from_checkpoint': None, 'profiler': None, 'benchmark': None, 'reload_dataloaders_every_n_epochs': 0, 'auto_lr_find': False, 'replace_sampler_ddp': False, 'detect_anomaly': False, 'auto_scale_batch_size': False, 'plugins': None, 'amp_backend': None, 'amp_level': None, 'move_metrics_to_cpu': False, 'multiple_trainloader_mode': 'max_size_cycle', 'inference_mode': True, 'my_timestamp': '2024-01-13-12-22-29', 'betas': (0.9, 0.99), 'real_bsz': 16, 'run_name': '65536 ctx512 L12 D768'}

INFO:pytorch_lightning.utilities.rank_zero:Current vocab size = 65536 (make sure it's correct)
INFO:pytorch_lightning.utilities.rank_zero:Data has 200499 tokens.
INFO:pytorch_lightning.utilities.rank_zero:########## Pile 20b-tokenized stage 3 ##########
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/train.py", line 248, in <module>
    train_data = MyDataset(args)
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/src/dataset.py", line 56, in __init__
    assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
AssertionError

Data has 200499 tokens

therefore set my_exit_tokens to 200499, and note:
magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 200499 /512-1 = 390.599609375 in this case)
use https://www.dcode.fr/prime-numbers-search

therefore set magic_prime = 389

Thanks for answering.But still some errors occur:

CUDA_VISIBLE_DEVICES=2 bash demo-training-run-demo.sh

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpyy254i6t
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpyy254i6t/_remote_module_non_scriptable.py
INFO:pytorch_lightning.utilities.rank_zero:########## work in progress ##########
[2024-01-16 12:35:16,735] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO:pytorch_lightning.utilities.rank_zero:
############################################################################
#
# RWKV-5 BF16 on 1x1 GPU, bsz 1x1x16=16, deepspeed_stage_2 
#
# Data = data/demo (binidx), ProjDir = model/demo
#
# Epoch = 0 to -1 (will continue afterwards), save every 10 epoch
#
# Each "epoch" = 2520 steps, 40320 samples, 20643840 tokens
#
# Model = 12 n_layer, 768 n_embd, 512 ctx_len
#
# Adam = lr 0.0006 to 6e-05, warmup 10 steps, beta (0.9, 0.99), eps 1e-08
#
# Found torch 1.13.1+cu117, recommend 1.13.1+cu117 or newer
# Found deepspeed 0.12.6, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning 1.9.5, recommend 1.9.5
#
############################################################################

INFO:pytorch_lightning.utilities.rank_zero:{'load_model': '../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth', 'wandb': 'RWKV-5-Test', 'proj_dir': 'model/demo', 'random_seed': -1, 'data_file': 'data/demo', 'data_type': 'binidx', 'vocab_size': 65536, 'ctx_len': 512, 'epoch_steps': 2520, 'epoch_count': 0, 'epoch_begin': 0, 'epoch_save': 10, 'micro_bsz': 16, 'n_layer': 12, 'n_embd': 768, 'dim_att': 768, 'dim_ffn': 2688, 'pre_ffn': 0, 'head_qk': 0, 'tiny_att_dim': 0, 'tiny_att_layer': -999, 'lr_init': 0.0006, 'lr_final': 6e-05, 'warmup_steps': 10, 'beta1': 0.9, 'beta2': 0.99, 'adam_eps': 1e-08, 'grad_cp': 0, 'dropout': 0, 'weight_decay': 0.001, 'weight_decay_final': -1, 'my_pile_version': 1, 'my_pile_stage': 3, 'my_pile_shift': 0, 'my_pile_edecay': 0, 'layerwise_lr': 1, 'ds_bucket_mb': 200, 'my_sample_len': 0, 'my_ffn_shift': 1, 'my_att_shift': 1, 'head_size_a': 64, 'head_size_divisor': 8, 'my_pos_emb': 0, 'load_partial': 0, 'magic_prime': 389, 'my_qa_mask': 0, 'my_random_steps': 0, 'my_testing': '', 'my_exit': 99999999, 'my_exit_tokens': 1498226207, 'logger': False, 'enable_checkpointing': False, 'default_root_dir': None, 'gradient_clip_val': 1.0, 'gradient_clip_algorithm': None, 'num_nodes': 1, 'num_processes': None, 'devices': '1', 'gpus': None, 'auto_select_gpus': None, 'tpu_cores': None, 'ipus': None, 'enable_progress_bar': True, 'overfit_batches': 0.0, 'track_grad_norm': -1, 'check_val_every_n_epoch': 100000000000000000000, 'fast_dev_run': False, 'accumulate_grad_batches': None, 'max_epochs': -1, 'min_epochs': None, 'max_steps': -1, 'min_steps': None, 'max_time': None, 'limit_train_batches': None, 'limit_val_batches': None, 'limit_test_batches': None, 'limit_predict_batches': None, 'val_check_interval': None, 'log_every_n_steps': 100000000000000000000, 'accelerator': 'gpu', 'strategy': 'deepspeed_stage_2', 'sync_batchnorm': False, 'precision': 'bf16', 'enable_model_summary': True, 'num_sanity_val_steps': 0, 'resume_from_checkpoint': None, 'profiler': None, 'benchmark': None, 'reload_dataloaders_every_n_epochs': 0, 'auto_lr_find': False, 'replace_sampler_ddp': False, 'detect_anomaly': False, 'auto_scale_batch_size': False, 'plugins': None, 'amp_backend': None, 'amp_level': None, 'move_metrics_to_cpu': False, 'multiple_trainloader_mode': 'max_size_cycle', 'inference_mode': True, 'my_timestamp': '2024-01-16-12-35-18', 'betas': (0.9, 0.99), 'real_bsz': 16, 'run_name': '65536 ctx512 L12 D768'}

INFO:pytorch_lightning.utilities.rank_zero:Current vocab size = 65536 (make sure it's correct)
INFO:pytorch_lightning.utilities.rank_zero:Data has 200499 tokens.
INFO:pytorch_lightning.utilities.rank_zero:########## Pile 20b-tokenized stage 3 ##########
RWKV_MY_TESTING 
Using /home/ubuntu/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu117/wkv5/build.ninja...
Building extension module wkv5...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/2] /usr/bin/g++-10 -MMD -MF wkv5_op.o.d -DTORCH_EXTENSION_NAME=wkv5 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/TH -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/THC -isystem /home/ubuntu/micromamba/envs/rwkv/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/cuda/wkv5_op.cpp -o wkv5_op.o 
[2/2] /usr/bin/g++-10 wkv5_op.o wkv5_cuda.cuda.o -shared -L/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda_cu -ltorch_cuda_cpp -ltorch -ltorch_python -L/usr/lib64 -lcudart -o wkv5.so
Loading extension module wkv5...
INFO:pytorch_lightning.utilities.rank_zero:########## Loading ../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth... ##########
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/train.py", line 284, in <module>
    model.load_state_dict(load_dict)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RWKV:
        Unexpected key(s) in state_dict: "blocks.12.ln1.weight", "blocks.12.ln1.bias", "blocks.12.ln2.weight", "blocks.12.ln2.bias", "blocks.12.att.time_mix_k", "blocks.12.att.time_mix_v", "blocks.12.att.time_mix_r", "blocks.12.att.time_mix_g", "blocks.12.att.time_decay", "blocks.12.att.time_faaaa", "blocks.12.att.receptance.weight", "blocks.12.att.key.weight", "blocks.12.att.value.weight", "blocks.12.att.output.weight", "blocks.12.att.gate.weight", "blocks.12.att.ln_x.weight", "blocks.12.att.ln_x.bias", "blocks.12.ffn.time_mix_k", "blocks.12.ffn.time_mix_r", "blocks.12.ffn.key.weight", "blocks.12.ffn.receptance.weight", "blocks.12.ffn.value.weight", "blocks.13.ln1.weight", "blocks.13.ln1.bias", "blocks.13.ln2.weight", "blocks.13.ln2.bias", "blocks.13.att.time_mix_k", "blocks.13.att.time_mix_v", "blocks.13.att.time_mix_r", "blocks.13.att.time_mix_g", "blocks.13.att.time_decay", "blocks.13.att.time_faaaa", "blocks.13.att.receptance.weight", "blocks.13.att.key.weight", "blocks.13.att.value.weight", "blocks.13.att.output.weight", "blocks.13.att.gate.weight", "blocks.13.att.ln_x.weight", "blocks.13.att.ln_x.bias", "blocks.13.ffn.time_mix_k", "blocks.13.ffn.time_mix_r", "blocks.13.ffn.key.weight", "blocks.13.ffn.receptance.weight", "blocks.13.ffn.value.weight", "blocks.14.ln1.weight", "blocks.14.ln1.bias", "blocks.14.ln2.weight", "blocks.14.ln2.bias", "blocks.14.att.time_mix_k", "blocks.14.att.time_mix_v", "blocks.14.att.time_mix_r", "blocks.14.att.time_mix_g", "blocks.14.att.time_decay", "blocks.14.att.time_faaaa", "blocks.14.att.receptance.weight", "blocks.14.att.key.weight", "blocks.14.att.value.weight", "blocks.14.att.output.weight", "blocks.14.att.gate.weight", "blocks.14.att.ln_x.weight", "blocks.14.att.ln_x.bias", "blocks.14.ffn.time_mix_k", "blocks.14.ffn.time_mix_r", "blocks.14.ffn.key.weight", "blocks.14.ffn.receptance.weight", "blocks.14.ffn.value.weight", "blocks.15.ln1.weight", "blocks.15.ln1.bias", "blocks.15.ln2.weight", "blocks.15.ln2.bias", "blocks.15.att.time_mix_k", "blocks.15.att.time_mix_v", "blocks.15.att.time_mix_r", "blocks.15.att.time_mix_g", "blocks.15.att.time_decay", "blocks.15.att.time_faaaa", "blocks.15.att.receptance.weight", "blocks.15.att.key.weight", "blocks.15.att.value.weight", "blocks.15.att.output.weight", "blocks.15.att.gate.weight", "blocks.15.att.ln_x.weight", "blocks.15.att.ln_x.bias", "blocks.15.ffn.time_mix_k", "blocks.15.ffn.time_mix_r", "blocks.15.ffn.key.weight", "blocks.15.ffn.receptance.weight", "blocks.15.ffn.value.weight", "blocks.16.ln1.weight", "blocks.16.ln1.bias", "blocks.16.ln2.weight", "blocks.16.ln2.bias", "blocks.16.att.time_mix_k", "blocks.16.att.time_mix_v", "blocks.16.att.time_mix_r", "blocks.16.att.time_mix_g", "blocks.16.att.time_decay", "blocks.16.att.time_faaaa", "blocks.16.att.receptance.weight", "blocks.16.att.key.weight", "blocks.16.att.value.weight", "blocks.16.att.output.weight", "blocks.16.att.gate.weight", "blocks.16.att.ln_x.weight", "blocks.16.att.ln_x.bias", "blocks.16.ffn.time_mix_k", "blocks.16.ffn.time_mix_r", "blocks.16.ffn.key.weight", "blocks.16.ffn.receptance.weight", "blocks.16.ffn.value.weight", "blocks.17.ln1.weight", "blocks.17.ln1.bias", "blocks.17.ln2.weight", "blocks.17.ln2.bias", "blocks.17.att.time_mix_k", "blocks.17.att.time_mix_v", "blocks.17.att.time_mix_r", "blocks.17.att.time_mix_g", "blocks.17.att.time_decay", "blocks.17.att.time_faaaa", "blocks.17.att.receptance.weight", "blocks.17.att.key.weight", "blocks.17.att.value.weight", "blocks.17.att.output.weight", "blocks.17.att.gate.weight", "blocks.17.att.ln_x.weight", "blocks.17.att.ln_x.bias", "blocks.17.ffn.time_mix_k", "blocks.17.ffn.time_mix_r", "blocks.17.ffn.key.weight", "blocks.17.ffn.receptance.weight", "blocks.17.ffn.value.weight", "blocks.18.ln1.weight", "blocks.18.ln1.bias", "blocks.18.ln2.weight", "blocks.18.ln2.bias", "blocks.18.att.time_mix_k", "blocks.18.att.time_mix_v", "blocks.18.att.time_mix_r", "blocks.18.att.time_mix_g", "blocks.18.att.time_decay", "blocks.18.att.time_faaaa", "blocks.18.att.receptance.weight", "blocks.18.att.key.weight", "blocks.18.att.value.weight", "blocks.18.att.output.weight", "blocks.18.att.gate.weight", "blocks.18.att.ln_x.weight", "blocks.18.att.ln_x.bias", "blocks.18.ffn.time_mix_k", "blocks.18.ffn.time_mix_r", "blocks.18.ffn.key.weight", "blocks.18.ffn.receptance.weight", "blocks.18.ffn.value.weight", "blocks.19.ln1.weight", "blocks.19.ln1.bias", "blocks.19.ln2.weight", "blocks.19.ln2.bias", "blocks.19.att.time_mix_k", "blocks.19.att.time_mix_v", "blocks.19.att.time_mix_r", "blocks.19.att.time_mix_g", "blocks.19.att.time_decay", "blocks.19.att.time_faaaa", "blocks.19.att.receptance.weight", "blocks.19.att.key.weight", "blocks.19.att.value.weight", "blocks.19.att.output.weight", "blocks.19.att.gate.weight", "blocks.19.att.ln_x.weight", "blocks.19.att.ln_x.bias", "blocks.19.ffn.time_mix_k", "blocks.19.ffn.time_mix_r", "blocks.19.ffn.key.weight", "blocks.19.ffn.receptance.weight", "blocks.19.ffn.value.weight", "blocks.20.ln1.weight", "blocks.20.ln1.bias", "blocks.20.ln2.weight", "blocks.20.ln2.bias", "blocks.20.att.time_mix_k", "blocks.20.att.time_mix_v", "blocks.20.att.time_mix_r", "blocks.20.att.time_mix_g", "blocks.20.att.time_decay", "blocks.20.att.time_faaaa", "blocks.20.att.receptance.weight", "blocks.20.att.key.weight", "blocks.20.att.value.weight", "blocks.20.att.output.weight", "blocks.20.att.gate.weight", "blocks.20.att.ln_x.weight", "blocks.20.att.ln_x.bias", "blocks.20.ffn.time_mix_k", "blocks.20.ffn.time_mix_r", "blocks.20.ffn.key.weight", "blocks.20.ffn.receptance.weight", "blocks.20.ffn.value.weight", "blocks.21.ln1.weight", "blocks.21.ln1.bias", "blocks.21.ln2.weight", "blocks.21.ln2.bias", "blocks.21.att.time_mix_k", "blocks.21.att.time_mix_v", "blocks.21.att.time_mix_r", "blocks.21.att.time_mix_g", "blocks.21.att.time_decay", "blocks.21.att.time_faaaa", "blocks.21.att.receptance.weight", "blocks.21.att.key.weight", "blocks.21.att.value.weight", "blocks.21.att.output.weight", "blocks.21.att.gate.weight", "blocks.21.att.ln_x.weight", "blocks.21.att.ln_x.bias", "blocks.21.ffn.time_mix_k", "blocks.21.ffn.time_mix_r", "blocks.21.ffn.key.weight", "blocks.21.ffn.receptance.weight", "blocks.21.ffn.value.weight", "blocks.22.ln1.weight", "blocks.22.ln1.bias", "blocks.22.ln2.weight", "blocks.22.ln2.bias", "blocks.22.att.time_mix_k", "blocks.22.att.time_mix_v", "blocks.22.att.time_mix_r", "blocks.22.att.time_mix_g", "blocks.22.att.time_decay", "blocks.22.att.time_faaaa", "blocks.22.att.receptance.weight", "blocks.22.att.key.weight", "blocks.22.att.value.weight", "blocks.22.att.output.weight", "blocks.22.att.gate.weight", "blocks.22.att.ln_x.weight", "blocks.22.att.ln_x.bias", "blocks.22.ffn.time_mix_k", "blocks.22.ffn.time_mix_r", "blocks.22.ffn.key.weight", "blocks.22.ffn.receptance.weight", "blocks.22.ffn.value.weight", "blocks.23.ln1.weight", "blocks.23.ln1.bias", "blocks.23.ln2.weight", "blocks.23.ln2.bias", "blocks.23.att.time_mix_k", "blocks.23.att.time_mix_v", "blocks.23.att.time_mix_r", "blocks.23.att.time_mix_g", "blocks.23.att.time_decay", "blocks.23.att.time_faaaa", "blocks.23.att.receptance.weight", "blocks.23.att.key.weight", "blocks.23.att.value.weight", "blocks.23.att.output.weight", "blocks.23.att.gate.weight", "blocks.23.att.ln_x.weight", "blocks.23.att.ln_x.bias", "blocks.23.ffn.time_mix_k", "blocks.23.ffn.time_mix_r", "blocks.23.ffn.key.weight", "blocks.23.ffn.receptance.weight", "blocks.23.ffn.value.weight". 
        size mismatch for emb.weight: copying a param with shape torch.Size([65536, 1024]) from checkpoint, the shape in current model is torch.Size([65536, 768]).
        size mismatch for blocks.0.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln0.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.0.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.0.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.0.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.1.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.1.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.1.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.1.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.1.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.1.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.1.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.2.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.2.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.2.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.2.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.2.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.2.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.2.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.3.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.3.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.3.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.3.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.3.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.3.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.3.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.4.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.4.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.4.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.4.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.4.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.4.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.4.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.5.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.5.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.5.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.5.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.5.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.5.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.5.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.6.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.6.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.6.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.6.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.6.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.6.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.6.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.7.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.7.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.7.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.7.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.7.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.7.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.7.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.8.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.8.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.8.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.8.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.8.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.8.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.8.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.9.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.9.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.9.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.9.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.9.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.9.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.9.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.10.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.10.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.10.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.10.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.10.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.10.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.10.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for blocks.11.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.11.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.11.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.11.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.11.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.11.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.11.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
        size mismatch for ln_out.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for ln_out.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for head.weight: copying a param with shape torch.Size([65536, 1024]) from checkpoint, the shape in current model is torch.Size([65536, 768]).

for 0.4B finetuning, set:
N_LAYER="24"
N_EMBD="1024"
LR_INIT="2e-5"
LR_FINAL="2e-5"
GRAD_CP="1"

Thanks for helping! But I wonder why set LR_INIT==LR_FINAL?
Another Question is that if I set GRAD_CP=0, the cost of mem will be more and I will receive OOM.

INFO:pytorch_lightning.strategies.deepspeed:initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
INFO:pytorch_lightning.utilities.rank_zero:Enabling DeepSpeed BF16.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]
Using /home/ubuntu/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu117/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_adam...
Time to load fused_adam op: 0.06461381912231445 seconds
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:2 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name   | Type       | Params
--------------------------------------
0 | emb    | Embedding  | 67.1 M
1 | blocks | ModuleList | 327 M 
2 | ln_out | LayerNorm  | 2.0 K 
3 | head   | Linear     | 67.1 M
--------------------------------------
461 M     Trainable params
0         Non-trainable params
461 M     Total params
1,846.886 Total estimated model params size (MB)
Epoch 0:   0%|                                           | 0/2520 [00:00<?, ?it/s]
{'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'overlap_comm': True, 'allgather_partitions': True, 'reduce_scatter': True, 'allgather_bucket_size': 200000000, 'reduce_bucket_size': 200000000, 'sub_group_size': 1000000000000}, 'activation_checkpointing': {'partition_activations': False, 'cpu_checkpointing': False, 'contiguous_memory_optimization': False, 'synchronize_checkpoint_boundary': False}, 'aio': {'block_size': 1048576, 'queue_depth': 8, 'single_submit': False, 'overlap_events': True, 'thread_count': 1}, 'gradient_accumulation_steps': 1, 'train_micro_batch_size_per_gpu': 16, 'gradient_clipping': 1.0, 'bf16': {'enabled': True}}

Login to wandb...
wandb: Currently logged in as: keweichen (aicolab). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in /home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/wandb/run-20240116_132216-ck3u9wok
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run 65536 ctx512 L24 D1024 2024-01-16-13-22-02
wandb: ⭐️ View project at https://wandb.ai/aicolab/RWKV-5-Test
wandb: 🚀 View run at https://wandb.ai/aicolab/RWKV-5-Test/runs/ck3u9wok
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/train.py", line 312, in <module>
    trainer.fit(model, data_loader)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch
    return function(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run
    results = self._run_stage()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage
    self._run_train()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train
    self.fit_loop.run()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 213, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 202, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 249, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 370, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1356, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1754, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 280, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 132, in optimizer_step
    closure_result = closure()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 149, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 144, in closure
    self._backward_fn(step_output.closure_loss)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 305, in backward_fn
    self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1494, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 207, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, optimizer_idx, *args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 118, in backward
    deepspeed_engine.backward(tensor, *args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1955, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2019, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 21.99 GiB total capacity; 20.70 GiB already allocated; 287.00 MiB free; 20.88 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
wandb: 🚀 View run 65536 ctx512 L24 D1024 2024-01-16-13-22-02 at: https://wandb.ai/aicolab/RWKV-5-Test/runs/ck3u9wok
wandb: ️⚡ View job at https://wandb.ai/aicolab/RWKV-5-Test/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEyOTkxNjc0MQ==/version_details/v1
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240116_132216-ck3u9wok/logs

image
I currently have 4 cards of A10-22G, how can I maximize the utilization of computing power and memory?

set --devices 4 to use 4 GPU

CUDA_VISIBLE_DEVICES=0,1,2,3

Thanks again @BlinkDL
image

I have another question I'd like to ask: Currently, I'm using a context length (ctx_len) of 1024 for full fine-tuning a model with only 0.4B parameters, specifically rwkv5, but it's almost maxing out the memory on all four of my A10 GPUs. However, llama2-7b can run full-scale on four A10 cards with a context length of 4096. Is there a way I can enable my v5 model to run full-scale training with a context length of 4096 using model parallelism across four GPUs?

Check your "gradient checkpoint" flag, disabling gives a speed boost, for much more VRAM usage (llama typically have that set to true)

@Ethan-Chen-plus set GRAD_CP=1