Lightning-AI/pytorch-lightning

FP8 mixed precision via nvidia's Transformer Engine

carmocca opened this issue · 6 comments

Description & Motivation

Support https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Pitch

Write a precision plugin using the library above that is enabled via:

  • precision="transformer-engine"

Alternatives

Don't implement this until it's vendored by PyTorch, if that ever happens.

Additional context

No response

cc @Borda @carmocca @justusschock @awaelchli

The library only requires enabling an autocast context manager

There is one more thing. The user needs to replace their layers with the custom ones from the library. What's the plan here? Will the plugin implement the module_init_context() manager? On the other hand, one might not want to replace all layers. If this is left to the user, then there is a lot less value in adding the plugin.

Yes, we'll need to implement a replacement mechanism. The plugin can have a flag to disable it if necessary

This also means that we'll have it in Fabric first, as these APIs do not exist in the trainer yet.

Actually convert_module might be a better fit than init_context if we prefer replacing existing layers than patching the torch.nn classes.

Any update on support for this?

@nanand2 Our access to H100s is very limited so we haven't merged this yet. However, the branch https://github.com/Lightning-AI/lightning/tree/carmocca/transformer-engine should be usable if you want to play with it right now

Great, thanks!