Meta-issue tracking PyTorch integration attempts
parasj opened this issue · 3 comments
Most viable approaches
Options sorted by order of implementation difficulty.
Support only nn.Sequential
models (no arbitrary graph support!)
At the moment, torch.utils.checkpoint
only supports Sequential layers and not arbitrary graphs. An easy way to integrate with PyTorch would be to continue using this API where we would simply call seq.children()
recursively. However, this approach could be error-prone as it is common to override the forward method of a Sequential block. Another example of nn.Sequential
parsing is https://gitlab.inria.fr/hiepacs/rotor/blob/master/rotor/inspection.py#L24-33.
Integration via module hooks
Status: LazyTensors are incomplete still, but might allows easy implementation of rematerialization. Care will need to be taken to
Backwards hooks are quite buggy still, so it seems prudent to avoid integrating via that method (see Maxim's comments on pytorch/pytorch#12331). It is reliable to attach a hook onto a tensor which is called AFTER the gradient with respect to the tensor is computed (https://pytorch.org/docs/stable/_modules/torch/tensor.html#Tensor.register_hook). Moreover, module hooks (pre-forward and forward) appear to be reliable and stable in my tests.
Supporting memory paging to CPU DRAM via unified memory is straightforward with a forward hook that moves the correct tensors to host via tensor.cpu()
. To support rematerialization using module + tensor hooks is that we would wrap the result of a forward operation in a wrapper similar to the ongoing LazyTensor implementation, and then detach the Tensor in the backward pass. During backpropagation, the lazy tensor would be evaluated prior to the corresponding gradient node.
A good example of using hooks to get to Tensors is the Distiller codebase from Intel Nervana, https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/100.
Adding a new JIT pass for rematerialization
Common Subexpression Elimination caused us some headaches when rematerializing graphs in TensorFlow. It seems that PyTorch also applies a CSE pass on JIT graphs. Is there an easy way to convert SSA form into a graph, then rematerialize nodes and apply the whole routine as a JIT pass (to be added in the GraphExecutor https://github.com/pytorch/pytorch/blob/f326045b3757236aabe367dfca1894be14ce31ef/torch/csrc/jit/graph_executor.cpp#L550-L612).
Approaches not currently viable
Integration via ONNX training support
Status: ONNX training support has not been merged yet, so not viable at the moment
Originally, we attempted to support PyTorch training via ONNX integration, which was inference-only at the time. We were able to extract PyTorch graphs and infer the backwards pass (via edge reversal) in the first public release of the code.
However, there was no clear path to support actually running the rescheduled graphs. We have since abandoned this line of investigation. If ONNX training support materializes (ONNX integration is ongoing in PR onnx/onnx-tensorflow#508), then this may become viable. We tried to use torch.autograd.backward for this, but encountered some issues (@aninrusimha has more context on the limitations of this). We need to extract a gradient function which given (a) the output of a layer and (b) the gradient with respect to the output of a layer that would produce the gradient with respect to each input of the layer. I believe one issue we encountered was that the output of a layer may not necessarily be derived from the inputs to the layer, and therefore the implicit graph stored in Tensors may not make logical sense.
Integration via TorchScript
Status:
TorchScript will trace an eager-mode PyTorch graph for the purpose of JIT optimizations. We attempted to leverage a similar approach to pytorchviz. However, it appears that the tracing API has changed significantly from PyTorch 0.3. We encountered a very similar issue to those faced by the authors of Astra in this issue.
Recently, there has been some success when training TorchScript models in C++. Moreover, pytorchviz has been updated with support for inferring training graphs via TorchScript tracing. They mention this approach is buggy.
Training models using TorchScript appears to be feasible now pytorch/pytorch#17614.
Language introspection to infer training graph from grad_fn
Status: Feasible to extract training graph and modify schedule, but difficult to accomplish and will make use by end users difficult.
We attempted to try to extract the training graph from the composed partial function grad_fn
. However, this proved to be difficult. Eunjie Jeong of the JANUS project had some success in hacking the Python interpreter to get to the PyTorch training graph. However, this is not a preferred approach if possible as it would preclude easy integration with established codebases for most users due to the need to release and maintain a custom version of Python.
User-level graph annotations
This is the approach that PyTorch checkpointing support currently leverages. However, the PyTorch checkpointing code is very buggy at the moment and barely is functional (anecdotal evidence from users within BAIR, but it seems to only moderately reduce training memory requirements). Original implementation was released in this PR and an optimized version of the checkpointing functionality is implemented in a third-party repo. Users either decorate which functions they would like to checkpoint, or pass an iterative list of layers upon which a very simple policy from Chen et al (2016) is applied. Current support only enables checkpointing a layer forever, not the dynamic and time-varying support we describe in the paper (aka rematerialization).
@aninrusimha can you elaborate on what issues we encountered using ONNX and torch.autograd.grad
?
Is this done?
Hi @eric-haibin-lin,
At this moment, we've decided not to pursue PyTorch support due to the complexity of integration. Happy to discuss ideas to extend Checkmate to new DL frameworks (PyTorch, MXNet etc.) if interested.