optimize pytorch computation graph for training
Opened this issue · 4 comments
hi,
thanks for the great work. But does it support optimize computation graph of pytorch for faster training? If supports, is there any benchmark?
@knsong Thanks for your interest in TASO. We support optimizing PyTorch graphs by transforming the graph to ONNX format using torch.onnx
. For training, please set the batch_size, which is typically the first dimension of the input tensors in ONNX graphs, to the desired number.
Note that TASO currently only considers forward operators, so you will get a graph optimized for forward processing.
@knsong Thanks for your interest in TASO. We support optimizing PyTorch graphs by transforming the graph to ONNX format using
torch.onnx
. For training, please set the batch_size, which is typically the first dimension of the input tensors in ONNX graphs, to the desired number.Note that TASO currently only considers forward operators, so you will get a graph optimized for forward processing.
Did you ever train the optimized graph using pytorch or tensorflow?
TASO optimizes for inference performance (i.e., minimizing forward processing time). The optimized graph is mathematically equivalent to the original graph, and therefore can be used for training as well, though the graphs are optimized for inference.
We are currently working on adding training cost into the cost model, and will update this thread only the training support is ready.
TASO optimizes for inference performance (i.e., minimizing forward processing time). The optimized graph is mathematically equivalent to the original graph, and therefore can be used for training as well, though the graphs are optimized for inference.
I notice that TASO can merge a conv and a following BN into a single conv for best inference efficiency, but that's not suitable for training.
We are currently working on adding training cost into the cost model, and will update this thread only the training support is ready.
Thanks. Looking forward to it.