what is the meaning of the function logsumexp
fgqile opened this issue · 0 comments
fgqile commented
Why is the difference required after the maximum value is taken
the code is following
## Helper function for log sum exp calculation:
def logsumexp(inputs, dim=None, keepdim=False):
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs