perrying/realistic-ssl-evaluation-pytorch

Some questions for your code

Closed this issue · 4 comments

Hi, thanks for your codes. I have some small questions.
Q1. train.py: the 161th row: cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean(). The outputs contains labled data and unabled data.
I think it may be inappropriate.
Q2. which way you use to calculate the loss in pseudo_label.py

The outputs contain labeled data and unlabeled data.

yes, but the targets of unlabeled data are given -1 and it's ignored by ignore_index option. So, cls_loss is calculated with only labeled data.

which way you use to calculate the loss in pseudo_label.py

cross entropy between pseudo labels and outputs

Thanks for your replay, but the .mean() will calculate the mean values of 100 rather than 50

Besides, in pseudo_label.py, I think only the unlabeled samples are needed for calculating the loss.

but the .mean() will calculate the mean values of 100 rather than 50

Original implementation also calculate the loss with the same way.

I think only the unlabeled samples are needed for calculating the loss.

You are right, my implementation is wrong.
I'll fix it. Thanks.