twitter-archive/torch-autograd

Wrong implemention for logsoftmax?

callowbird opened this issue · 1 comments

Hi,

This autograd tool for torch is wonderful! I like it a lot :)

I just have a small question. I found that the output of logSoftMax does not match the normal nn.logSoftMax when minibatch is applied (e.g., batch size >1). I think the problem is in util.logSumExp(), one shouldn't take max=torch.max(array). Instead, one should take the maximum of each row. Is it true? Are there any easy fix for that? (see below)

--Thanks!

function util.logSumExp(array)
local max = torch.max(array)
return torch.log(torch.sum(torch.exp(array-max))) + max
end

function util.logSoftMax(array)
return array - util.logSumExp(array)
end

I found the following fix (is it good enough?)

local function logSumExp(array)
local max = torch.max(array,2)
local c=torch.expand(max,array:size(1),array:size(2))
return torch.log(torch.sum(torch.exp(array-c),2)) +max
end

local function logSoftMax(array)
local rlt=logSumExp(array)
local c=torch.expand(rlt,array:size(1),array:size(2))
return array-c
end