yang-song/score_sde_pytorch

Round operation for discrete models

SANCHES-Pedro opened this issue · 2 comments

Hello,

Firstly, congratulations on the amazing work. The ICLR award was well deserved!

I don't want to be pedantic but I realized that the get_score_fn for discrete models doesn't have a torch.round() operation even though the t at training time is an int. Therefore, the sampling is being done with slightly different values than the training (e.g. 500.1 instead of 500). I'm not sure if this really affects performance, it's just an observation.

I would add labels = torch.round(labels) after line 155 of the models/utils.py file.

Many thanks,
Pedro

Thanks for the comment! You are right it is better to add an additional torch.round operation for discrete models, though I don't think results would change much. The current code has an additional benefit: you can use the continuous SDE framework even for models pre-trained with discrete losses (such as DDPM and NCSN models provided by previous work), which allows you to compute log-likelihoods, for example.

I see, I hadn't understood the motivation of not adding that. Thanks for clarifying!