pytorch-labs/float8_experimental

upcoming feature tracker

vkuzo opened this issue · 3 comments

vkuzo commented

configurability

  • [done] support delayed vs dynamic scaling type, configurable separately for activations/weights/gradients
  • [planned] support rowwise/blockwise scaling granularity, configurable separately for each gemm
  • [planned] configure settings for each of the three gemms in linear fwd/bwd separately
  • [planned] support more fine grained configuration of how to apply Float8Linear to individual modules
  • [planned] inference support (see #314)

performance

  • [done] torch._scaled_mm support for per-tensor scaled float8 gemm
  • [in progress] torch._scaled_mm support for rowwise scaled float8 gemm
    • [done] eager mode support
    • [planned] torch.compile support, backed by triton/cutlass
  • [in progress] optimize torch.compile performance for float8 scaling/casting kernels

distributed

  • [done] integrate with TP/SP via DTensor APIs
  • [done] integrate with FSDP1 with 16-bit all-gather
  • [done] integrate with FSDP2 with 16-bit or 8-bit all-gather with dynamic scaling for weights
    • performance optimizations are ongoing
  • [in progress] integrate with FSDP2 with 16-bit or 8-bit all-gather with delayed scaling for weights
    • POC is done, performance optimizations are ongoing
  • [planned] verify integration with PP

other

  • weight gradient accumulation in float32
  • add use_fast_accum (float8 accumulation of gemm) option to UX - #144
  • improve saturated casting performance

Is there a plan to support AMP?

Is there a plan to support AMP?

Sorry for late reply! We don't have a plan to support AMP in the near future because the eng cost to support delayed scaling in an AMP-like API would be too high, because delayed scaling is stateful. For now we would like to have a consistent API between dynamic and delayed scaling for easy ablation studies. If the community converges on dynamic scaling in the future (which is stateless), we could adjust.