The last dimension of predicted logits discarded in inference but used in training
questionstorer opened this issue · 2 comments
When I'm trying out the code of mask2former, I have the following observations
Let's say we have a label set with classes in [0, num_seg)
. Background is labelled 0
.
- During inference, in the function
semantic_inference
inmask2former/maskformer_model.py
, the last axis of the predicted logitsmask_pred
is sliced and the last dimension is excluded. This convert themask_pred
from shape(..., num_seg+1)
to(..., num_seg)
. This last dimension is considered as NULL class and is not used in inference. Any pixel is then classified as one of thenum_seg
classess later. Is this observation correct? - However, during training, in calculating the cross entropy loss on labels in
loss_labels
inmask2former/modeling/criterion.py
in this sentenceloss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
,
src_logits.transpose(1, 2)
containsnum_seg+1
channels and the last channel stands for the NULL classtarget_classes
contain classes for each label and it has value in[0, num_seg)
So actually the last channel of the
num_seg+1
channels ofsrc_logits.transpose(1, 2)
can never be indexed, but they can appear in the denominator in the CrossEntropyLoss. So the last channel is actually in the training.
Is this understanding correct?
- Here comes a question on the consistency between training and inference. If a channel is used in training but discarded in inference. Is there a consistency issue? If this NULL label can never be predicted during inference, does that mean we are not supposed to use this class as a background class? Do we always have to add
background
as one of our class in the label?
The reason background class is discarded in the inference is purely because they are not supported in the evaluation metric. If there is a metric that supports predicting background, then we do not need to discard it.
Can we use Mask2Former on the semantic segmentation task with the PascalVOC2012 dataset? This dataset contains the 'background' class. I tried Mask2Former with it, and it just won't work. I find out that gradient explosion occurs.