eagle705/pytorch-bert-crf-ner

학습 모델 로드시 질문

robinsongh381 opened this issue · 1 comments

안녕하세요

좋은 자료 공유해주셔서 감사의 말씀을 우선 정합니다.

Inference.py 에서

    convert_keys = {}
    for k, v in checkpoint['model_state_dict'].items():
        new_key_name = k.replace("module.", '')
        if new_key_name not in model_dict:
            print("{} is not int model_dict".format(new_key_name))
            continue
        convert_keys[new_key_name] = v

다음과 같이 convert_keys 를 정의하고 model.load_state_dict(convert_keys)를 하셨는데,
왜 바로 model.load_state_dict(checkpoint['model_state_dict']) 하시지 않았는지 아니면 하면 안되는지 궁금하여 질문을 드립니다

감사합니다

아마 모델을 분산학습 시키셔서 모든 weight들의 이름이 module.~ 이런식으로 되어 있을 겁니다.
따라서 그냥 model.load_state_dict(checkpoint['model_state_dict'])를 하면 key error가 발생합니다.