MCG-NJU/MixFormerV2

MixFormerDistillStage1Actor __call__ arguments with option 'use_amp'

MrNeoBlue opened this issue · 3 comments

Hello, I want to train the student network throught distillation process following your tutorial.
However in stage 1 i encountered the problem with forward pass

            # forward pass
            if not self.use_amp:
                loss, stats = self.actor(data, remove_rate_cur_epoch)
            else:
                with autocast():
                    loss, stats = self.actor(data)

Im not familiar with the option use_amp, would you please explain me the usage of it? the actor is actually MixFormerDistillStage1Actor whose call function only contains data input.

Do I have to add AMP field in cfg to config.py file to control the flow?

Hello! amp means mixed precision, we don't use this in the project, i.e. use_amp is always False.
I'm sorry there exist some issues in the code, you can add an second parameters in MixFormerDistillStage1Actor.__call__ definition, like __call__(self, data, remove_rate_cur_epoch) which remove_rate_cur_epoch is not used.
I will fix this soon.

thx, I modified the call function and the training script is now working.