facebookresearch/msn

Why we need multiply 1.25 for num_epochs?

andy-yangz opened this issue · 2 comments

Hi! thank you guys to share such interesting research first.

I have some question when I read this part of source code.

    # -- momentum schedule
    _start_m, _final_m = 0.996, 1.0
    _increment = (_final_m - _start_m) / (ipe * num_epochs * 1.25)
    momentum_scheduler = (_start_m + (_increment*i) for i in range(int(ipe*num_epochs*1.25)+1))

    # -- sharpening schedule
    _increment_T = (_final_T - _start_T) / (ipe * num_epochs * 1.25)
    sharpen_scheduler = (_start_T + (_increment_T*i) for i in range(int(ipe*num_epochs*1.25)+1))

Why we need multiply 1.25 for the num_epochs? In the paper, it write "with a momentum value of 0.996, and linearly increase this value to 1.0 by the end of training", but in this situation it can only increase to 0.9992.

msn/src/msn_train.py

Lines 214 to 215 in 81cb855

ipe = len(unsupervised_loader)
logger.info(f'iterations per epoch: {ipe}')

msn/src/msn_train.py

Lines 252 to 259 in 81cb855

# -- momentum schedule
_start_m, _final_m = 0.996, 1.0
_increment = (_final_m - _start_m) / (ipe * num_epochs * 1.25)
momentum_scheduler = (_start_m + (_increment*i) for i in range(int(ipe*num_epochs*1.25)+1))
# -- sharpening schedule
_increment_T = (_final_T - _start_T) / (ipe * num_epochs * 1.25)
sharpen_scheduler = (_start_T + (_increment_T*i) for i in range(int(ipe*num_epochs*1.25)+1))

The first value is _start_m. The last value is _start_m + (_increment*i) where:

  • _increment is (_final_m - _start_m) / delta
  • i is int(delta)

and delta is (ipe * num_epochs * 1.25)

So the last value is pretty close to _final_m, no? It is _start_m + (_final_m - _start_m) * int(delta) / delta.

If ipe * num_epochs is a multiple of 100, then delta is an integer anyway.

Hi @andy-yangz,

Thanks for your question.

By multiplying the schedule by 1.25 we only use "80%" of the cosine schedule. You can get a similar effect just by running your pre-training with a regular cosine schedule and applying early stopping. As @woctezuma pointed out, the last value is quite close to _final_m, but I will ensure that this is mentioned in the implementation details.

Please let me know if this answers your question! Otherwise I'd be happy to reopen the issue.