Shape of inputs to SpearmanLoss?
Opened this issue · 0 comments
visrra commented
What are the expected shapes of the inputs to the forward function of SpearmanLoss? Does it accept batch data?
I think the SpearmanLoss's forward method may have to be modified to support batch input. E.g.
`
def forward(self, mem_pred, mem_gt, pr=False):
rank_gt = get_rank(mem_gt, -1)
if len(mem_pred.shape) == 1:
rank_pred = self.sorter(mem_pred.unsqueeze(
0)).view(-1)
else:
rank_pred = self.sorter(mem_pred)
return self.criterion_mse(rank_pred, rank_gt) + self.lbd * self.criterionl1(mem_pred, mem_gt)
`