Little-Podi/GRM

About the 'threshold' inference

kun-dragon opened this issue · 1 comments

Could you please explain why certain datasets require a threshold and why there are different thresholds for them during inference?
` if self.training:
# During training
decision = F.gumbel_softmax(divide_prediction, hard=True)
else:
# During inference
if threshold:
# Manual rank based selection
decision_rank = (F.softmax(divide_prediction, dim=-1)[:, :, 0] < threshold).long()
else:
# Auto rank based selection
decision_rank = torch.argsort(divide_prediction, dim=-1, descending=True)[:, :, 0]

            decision = F.one_hot(decision_rank, num_classes=2)`

Hi. This threshold controls the amount of search tokens that interact with the template tokens. For some benchmarks, the challenging scenarios make it hard to precisely determine the region to interact, thus we may need to lower the constraint to improve the recall. On the contrary, the division results might be always confident on some other datasets. In those cases, we can be more strict to promote the feature interaction process thus improve the tracking performance. As for a binary classification formulation, the default threshold is 0.5, which I implemented it as "auto rank" here. From my experience, it already works fine enough and minor tuning may improve the overall performance a little.