kekmodel/MPL-pytorch

Teacher Gradients

monney opened this issue ยท 13 comments

Hi,
I understand in the reference implementation, the MPL loss on the teacher does nothing essentially. To fix this we use hard labels rather than soft ones.

For this purpose, I believe we should not be detaching t_logits_us here:
https://github.com/kekmodel/MPL-pytorch/blob/main/main.py#L208

It seems to be correct. I will try again. Thank you.

Test accuracy is 94.25. Unfortunately, the performance has decreased. What is the reason?

Interesting result.

Here: google-research/google-research#534 (comment)
The author says they use hard labels for large scale tasks, and for smaller ones soft labels converge faster. But I don't know if they used temperature scaling to get it to work. If we use soft labels without scaling then the training is almost identical to the UDA paper, the only benefit over UDA is distillation to the student, and the entire dot product code does nothing.

The released code has an unused flag for MPL temperature:
https://github.com/google-research/google-research/blob/master/meta_pseudo_labels/flag_utils.py#L150
so perhaps in the unreleased code they used temperature scaling for soft labels. I'm not really sure. It might be best to ask the author. The large scale results show metapseudolabels helps, but perhaps at small scales they accidentally ended up using just UDA for the teacher.

google-research/google-research#534 (comment)
Author said that when using hard pseudo labels, 96.06% performance came out.

I had read that comment previously but wasn't sure if he meant hard labels, or the repo as is, I think you're right though.

I looked through this repo's code some more. Here are some other differences which might be impacting performance:

  • In the original repo in addition to warmup steps the author also uses wait steps before which the student doesn't train at all,
    Like this:
def get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps,
    num_wait_steps,
    num_training_steps,
    num_cycles=0.5,
    last_epoch=-1,
):
    def lr_lambda(current_step):
        if current_step < num_wait_steps:
            return 0.0
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return max(
            0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)
  • In the original repo, weight decay is not applied to batch norm layers, in Pytorch it is by default, so you have to put BN layers in their own parameter group and set weight decay to 0 for that group.
    Like this:
def add_weight_decay(model, weight_decay=1e-4):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if 'bn' in name:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

I'll look through some more later and see if I see anything else.

I had read that comment previously but wasn't sure if he meant hard labels, or the repo as is, I think you're right though.

I looked through this repo's code some more. Here are some other differences which might be impacting performance:

  • In the original repo in addition to warmup steps the author also uses wait steps before which the student doesn't train at all,
    Like this:
def get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps,
    num_wait_steps,
    num_training_steps,
    num_cycles=0.5,
    last_epoch=-1,
):
    def lr_lambda(current_step):
        if current_step < num_wait_steps:
            return 0.0
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return max(
            0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)

This might be important. I'll apply it. Thanks!

  • In the original repo, weight decay is not applied to batch norm layers, in Pytorch it is by default, so you have to put BN layers in their own parameter group and set weight decay to 0 for that group.
    Like this:
def add_weight_decay(model, weight_decay=1e-4):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if 'bn' in name:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

I've already tested this, but it's rather less accurate...

I'll look through some more later and see if I see anything else.

Hi!
As of right now the only other difference which I think might be important is that the BatchNorm momentum should be set to 0.01 not 0.001 for 0.99 momentum.
(Perhaps excluding from weight decay would work correctly if this is fixed?)

Nice Work!

Hi. Firstly, thank you for you wonderful implementation!

I was wondering about the implementation you used for the dot product in the calculation of the teacher MPL loss. Please correct me if I'm wrong, but I believe in the paper they use cosine similarity (although they refer to it as cosine distance, as mentioned in Appendix C.3). In particular they say that the value of this operation will lie in [-1, 1], which seems to the cosine similarity, since this would be normalizing the inner product of <a, b> by the magnitude of a and b.

However, in the implementation this is calculated as just a subtraction. Is this inconsistent with the paper? I may just be missing something.

Again, appreciate your awesome work.

@as2626 please see the reference implementation in the google research repo. They do the same thing. In their open implementation they drop cosine similarity for the raw dot product. Then use the first term in the Taylor expansion to approximate the dot product, which results in the subtraction.
https://github.com/google-research/google-research/blob/master/meta_pseudo_labels/training_utils.py#L472

Thanks for the note @monney ! :) Could you elaborate about the Taylor expansion to approximate the dot product? Or point me in the right direction to read more.

@as2626
See the authors note here:
google-research/google-research#534

And the Taylor expansion here:
https://mathworld.wolfram.com/TaylorSeries.html

Plugging in values and rearranging terms will get you the first order approximation.

@monney Hello, thank you for the nice explanation!
Sorry for being a beginner in this field, I'd like to know what is the role of CE loss on teacher gradients in line 218.
Is it devised to make the distribution of teacher output sharper like hard pseudo label?
Thanks in advance!

@kekmodel if i need to generate the labels of unlabelled data, will i have to use Teacher Model at the end of training ?