/ReinMax

Beyond Straight-Through

Primary LanguagePythonMIT LicenseMIT

PyTorch PyPI - Python Version GitHub Maintenance PyPI

ReinMax

Beyond Straight-Through

Straight-ThroughReinMaxHow To UseExamplesCitationLicense

ReinMax achieves second-order accuracy and is as fast as the original Straight-Through, which has first-order accuracy.

What is Straight-Through

Straight-Through (as below) bridges discrete variables (y_hard) and back-propagation.

y_soft = theta.softmax()

# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft) 

# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard 

It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:

  • y_soft - y_soft.detach(),
  • (theta/tau).softmax() - (theta/tau).softmax().detach(),
  • or what?

Understand Straight-Through and Go Beyond

We reveal that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which approximates gradient with second-order accuracy with negligible computation overheads.

How to use?

reinmax can be installed via pip

pip install reinmax

To replace Straight-Through Gumbel-Softmax with ReinMax:

from reinmax import reinmax
...
- y_hard = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
+ y_hard, _ = reinmax(logits, tau) # note that reinmax prefers to set tau >= 1, while gumbel-softmax prefers to set tau < 1
...

To replace Straight-Through with ReinMax:

from reinmax import reinmax
...
- y_hard = one_hot_multinomial(logits.softmax()) 
- y_soft_tau = (logits/tau).softmax()
- y_hard = y_soft_tau - y_soft_tau.detach() + y_hard 
+ y_hard, y_soft = reinmax(logits, tau) 
...

Examples

Citation

Please cite the following papers if you found our model useful. Thanks!

Liyuan Liu, Chengyu Dong, Xiaodong Liu, Bin Yu, and Jianfeng Gao (2023). Bridging Discrete and Backpropagation: Straight-Through and Beyond. ArXiv, abs/2304.08612.

@inproceedings{liu2023bridging,
  title={Bridging Discrete and Backpropagation: Straight-Through and Beyond},
  author = {Liu, Liyuan and Dong, Chengyu and Liu, Xiaodong and Yu, Bin and Gao, Jianfeng},
  booktitle = {arXiv:2304.08612 [cs]},
  year={2023}
}