yandex-research/rtdl

How to get the probablity in each multiclass?

jerronl opened this issue · 1 comments

other than the max bucket from argmax, can we also have the probabilities for each class? currently the prediction often have a lot negative values and I don't know what would be the right way to convert them to probabilities.

Yura52 commented

The predictions are logits. To convert them to probabilities, use softmax:

import torch.nn.functional as F

logits = model(x)
probabilities = F.softmax(logits, dim=-1)