Issues
- 0
`sharding` parameter bug in JAX 0.4.31
#132 opened by balancap - 0
Fix import bug regression with JAX `<0.4.28`
#121 opened by balancap - 0
Add implementation of `jax_scalify.tree` sub-module.
#115 opened by balancap - 0
Reference documentation website
#114 opened by balancap - 0
Fix usability of `Scalify` in JAX eager mode
#89 opened by balancap - 0
- 5
- 0
Training setup for 3 MLP layers
#37 opened by balancap - 0
Gaussian ops analysis notebook
#17 opened by balancap - 3
- 1
- 1
- 2
MNIST training example in FP8
#62 opened by balancap - 1
Use FP32 scale in MNIST training example
#77 opened by balancap - 2
- 1
Formalize minimal set of ops on a scale datatype
#64 opened by balancap - 0
Implement power of 2 scaling in `AutoScale`
#61 opened by balancap - 1
- 0
Support JAX (experimental) IPU version
#38 opened by balancap - 0
- 0
Support mixed normal/scaled graph in `AutoScale`
#22 opened by balancap - 0
- 0