LTH14/mar

About diffloss

Closed this issue · 2 comments

Hi,

In this part

mar/models/mar.py

Lines 232 to 238 in fe470ac

def forward_loss(self, z, target, mask):
bsz, seq_len, _ = target.shape
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
loss = self.diffloss(z=z, target=target, mask=mask)
return loss
, why you did not filter unmasked tokens (by mask) and just pass the masked tokens to the self.diffloss? You then use the mask to compute the loss only for masked tokens.

Thanks

LTH14 commented

Both implementations are ok

Thanks for the prompt response.