학습 모델 로드시 질문
robinsongh381 opened this issue · 1 comments
robinsongh381 commented
안녕하세요
좋은 자료 공유해주셔서 감사의 말씀을 우선 정합니다.
Inference.py 에서
convert_keys = {}
for k, v in checkpoint['model_state_dict'].items():
new_key_name = k.replace("module.", '')
if new_key_name not in model_dict:
print("{} is not int model_dict".format(new_key_name))
continue
convert_keys[new_key_name] = v
다음과 같이 convert_keys
를 정의하고 model.load_state_dict(convert_keys)
를 하셨는데,
왜 바로 model.load_state_dict(checkpoint['model_state_dict'])
하시지 않았는지 아니면 하면 안되는지 궁금하여 질문을 드립니다
감사합니다
dave-rtzr commented
아마 모델을 분산학습 시키셔서 모든 weight들의 이름이 module.~ 이런식으로 되어 있을 겁니다.
따라서 그냥 model.load_state_dict(checkpoint['model_state_dict'])를 하면 key error가 발생합니다.