loss function error in tutorial /0_basic.py
Juanhui28 opened this issue · 0 comments
hello,
Thanks for providing the amazon repo. But I have a question regarding the loss function in tutorial /0_basic.py
. The loss function is loss_func = torch.nn.CrossEntropyLoss()
which applies the softmax and log internally. And when we use PromptForClassification
and from openprompt.prompts import ManualVerbalizer
as model and verbalizer, the default setting for the output of logits = prompt_model(inputs)
would go through the softmax and log : https://github.com/thunlp/OpenPrompt/blob/main/openprompt/prompts/manual_verbalizer.py#L147. Then the logits are input the the loss_func
which would go through the softmax and log again. I guess there is something wrong? I doesn't make sense the logits go through softmax and log twice. Thanks.