openai/jukebox

Size mismatch

martinkoenig opened this issue · 0 comments

Hi, I´m trying to run the example sample

python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --sample_length_in_seconds=20 \
--total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125

But I´m getting some strange errors due to size mismatch. Does anybody know whats wrong here?

Restored from C:\Users\Martin/.cache\jukebox/models/5b_lyrics/prior_level_2.pth.tar
Traceback (most recent call last):
  File "jukebox/sample.py", line 279, in <module>
    fire.Fire(run)
  File "C:\Users\Martin\miniconda3\envs\jukebox\lib\site-packages\fire\core.py", line 127, in Fire
    component_trace = _Fire(component, args, context, name)
  File "C:\Users\Martin\miniconda3\envs\jukebox\lib\site-packages\fire\core.py", line 366, in _Fire
    component, remaining_args)
  File "C:\Users\Martin\miniconda3\envs\jukebox\lib\site-packages\fire\core.py", line 542, in _CallCallable
    result = fn(*varargs, **kwargs)
  File "jukebox/sample.py", line 276, in run
    save_samples(model, device, hps, sample_hps)
  File "jukebox/sample.py", line 181, in save_samples
    vqvae, priors = make_model(model, device, hps)
  File "c:\users\martin\pycharmprojects\jukebox\jukebox\make_models.py", line 195, in make_model
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
  File "c:\users\martin\pycharmprojects\jukebox\jukebox\make_models.py", line 195, in <listcomp>
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
  File "c:\users\martin\pycharmprojects\jukebox\jukebox\make_models.py", line 179, in make_prior
    restore_model(hps, prior, hps.restore_prior)
  File "c:\users\martin\pycharmprojects\jukebox\jukebox\make_models.py", line 61, in restore_model
    model.load_state_dict(checkpoint['model'])
  File "C:\Users\Martin\miniconda3\envs\jukebox\lib\site-packages\torch\nn\modules\module.py", line 777, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SimplePrior:
        Missing key(s) in state_dict: "prime_prior.start_token", "prime_prior.x_emb.weight", "prime_prior.pos_emb.pos_emb", "prime_prior.transformer._attn_mods.0.attn.c_attn.w", "prime_prior.transformer._attn_mods.0.attn.c_attn.b", "prime_prior.transformer._attn_mods.0.attn.c_proj.w", "prime_prior.transformer._attn_mods.0.attn.c_proj.b", "prime_prior.transformer._attn_mods.0.ln_0.weight", "prime_prior.transformer._attn_mods.0.ln_0.bias", "prime_prior.transformer._attn_mods.0.mlp.c_fc.w", "prime_prior.transformer._attn_mods.0.mlp.c_fc.b", "prime_prior.transformer._attn_mods.0.mlp.c_proj.w", "prime_prior.transformer._attn_mods.0.mlp.c_proj.b", "prime_prior.transformer._attn_mods.0.ln_1.weight", "prime_prior.transformer._attn_mods.0.ln_1.bias", "prime_prior.transformer._attn_mods.1.attn.c_attn.w", "prime_prior.transformer._attn_mods.1.attn.c_attn.b", "prime_prior.transformer._attn_mods.1.attn.c_proj.w", "prime_prior.transformer._attn_mods.1.attn.c_proj.b", "prime_prior.transformer._attn_mods.1.ln_0.weight", "prime_prior.transformer._attn_mods.1.ln_0.bias", "prime_prior.transformer._attn_mods.1.mlp.c_fc.w", "prime_prior.transformer._attn_mods.1.mlp.c_fc.b", "prime_prior.transformer._attn_mods.1.mlp.c_proj.w", "prime_prior.transformer._attn_mods.1.mlp.c_proj.b", "prime_prior.transformer._attn_mods.1.ln_1.weight", "prime_prior.transformer._attn_mods.1.ln_1.bias", "prime_prior.transformer._attn_mods.2.attn.c_attn.w", "prime_prior.transformer._attn_mods.2.attn.c_attn.b", "prime_prior.transformer._attn_mods.2.attn.c_proj.w", "prime_prior.transformer._attn_mods.2.attn.c_proj.b", "prime_prior.transformer._attn_mods.2.ln_0.weight", "prime_prior.transformer._attn_mods.2.ln_0.bias", "prime_prior.transformer._attn_mods.2.mlp.c_fc.w", "prime_prior.transformer._attn_mods.2.mlp.c_fc.b", "prime_prior.transformer._attn_mods.2.mlp.c_proj.w", "prime_prior.transformer._attn_mods.2.mlp.c_proj.b", "prime_prior.transformer._attn_mods.2.ln_1.weight", "prime_prior.transformer._attn_mods.2.ln_1.bias", "prime_prior.transformer._attn_mods.3.attn.c_attn.w", "prime_prior.transformer._attn_mods.3.attn.c_attn.b", "prime_prior.transformer._attn_mods.3.attn.c_proj.w", "prime_prior.transformer._attn_mods.3.attn.c_proj.b", "prime_prior.transformer._attn_mods.3.ln_0.weight", "prime_prior.transformer._attn_mods.3.ln_0.bias", "prime_prior.transformer._attn_mods.3.mlp.c_fc.w", "prime_prior.transformer._attn_mods.3.mlp.c_fc.b", "prime_prior.transformer._attn_mods.3.mlp.c_proj.w", "prime_prior.transformer._attn_mods.3.mlp.c_proj.b", "prime_prior.transformer._attn_mods.3.ln_1.weight", "prime_prior.transformer._attn_mods.3.ln_1.bias", "prime_prior.transformer._attn_mods.4.attn.c_attn.w", "prime_prior.transformer._attn_mods.4.attn.c_attn.b", "prime_prior.transformer._attn_mods.4.attn.c_proj.w", "prime_prior.transformer._attn_mods.4.attn.c_proj.b", "prime_prior.transformer._attn_mods.4.ln_0.weight", "prime_prior.transformer._attn_mods.4.ln_0.bias", "prime_prior.transformer._attn_mods.4.mlp.c_fc.w", "prime_prior.transformer._attn_mods.4.mlp.c_fc.b", "prime_prior.transformer._attn_mods.4.mlp.c_proj.w", "prime_prior.transformer._attn_mods.4.mlp.c_proj.b", "prime_prior.transformer._attn_mods.4.ln_1.weight", "prime_prior.transformer._attn_mods.4.ln_1.bias", "prime_prior.transformer._attn_mods.5.attn.c_attn.w", "prime_prior.transformer._attn_mods.5.attn.c_attn.b", "prime_prior.transformer._attn_mods.5.attn.c_proj.w", "prime_prior.transformer._attn_mods.5.attn.c_proj.b", "prime_prior.transformer._attn_mods.5.ln_0.weight", "prime_prior.transformer._attn_mods.5.ln_0.bias", "prime_prior.transformer._attn_mods.5.mlp.c_fc.w", "prime_prior.transformer._attn_mods.5.mlp.c_fc.b", "prime_prior.transformer._attn_mods.5.mlp.c_proj.w", "prime_prior.transformer._attn_mods.5.mlp.c_proj.b", "prime_prior.transformer._attn_mods.5.ln_1.weight", "prime_prior.transformer._attn_mods.5.ln_1.bias", "prime_prior.transformer._attn_mods.6.attn.c_attn.w", "prime_prior.transformer._attn_mods.6.attn.c_attn.b", "prime_prior.transformer._attn_mods.6.attn.c_proj.w", "prime_prior.transformer._attn_mods.6.attn.c_proj.b", "prime_prior.transformer._attn_mods.6.ln_0.weight", "prime_prior.transformer._attn_mods.6.ln_0.bias", "prime_prior.transformer._attn_mods.6.mlp.c_fc.w", "prime_prior.transformer._attn_mods.6.mlp.c_fc.b", "prime_prior.transformer._attn_mods.6.mlp.c_proj.w", "prime_prior.transformer._attn_mods.6.mlp.c_proj.b", "prime_prior.transformer._attn_mods.6.ln_1.weight", "prime_prior.transformer._attn_mods.6.ln_1.bias", "prime_prior.transformer._attn_mods.7.attn.c_attn.w", "prime_prior.transformer._attn_mods.7.attn.c_attn.b", "prime_prior.transformer._attn_mods.7.attn.c_proj.w", "prime_prior.transformer._attn_mods.7.attn.c_proj.b", "prime_prior.transformer._attn_mods.7.ln_0.weight", "prime_prior.transformer._attn_mods.7.ln_0.bias", "prime_prior.transformer._attn_mods.7.mlp.c_fc.w", "prime_prior.transformer._attn_mods.7.mlp.c_fc.b", "prime_prior.transformer._attn_mods.7.mlp.c_proj.w", "prime_prior.transformer._attn_mods.7.mlp.c_proj.b", "prime_prior.transformer._attn_mods.7.ln_1.weight", "prime_prior.transformer._attn_mods.7.ln_1.bias", "prime_prior.transformer._attn_mods.8.attn.c_attn.w", "prime_prior.transformer._attn_mods.8.attn.c_attn.b", "prime_prior.transformer._attn_mods.8.attn.c_proj.w", "prime_prior.transformer._attn_mods.8.attn.c_proj.b", "prime_prior.transformer._attn_mods.8.ln_0.weight", "prime_prior.transformer._attn_mods.8.ln_0.bias", "prime_prior.transformer._attn_mods.8.mlp.c_fc.w", "prime_prior.transformer._attn_mods.8.mlp.c_fc.b", "prime_prior.transformer._attn_mods.8.mlp.c_proj.w", "prime_prior.transformer._attn_mods.8.mlp.c_proj.b", "prime_prior.transformer._attn_mods.8.ln_1.weight", "prime_prior.transformer._attn_mods.8.ln_1.bias", "prime_prior.transformer._attn_mods.9.attn.c_attn.w", "prime_prior.transformer._attn_mods.9.attn.c_attn.b", "prime_prior.transformer._attn_mods.9.attn.c_proj.w", "prime_prior.transformer._attn_mods.9.attn.c_proj.b", "prime_prior.transformer._attn_mods.9.ln_0.weight", "prime_prior.transformer._attn_mods.9.ln_0.bias", "prime_prior.transformer._attn_mods.9.mlp.c_fc.w", "prime_prior.transformer._attn_mods.9.mlp.c_fc.b", "prime_prior.transformer._attn_mods.9.mlp.c_proj.w", "prime_prior.transformer._attn_mods.9.mlp.c_proj.b", "prime_prior.transformer._attn_mods.9.ln_1.weight", "prime_prior.transformer._attn_mods.9.ln_1.bias", "prime_prior.transformer._attn_mods.10.attn.c_attn.w", "prime_prior.transformer._attn_mods.10.attn.c_attn.b", "prime_prior.transformer._attn_mods.10.attn.c_proj.w", "prime_prior.transformer._attn_mods.10.attn.c_proj.b", "prime_prior.transformer._attn_mods.10.ln_0.weight", "prime_prior.transformer._attn_mods.10.ln_0.bias", "prime_prior.transformer._attn_mods.10.mlp.c_fc.w", "prime_prior.transformer._attn_mods.10.mlp.c_fc.b", "prime_prior.transformer._attn_mods.10.mlp.c_proj.w", "prime_prior.transformer._attn_mods.10.mlp.c_proj.b", "prime_prior.transformer._attn_mods.10.ln_1.weight", "prime_prior.transformer._attn_mods.10.ln_1.bias", "prime_prior.transformer._attn_mods.11.attn.c_attn.w", "prime_prior.transformer._attn_mods.11.attn.c_attn.b", "prime_prior.transformer._attn_mods.11.attn.c_proj.w", "prime_prior.transformer._attn_mods.11.attn.c_proj.b", "prime_prior.transformer._attn_mods.11.ln_0.weight", "prime_prior.transformer._attn_mods.11.ln_0.bias", "prime_prior.transformer._attn_mods.11.mlp.c_fc.w", "prime_prior.transformer._attn_mods.11.mlp.c_fc.b", "prime_prior.transformer._attn_mods.11.mlp.c_proj.w", "prime_prior.transformer._attn_mods.11.mlp.c_proj.b", "prime_prior.transformer._attn_mods.11.ln_1.weight", "prime_prior.transformer._attn_mods.11.ln_1.bias", "prime_prior.transformer._attn_mods.12.attn.c_attn.w", "prime_prior.transformer._attn_mods.12.attn.c_attn.b", "prime_prior.transformer._attn_mods.12.attn.c_proj.w", "prime_prior.transformer._attn_mods.12.attn.c_proj.b", "prime_prior.transformer._attn_mods.12.ln_0.weight", "prime_prior.transformer._attn_mods.12.ln_0.bias", "prime_prior.transformer._attn_mods.12.mlp.c_fc.w", "prime_prior.transformer._attn_mods.12.mlp.c_fc.b", "prime_prior.transformer._attn_mods.12.mlp.c_proj.w", "prime_prior.transformer._attn_mods.12.mlp.c_proj.b", "prime_prior.transformer._attn_mods.12.ln_1.weight", "prime_prior.transformer._attn_mods.12.ln_1.bias", "prime_prior.transformer._attn_mods.13.attn.c_attn.w", "prime_prior.transformer._attn_mods.13.attn.c_attn.b", "prime_prior.transformer._attn_mods.13.attn.c_proj.w", "prime_prior.transformer._attn_mods.13.attn.c_proj.b", "prime_prior.transformer._attn_mods.13.ln_0.weight", "prime_prior.transformer._attn_mods.13.ln_0.bias", "prime_prior.transformer._attn_mods.13.mlp.c_fc.w", "prime_prior.transformer._attn_mods.13.mlp.c_fc.b", "prime_prior.transformer._attn_mods.13.mlp.c_proj.w", "prime_prior.transformer._attn_mods.13.mlp.c_proj.b", "prime_prior.transformer._attn_mods.13.ln_1.weight", "prime_prior.transformer._attn_mods.13.ln_1.bias", "prime_prior.transformer._attn_mods.14.attn.c_attn.w", "prime_prior.transformer._attn_mods.14.attn.c_attn.b", "prime_prior.transformer._attn_mods.14.attn.c_proj.w", "prime_prior.transformer._attn_mods.14.attn.c_proj.b", "prime_prior.transformer._attn_mods.14.ln_0.weight", "prime_prior.transformer._attn_mods.14.ln_0.bias", "prime_prior.transformer._attn_mods.14.mlp.c_fc.w", "prime_prior.transformer._attn_mods.14.mlp.c_fc.b", "prime_prior.transformer._attn_mods.14.mlp.c_proj.w", "prime_prior.transformer._attn_mods.14.mlp.c_proj.b", "prime_prior.transformer._attn_mods.14.ln_1.weight", "prime_prior.transformer._attn_mods.14.ln_1.bias", "prime_prior.transformer._attn_mods.15.attn.c_attn.w", "prime_prior.transformer._attn_mods.15.attn.c_attn.b", "prime_prior.transformer._attn_mods.15.attn.c_proj.w", "prime_prior.transformer._attn_mods.15.attn.c_proj.b", "prime_prior.transformer._attn_mods.15.ln_0.weight", "prime_prior.transformer._attn_mods.15.ln_0.bias", "prime_prior.transformer._attn_mods.15.mlp.c_fc.w", "prime_prior.transformer._attn_mods.15.mlp.c_fc.b", "prime_prior.transformer._attn_mods.15.mlp.c_proj.w", "prime_prior.transformer._attn_mods.15.mlp.c_proj.b", "prime_prior.transformer._attn_mods.15.ln_1.weight", "prime_prior.transformer._attn_mods.15.ln_1.bias", "prime_prior.transformer._attn_mods.16.attn.c_attn.w", "prime_prior.transformer._attn_mods.16.attn.c_attn.b", "prime_prior.transformer._attn_mods.16.attn.c_proj.w", "prime_prior.transformer._attn_mods.16.attn.c_proj.b", "prime_prior.transformer._attn_mods.16.ln_0.weight", "prime_prior.transformer._attn_mods.16.ln_0.bias", "prime_prior.transformer._attn_mods.16.mlp.c_fc.w", "prime_prior.transformer._attn_mods.16.mlp.c_fc.b", "prime_prior.transformer._attn_mods.16.mlp.c_proj.w", "prime_prior.transformer._attn_mods.16.mlp.c_proj.b", "prime_prior.transformer._attn_mods.16.ln_1.weight", "prime_prior.transformer._attn_mods.16.ln_1.bias", "prime_prior.transformer._attn_mods.17.attn.c_attn.w", "prime_prior.transformer._attn_mods.17.attn.c_attn.b", "prime_prior.transformer._attn_mods.17.attn.c_proj.w", "prime_prior.transformer._attn_mods.17.attn.c_proj.b", "prime_prior.transformer._attn_mods.17.ln_0.weight", "prime_prior.transformer._attn_mods.17.ln_0.bias", "prime_prior.transformer._attn_mods.17.mlp.c_fc.w", "prime_prior.transformer._attn_mods.17.mlp.c_fc.b", "prime_prior.transformer._attn_mods.17.mlp.c_proj.w", "prime_prior.transformer._attn_mods.17.mlp.c_proj.b", "prime_prior.transformer._attn_mods.17.ln_1.weight", "prime_prior.transformer._attn_mods.17.ln_1.bias", "prime_state_proj.w", "prime_state_proj.b", "prime_state_ln.weight", "prime_state_ln.bias", "prime_x_out.weight", "prior.transformer._attn_mods.18.attn.c_enc_kv.w", "prior.transformer._attn_mods.18.attn.c_enc_kv.b", "prior.transformer._attn_mods.28.attn.c_enc_kv.w", "prior.transformer._attn_mods.28.attn.c_enc_kv.b", "prior.transformer._attn_mods.38.attn.c_enc_kv.w", "prior.transformer._attn_mods.38.attn.c_enc_kv.b", "prior.transformer._attn_mods.48.attn.c_enc_kv.w", "prior.transformer._attn_mods.48.attn.c_enc_kv.b", "prior.transformer._attn_mods.58.attn.c_enc_kv.w", "prior.transformer._attn_mods.58.attn.c_enc_kv.b", "prior.transformer._attn_mods.68.attn.c_enc_kv.w", "prior.transformer._attn_mods.68.attn.c_enc_kv.b", "prior.transformer._attn_mods.72.attn.c_attn.w", "prior.transformer._attn_mods.72.attn.c_attn.b", "prior.transformer._attn_mods.72.attn.c_proj.w", "prior.transformer._attn_mods.72.attn.c_proj.b", "prior.transformer._attn_mods.72.ln_0.weight", "prior.transformer._attn_mods.72.ln_0.bias", "prior.transformer._attn_mods.72.mlp.c_fc.w", "prior.transformer._attn_mods.72.mlp.c_fc.b", "prior.transformer._attn_mods.72.mlp.c_proj.w", "prior.transformer._attn_mods.72.mlp.c_proj.b", "prior.transformer._attn_mods.72.ln_1.weight", "prior.transformer._attn_mods.72.ln_1.bias", "prior.transformer._attn_mods.73.attn.c_attn.w", "prior.transformer._attn_mods.73.attn.c_attn.b", "prior.transformer._attn_mods.73.attn.c_proj.w", "prior.transformer._attn_mods.73.attn.c_proj.b", "prior.transformer._attn_mods.73.ln_0.weight", "prior.transformer._attn_mods.73.ln_0.bias", "prior.transformer._attn_mods.73.mlp.c_fc.w", "prior.transformer._attn_mods.73.mlp.c_fc.b", "prior.transformer._attn_mods.73.mlp.c_proj.w", "prior.transformer._attn_mods.73.mlp.c_proj.b", "prior.transformer._attn_mods.73.ln_1.weight", "prior.transformer._attn_mods.73.ln_1.bias", "prior.transformer._attn_mods.74.attn.c_attn.w", "prior.transformer._attn_mods.74.attn.c_attn.b", "prior.transformer._attn_mods.74.attn.c_proj.w", "prior.transformer._attn_mods.74.attn.c_proj.b", "prior.transformer._attn_mods.74.ln_0.weight", "prior.transformer._attn_mods.74.ln_0.bias", "prior.transformer._attn_mods.74.mlp.c_fc.w", "prior.transformer._attn_mods.74.mlp.c_fc.b", "prior.transformer._attn_mods.74.mlp.c_proj.w", "prior.transformer._attn_mods.74.mlp.c_proj.b", "prior.transformer._attn_mods.74.ln_1.weight", "prior.transformer._attn_mods.74.ln_1.bias", "prior.transformer._attn_mods.75.attn.c_attn.w", "prior.transformer._attn_mods.75.attn.c_attn.b", "prior.transformer._attn_mods.75.attn.c_proj.w", "prior.transformer._attn_mods.75.attn.c_proj.b", "prior.transformer._attn_mods.75.ln_0.weight", "prior.transformer._attn_mods.75.ln_0.bias", "prior.transformer._attn_mods.75.mlp.c_fc.w", "prior.transformer._attn_mods.75.mlp.c_fc.b", "prior.transformer._attn_mods.75.mlp.c_proj.w", "prior.transformer._attn_mods.75.mlp.c_proj.b", "prior.transformer._attn_mods.75.ln_1.weight", "prior.transformer._attn_mods.75.ln_1.bias", "prior.transformer._attn_mods.76.attn.c_attn.w", "prior.transformer._attn_mods.76.attn.c_attn.b", "prior.transformer._attn_mods.76.attn.c_proj.w", "prior.transformer._attn_mods.76.attn.c_proj.b", "prior.transformer._attn_mods.76.ln_0.weight", "prior.transformer._attn_mods.76.ln_0.bias", "prior.transformer._attn_mods.76.mlp.c_fc.w", "prior.transformer._attn_mods.76.mlp.c_fc.b", "prior.transformer._attn_mods.76.mlp.c_proj.w", "prior.transformer._attn_mods.76.mlp.c_proj.b", "prior.transformer._attn_mods.76.ln_1.weight", "prior.transformer._attn_mods.76.ln_1.bias", "prior.transformer._attn_mods.77.attn.c_attn.w", "prior.transformer._attn_mods.77.attn.c_attn.b", "prior.transformer._attn_mods.77.attn.c_proj.w", "prior.transformer._attn_mods.77.attn.c_proj.b", "prior.transformer._attn_mods.77.ln_0.weight", "prior.transformer._attn_mods.77.ln_0.bias", "prior.transformer._attn_mods.77.mlp.c_fc.w", "prior.transformer._attn_mods.77.mlp.c_fc.b", "prior.transformer._attn_mods.77.mlp.c_proj.w", "prior.transformer._attn_mods.77.mlp.c_proj.b", "prior.transformer._attn_mods.77.ln_1.weight", "prior.transformer._attn_mods.77.ln_1.bias", "prior.transformer._attn_mods.78.attn.c_attn.w", "prior.transformer._attn_mods.78.attn.c_attn.b", "prior.transformer._attn_mods.78.attn.c_enc_kv.w", "prior.transformer._attn_mods.78.attn.c_enc_kv.b", "prior.transformer._attn_mods.78.attn.c_proj.w", "prior.transformer._attn_mods.78.attn.c_proj.b", "prior.transformer._attn_mods.78.ln_0.weight", "prior.transformer._attn_mods.78.ln_0.bias", "prior.transformer._attn_mods.78.mlp.c_fc.w", "prior.transformer._attn_mods.78.mlp.c_fc.b", "prior.transformer._attn_mods.78.mlp.c_proj.w", "prior.transformer._attn_mods.78.mlp.c_proj.b", "prior.transformer._attn_mods.78.ln_1.weight", "prior.transformer._attn_mods.78.ln_1.bias".
        size mismatch for prior.transformer._attn_mods.18.attn.c_attn.w: copying a param with shape torch.Size([4800, 3600]) from checkpoint, the shape in current model is torch.Size([4800, 1200]).
        size mismatch for prior.transformer._attn_mods.18.attn.c_attn.b: copying a param with shape torch.Size([3600]) from checkpoint, the shape in current model is torch.Size([1200]).
        size mismatch for prior.transformer._attn_mods.28.attn.c_attn.w: copying a param with shape torch.Size([4800, 3600]) from checkpoint, the shape in current model is torch.Size([4800, 1200]).
        size mismatch for prior.transformer._attn_mods.28.attn.c_attn.b: copying a param with shape torch.Size([3600]) from checkpoint, the shape in current model is torch.Size([1200]).
        size mismatch for prior.transformer._attn_mods.38.attn.c_attn.w: copying a param with shape torch.Size([4800, 3600]) from checkpoint, the shape in current model is torch.Size([4800, 1200]).
        size mismatch for prior.transformer._attn_mods.38.attn.c_attn.b: copying a param with shape torch.Size([3600]) from checkpoint, the shape in current model is torch.Size([1200]).
        size mismatch for prior.transformer._attn_mods.48.attn.c_attn.w: copying a param with shape torch.Size([4800, 3600]) from checkpoint, the shape in current model is torch.Size([4800, 1200]).
        size mismatch for prior.transformer._attn_mods.48.attn.c_attn.b: copying a param with shape torch.Size([3600]) from checkpoint, the shape in current model is torch.Size([1200]).
        size mismatch for prior.transformer._attn_mods.58.attn.c_attn.w: copying a param with shape torch.Size([4800, 3600]) from checkpoint, the shape in current model is torch.Size([4800, 1200]).
        size mismatch for prior.transformer._attn_mods.58.attn.c_attn.b: copying a param with shape torch.Size([3600]) from checkpoint, the shape in current model is torch.Size([1200]).
        size mismatch for prior.transformer._attn_mods.68.attn.c_attn.w: copying a param with shape torch.Size([4800, 3600]) from checkpoint, the shape in current model is torch.Size([4800, 1200]).
        size mismatch for prior.transformer._attn_mods.68.attn.c_attn.b: copying a param with shape torch.Size([3600]) from checkpoint, the shape in current model is torch.Size([1200]).