upcoming feature tracker
vkuzo opened this issue · 3 comments
vkuzo commented
configurability
- [done] support delayed vs dynamic scaling type, configurable separately for activations/weights/gradients
- [planned] support rowwise/blockwise scaling granularity, configurable separately for each gemm
- [planned] configure settings for each of the three gemms in linear fwd/bwd separately
- [planned] support more fine grained configuration of how to apply
Float8Linear
to individual modules - [planned] inference support (see #314)
performance
- [done]
torch._scaled_mm
support for per-tensor scaled float8 gemm - [in progress]
torch._scaled_mm
support for rowwise scaled float8 gemm- [done] eager mode support
- [planned] torch.compile support, backed by triton/cutlass
- [in progress] optimize torch.compile performance for float8 scaling/casting kernels
distributed
- [done] integrate with TP/SP via DTensor APIs
- [done] integrate with FSDP1 with 16-bit all-gather
- [done] integrate with FSDP2 with 16-bit or 8-bit all-gather with dynamic scaling for weights
- performance optimizations are ongoing
- [in progress] integrate with FSDP2 with 16-bit or 8-bit all-gather with delayed scaling for weights
- POC is done, performance optimizations are ongoing
- [planned] verify integration with PP
other
- weight gradient accumulation in float32
- add
use_fast_accum
(float8 accumulation of gemm) option to UX - #144 - improve saturated casting performance
bhack commented
Is there a plan to support AMP?
vkuzo commented
Is there a plan to support AMP?
Sorry for late reply! We don't have a plan to support AMP in the near future because the eng cost to support delayed scaling in an AMP-like API would be too high, because delayed scaling is stateful. For now we would like to have a consistent API between dynamic and delayed scaling for easy ablation studies. If the community converges on dynamic scaling in the future (which is stateless), we could adjust.
vkuzo commented
moved to pytorch/ao#556