metaopt/torchopt

[Feature Request] Implement differentiable discrete randomness

Benjamin-eecs opened this issue · 0 comments

Required prerequisites

Motivation

I would like to request a feature that allows torchopt to differentiate through programs with discrete randomness, such as flipping a coin with probability p of being heads. This would enable gradient-based optimization of stochastic models that involve discrete choices or events. Currently, torchopt does not support automatic differentiation (AD) of such programs because they have a discontinuous dependence on parameters.

Solution

A possible solution is to implement the method proposed by Arya et al. (2022) in their paper "Automatic Differentiation of Programs with Discrete Randomness". This method uses a reparameterization-based technique that generates new programs whose expectation is the derivative of the expectation of the original program. The paper shows how this method gives an unbiased and low-variance estimator which is as automated as traditional AD mechanisms. The paper also demonstrates unbiased forward-mode AD of discrete-time Markov chains, agent-based models such as Conway's Game of Life, and unbiased reverse-mode AD of a particle filter.

Alternatives

One alternative solution is to use score function estimators such as REINFORCE or RELAX, which are based on the log-derivative trick. However, these estimators have high variance and require additional hyperparameters such as baselines or control variates. Another alternative solution is to use discrete variational autoencoders (VAEs), which use a continuous relaxation of discrete variables. However, these methods introduce a bias in the gradient estimation and may not preserve the semantics of the original program.

Additional context

Arya et al., 2022: https://arxiv.org/pdf/2210.08572.pdf