Wrong implemention for logsoftmax?
callowbird opened this issue · 1 comments
Hi,
This autograd tool for torch is wonderful! I like it a lot :)
I just have a small question. I found that the output of logSoftMax does not match the normal nn.logSoftMax when minibatch is applied (e.g., batch size >1). I think the problem is in util.logSumExp(), one shouldn't take max=torch.max(array). Instead, one should take the maximum of each row. Is it true? Are there any easy fix for that? (see below)
--Thanks!
function util.logSumExp(array)
local max = torch.max(array)
return torch.log(torch.sum(torch.exp(array-max))) + max
end
function util.logSoftMax(array)
return array - util.logSumExp(array)
end
I found the following fix (is it good enough?)
local function logSumExp(array)
local max = torch.max(array,2)
local c=torch.expand(max,array:size(1),array:size(2))
return torch.log(torch.sum(torch.exp(array-c),2)) +max
end
local function logSoftMax(array)
local rlt=logSumExp(array)
local c=torch.expand(rlt,array:size(1),array:size(2))
return array-c
end