arxyzan/data2vec-pytorch

EMA model forward

anhvth opened this issue · 1 comments

        # model forward in online mode (student)
        x = self.encoder(src, mask, **kwargs)['encoder_out']  # fetch the last layer outputs
        if trg is None:
            return x

        # model forward in offline mode (teacher)
        with torch.no_grad():
            self.ema.model.eval()
            y = self.ema.model(trg, ~mask, **kwargs)['encoder_states']  # fetch the last transformer layers outputs

In the teacher forward pass the mask_time_indices is the inverse of the one in student, is this correct?
I think the mask in the teacher forward pass should be None since the teacher expects the full version of input data

Hi @anhvth, thanks for your feedback.
The teacher predicts representations from the masked indices in the input (the indices that are masked for src are not masked for trg and vice versa) so the mask must be the inverse of the one in the student.