/sam

Primary LanguagePythonMIT LicenseMIT

Implementation of SAM as a jax/optax GradientTransformation, with additional adaptive and periodic extensions. This codebase does not presently implement the layer-wise extensions specified in the latter report for large-batch training. See demo.py for a worked example of how to use the interface.