loss computation wrong?
tt6746690 opened this issue · 2 comments
It seems that the loss implementation (https://github.com/sangmichaelxie/doremi/blob/main/doremi/trainer.py#L360) is not exactly the same as the loss in the paper. In the implementation, the normalizer is Σ_i α_i Σ_{x\in Dᵢ} |x|
but should just be Σ_{x\in Dᵢ} |x|
for samples from i-th domain. Any comments on this observation?
Here is the code that implements the loss in the paper. It seems you get smoother domain weights using the following implementation.
# compute the rescaled loss, divide by domain weights
train_domain_weights = self.read_weights().to(pertoken_loss.device)
# if doing non-uniform sampling, normalize by inverse sampling weight
train_domain_weights = train_domain_weights / self.sampling_weights.to(train_domain_weights.device)
train_domain_weights = train_domain_weights / train_domain_weights.sum()
# (#domains,) total number of tokens amongst samples from each domain
perdomain_num_tokens = []
for domain_id in range(len(train_domain_weights)):
domain_mask = (inputs['domain_ids'] == domain_id)
if domain_mask.sum() > 0:
num_tokens = token_mask[domain_mask].sum()
else:
num_tokens = torch.tensor(0., device=token_mask.device)
perdomain_num_tokens.append(num_tokens)
perdomain_num_tokens = torch.stack(perdomain_num_tokens)
## sync between procs `perdomain_num_tokens` since different procs
# might process micro-batch samples from the same domain.
dist.all_reduce(perdomain_num_tokens, op=torch.distributed.ReduceOp.SUM)
# scale by world size because DDP averages gradients
perdomain_num_tokens = perdomain_num_tokens / self.args.world_size
# avoid division by zero
perdomain_num_tokens[torch.where(perdomain_num_tokens==0)] = 1.
# (#domains,) equivalent to αᵢ / Σ_{x\in D_i} |x|
perdomain_coeff = train_domain_weights/perdomain_num_tokens
# (bsz, seq_len-1)
coeff = perdomain_coeff[inputs['domain_ids']].unsqueeze(-1) * token_mask
loss = (pertoken_loss * coeff.detach()).sum()
As long as there are a sufficient number of tokens in the batch, the two objectives should be about the same (and they converge to the same value with more samples). In my tests, the two objectives seem to return almost the same average domain weights. Our implementation follows Pytorch's class weighting implementation for the CrossEntropyLoss, and makes sure that the loss scale is preserved even when there are some domains missing (could be important for situations where there are more domains).
About your code, I think that if you choose to normalize by the observed token frequency, you don't need to do this part:
# if doing non-uniform sampling, normalize by inverse sampling weight
train_domain_weights = train_domain_weights / self.sampling_weights.to(train_domain_weights.device)
train_domain_weights = train_domain_weights / train_domain_weights.sum()
Thanks for the clarification!