thunlp/OpenPrompt

loss function error in tutorial /0_basic.py

Juanhui28 opened this issue · 0 comments

hello,

Thanks for providing the amazon repo. But I have a question regarding the loss function in tutorial /0_basic.py. The loss function is loss_func = torch.nn.CrossEntropyLoss() which applies the softmax and log internally. And when we use PromptForClassification and from openprompt.prompts import ManualVerbalizer as model and verbalizer, the default setting for the output of logits = prompt_model(inputs) would go through the softmax and log : https://github.com/thunlp/OpenPrompt/blob/main/openprompt/prompts/manual_verbalizer.py#L147. Then the logits are input the the loss_func which would go through the softmax and log again. I guess there is something wrong? I doesn't make sense the logits go through softmax and log twice. Thanks.