FP8 mixed precision via nvidia's Transformer Engine
carmocca opened this issue · 6 comments
Description & Motivation
Support https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Pitch
Write a precision plugin using the library above that is enabled via:
precision="transformer-engine"
Alternatives
Don't implement this until it's vendored by PyTorch, if that ever happens.
Additional context
No response
The library only requires enabling an autocast context manager
There is one more thing. The user needs to replace their layers with the custom ones from the library. What's the plan here? Will the plugin implement the module_init_context()
manager? On the other hand, one might not want to replace all layers. If this is left to the user, then there is a lot less value in adding the plugin.
Yes, we'll need to implement a replacement mechanism. The plugin can have a flag to disable it if necessary
This also means that we'll have it in Fabric first, as these APIs do not exist in the trainer yet.
Actually convert_module
might be a better fit than init_context
if we prefer replacing existing layers than patching the torch.nn
classes.
Any update on support for this?
@nanand2 Our access to H100s is very limited so we haven't merged this yet. However, the branch https://github.com/Lightning-AI/lightning/tree/carmocca/transformer-engine should be usable if you want to play with it right now
Great, thanks!