issue on ckpt loading in train_distill.py
folkartist opened this issue · 1 comments
I try to run train_distill.py to replicate the C2F method and then modify the code to fine-tune the model on my own datasets, I download both ckpts of clip and frame versions and set the corresponding args of the path. But I got error on the code sentence :
state_dict_cls = cls_s["state_dict"]
error message :
KeyError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'state_dict'
File "/work/wpd/audiossl/train_distill.py", line 50, in <module>
state_dict_cls = cls_s["state_dict"]
I tried changing "state_dict" to "student" but it cause a new error...
How do I fix it?
Thank you if you would like to help!
and really good work you have done
Hi, very sorry for the late response!
I try to run train_distill.py to replicate the C2F method and then modify the code to fine-tune the model on my own datasets, I download both ckpts of clip and frame versions and set the corresponding args of the path. But I got error on the code sentence :
state_dict_cls = cls_s["state_dict"]
error message :KeyError (note: full exception trace is shown but execution is paused at: _run_module_as_main) 'state_dict' File "/work/wpd/audiossl/train_distill.py", line 50, in <module> state_dict_cls = cls_s["state_dict"]
I tried changing "state_dict" to "student" but it cause a new error... How do I fix it? Thank you if you would like to help! and really good work you have done
Which checkpoint did you use for the argument --pretrain_ckpt_path_clip? This should be the checkpoint of the ATST-Clip-Audioset, which means the clip model finetuned on audioset. I guess you used the checkpoint of ATST-Clip, which is wrong. I hope this can solve your problem.
I also check the train_distill.py, the code was not adapted to pytorch 2.1.1 and lightning 2.2.1. Please use the newest commit.