[Tracking Issue] Relax Training M0
Ubospica opened this issue · 2 comments
There has been increased interest from the community in using TVM for training. Relax, the next generation graph level IR of TVM, also faces the demand of training model.
We are building a training workflow on Relax, including:
- an automatic differentiation tool based on source code transformation
- an optimizer abstraction and common optimizers
- a loss function abstraction and common loss functions
- Trainer API that integrates them together, and is easy to use
The training APIs can serve many needs. You will be able to:
- train a model from scratch. You can use the compilation advantages of TVM to speed up the training process.
- fine-tune a model on device based on TVM.
- deploy the process of training models to various devices that TVM supports, such as FPGA and Raspberry PI.
This work is mainly done by @SiriusNEO and @Ubospica, with the help from @tqchen @junrushao @MasterJH5574 @Hzfengsy @spectrometerHBH et al.
Further introduction of our work:
A jupyter notebook tutorial of the training APIs can be found here.
Detailed explanation of the AD pass, and its limitations can be found here.
Currently a large part of our work has been merged into the mlc repo. Now our work is tracked at this issue.
The APIs are still changing. We will update the tutorial within a period of time after the API is modified.
Discussion from the Feb. 7, 2023 community meeting after this work was presented:
- This seems to be focused on the single-device case, so the natural comparison is against PyTorch 2.0 or JAX. Relax might have more opportunities for optimizations
- Response: We might be able to extend to the distributed case using the method from ZeRO
- Q: Is it necessary to switch control between Relax and Python in the training loop?
- A: No, it is not a necessary portion of the design and might be costing us some efficiency. We can work on getting the entire training loop into Relax (possibilities: do gradient updates in Relax via PackedFuncs, use recursion for the loop itself) and try further optimizations
- Q: What are the next steps for this work? What are the priorities?
- A: call_tir would be the next big objective (doing AD in TIR), since this would be able to derive gradients directly from the TIR definitions for operators rather than require definitions.
- Q: Do you think changing the Relax AST could help for training? We might consider making the notion of parameters first-class, simplifying some of the APIs.
- A: A first-class parameter node would be nice, but we will try to get farther in the prototype before deciding on it. We might be able to accomplish parameters through function attributes; it might not be desirable to complicate the AST. Of possible changes, some notion of parameter seems like it would be the most useful, but we would have to figure out the necessary properties.
- Q: What obstacles are there to supporting AD for TIR PrimFuncs?
- A: It would be useful to have the legalizer to be able to get the PrimFunc implementations for Relax operators. Additionally, dealing with dynamic shapes in TIR could be a challenge.
I would be curious to hear about the plans for further developing AD in Relax. We should be able to support all of the language's features by building up the "tape" within Relax (this could be accomplished through PackedFunc
s).
We could consider using the approach in Relay's general-purpose AD pass, where the "tape" is built up using a closure: https://github.com/apache/tvm/blob/main/src/relay/transforms/higher_order_gradient.cc (I think we should aim to make AD as general as we can, as this would allow us to target all manner of diverse models for training).
Edit: One issue with using PackedFunc
s to represent the tape would be that the data structure and the updates would be opaque to the language. This would be fine for a first pass on a prototype, but would really prevent us from doing optimizations on it. This could be a good impetus to develop proper data structures for Relax, so that might be something to consider.