/torch-reparametrised-mixture-distribution

PyTorch implementation of the mixture distribution family with implicit reparametrisation gradients.

Primary LanguagePython

Reparametrisable PyTorch MixtureSameFamily distribution

PyTorch implementation of the implicit reparametrisation trick for mixture distributions based on Figurnov et al., 2019, "Implicit Reparameterization Gradients" and the implementation in Tensorflow Probability.

Can be readily used for variational inference with mixture distribution variational families.

Remarks:

  • For multivariate mixtures, the class is currently implemented when the mixture component distributions fully factorise.
  • Also added a StableNormal distribution, which overrides the default cdf method with a more stable implementation from pytorch/pytorch#52973 (comment). The implementation also provides a _log_cdf method, however it is not used for the implicit reparametrisation.