computing EPEs on flowfiled with unvalid values
Closed this issue · 1 comments
On some optical flow datasets, there are many unvalid flow values on the flowfields. When we train the network on these datasets, do we need to change the way of computing EPEs to exclude the unvalid flow values during forward and backward propagation? It seems that you don't exclude the unvalid flow values when you compute EPE in your code.
what are the invalid valus like ? are they NaNs ?
I am not very familiar wth lua anymore (it's been a year !), but you can do a masked select of values on which you will compute a mean with a ByteTensor. Explanation here note that x:maskedSelect(mask)
is the same as x[mask]
provided mask is the right type.
Try this code for EPE forward function (although I'm not 100% sure it works, torch is not installed on my computer anymore)
function EPECriterion:updateOutput(input, target)
local diffMap = input-target
assert(input:nDimension() == 4 or input:nDimension() == 3)
if input:nDimension() == 4 then
local valid_pixels = torch.eq(target, target)[{{},1}] -- NaN are characterized as NaN != NaN, this is a tensor or BxHxW boolean values
self.EPE = diffMap:norm(2,2)[valid_pixels]:view(input:size(1), -1) -- get a tensor of Bx(H*W - nb_of invalid flow values) EPE values
else
local valid_pixels = torch.eq(target, target)[1] -- this is a tensor or HxW boolean values
self.EPE = diffMap:norm(2,1)[valid_pixels]
end
self.zeroEPE = torch.zeros(self.EPE:size()):cuda():fill(0)
self.output = self.criterion:forward(self.EPE, self.zeroEPE)
return self.output
end
you then do more or less the same thing for updateGradInput
function
By the way, there is a Pytorch version that is more up to date if you are willing to go from torch to pytorch (and I advise you to do so if not too much of your code is already in torch)
invalid flow values are already taken care of (in the case of 0 flow, but it is easily replaced with NaN flow) here :
https://github.com/ClementPinard/FlowNetPytorch
https://github.com/ClementPinard/FlowNetPytorch/blob/master/multiscaleloss.py#L7