technicolor-research/sodeep

Shape of inputs to SpearmanLoss?

Opened this issue · 0 comments

What are the expected shapes of the inputs to the forward function of SpearmanLoss? Does it accept batch data?

I think the SpearmanLoss's forward method may have to be modified to support batch input. E.g.
`
def forward(self, mem_pred, mem_gt, pr=False):
rank_gt = get_rank(mem_gt, -1)

    if len(mem_pred.shape) == 1:
        rank_pred = self.sorter(mem_pred.unsqueeze(
            0)).view(-1)
    else:
        rank_pred = self.sorter(mem_pred)
        
    return self.criterion_mse(rank_pred, rank_gt) + self.lbd * self.criterionl1(mem_pred, mem_gt)

`