关于loss的疑问
Sunway-s opened this issue · 3 comments
Sunway-s commented
对于SparseOcc/models/loss_utils.py
中291行的 loss_classes 计算我有一些疑惑
tgt_class = class_gt[b]
tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device))
tgt_mask = tgt_mask.permute(1, 0)
src_idx, tgt_idx = indices[b]
src_mask = mask_pred[b][src_idx] # [M, N], M is number of gt instances, N is number of remaining voxels
tgt_mask = tgt_mask[tgt_idx] # [M, N]
src_class = class_pred[b] # [Q, CLS]
# pad non-aligned queries' tgt classes with 'no class'
pad_tgt_class = torch.full(
(src_class.shape[0], ), self.num_classes - 1, dtype=torch.int64, device=class_pred.device
) # [Q]
pad_tgt_class[src_idx] = tgt_class
为什么这里的 tgt_class
不用加 tgt_class[tgt_idx]
呢,而是直接 tgt_class = class_gt[b]
afterthat97 commented
Sorry for the late reply. It is a BUG. We retrained the model but it doesn't seem to affect the performance. We are working to find out the reason. Thank you very very much!
afterthat97 commented
afterthat97 commented
Thank you so much for reading our code so carefully!