mhamilton723/STEGO

Ground Truth Mask Label is used when training

bio-mlhui opened this issue · 0 comments

I saw that the ground truth label is used when computing the linear_loss:

detached_code = torch.clone(code.detach())
linear_logits = self.linear_probe(detached_code)
...
linear_loss = self.linear_probe_loss_fn(linear_logits[mask], flat_label[mask]).mean()
loss += linear_loss
...
self.manual_backward(loss)

In my opinion, although detach_code does not requies grad, the linear_logits will generate a new torch graph, where linear_probe is included. Backward linear_loss will also accumulate gradient to linear_probe
Also, the evaluation code is

feats, code1 = par_model(img)
feats, code2 = par_model(img.flip(dims=[3]))
code = (code1 + code2.flip(dims=[3])) / 2
...
linear_probs = torch.log_softmax(model.linear_probe(code), dim=1)
...
linear_preds = linear_probs.argmax(1)
model.test_linear_metrics.update(linear_preds, label)

which means the linear_probe is used when compuating the metrics.

I can not figure out this problem and it would be very kind for you to teach me where I am wrong. Thanks so much!