graphcore-research/jax-scalify

MNIST training broken with `min/max` scale propagation

Closed this issue · 2 comments

Fixing min/max ops scale propagation (PR #68 ) had the side effect of breaking MNIST training. Early investigation showing a divergence of the scale factors after a couple of iterations, similarly to an unstable dynamical system.

Todo: additional investigation to understand the dynamic of the issue.

#74 implements dynamic rescaling methods. An investigation on how to use these properly is still necessary.

Bug fixed in #75 and #76 , with dynamic rescaling of logits gradient.