Fusing?
srush opened this issue · 8 comments
This project is really neat.
Curious if you considered seperating out the load / store logic from the math to make it more easy to fuse operations? For instance would be nice to be able to use dropout (forward / backward) from within a NN+dropout fused function.
Hello,
Thank you for your interest in attorch. Yes, I have considered refactoring the math operations into a separate collection of functions that receive loaded arrays (in lieu of pointers) and return transformed arrays (in lieu of storing them), which can subsequently be incorporated into higher-level functions tasked with loading / storing data and applying the appropriate mathematical transformations. This is an appealing idea since it would enable aggressive fusion of kernels, as you say, in addition to rendering the design of most layers more modularizable (e.g., rather than having individual kernels for batch and layer normalization, they would both be instances of a more general pipeline - namely, calculate mean and standard deviation -> subtract by mean and divide by standard deviation -> perform affine transformation - that can handle arbitrary axes), but it would also sacrifice the atomic design of attorch, where each module has its own file that can easily be copy-pasted and modified to suit the user's needs. The raison d'être of this project is to facilitate exactly this type of forking or copy-pasting, and I fear that refactoring out the math would defeat the purpose. Nevertheless, I agree that for commonly-fused modules such as dropout, it would be particularly helpful to follow the pattern you refer to. I will flesh something out in the coming weeks and add support for fused dropout to linear layers.
I would be happy to hear your thoughts.
I should also add that one interesting benefit of pure math functions would be the possibility of a Triton counterpart to jax.grad
to perform autodiff, but that is a completely different topic.
I agree that having single file implementations is really nice. I guess the design would be moving from 2 functions to 4 for each module where the first is a load / store from pointer and the second is the mathematical implementation. Possibly wouldn't be so bad to keep the complexity low.
Actually this idea does tie into backprop. I wrote a small library https://github.com/srush/triton-autodiff that allows you to derive backprop code in Triton. It might let you derive backprop of fused ops directly.
I like the idea of having having a layer_forward_math
, layer_forward_kernel
, layer_backward_math
, and layer_backward_kernel
, but I am not sure if it is applicable to all layers. In the case of batch normalization, for example, it would not be possible to cram the math into a single function since at least two passes over the data are necessary - one to calculate the batch statistics, another to transform the input. I suppose we could have multiple math-only functions per kernel though (e.g., calc_mean_std
and normalize
). A compromise between the current design and a thoroughgoing refactoring of it would be to have separate calc_mean_std
and normalize
functions for each normalization layer. I will start experimenting with this.
triton-autodiff seems very cool. The difficulty of a full-fledged autodiff mechanism for Triton lies in impure functions, but the issue is side-stepped if only math functions are considered. Were you planning to add support for conditionals and loops too?
Yeah I agree these are interesting complications.
Agree that you don't want full autodiff for triton, but if you have a couple of specialized impure fwd/back functions you can potentially connect them together with autodiff.
The backend for triton-autodiff actually does support if / for which is neat (it's a general purpose python source course autodiff). I haven't tested it yet thoughly.
I have gone through triton-autodiff (super cool that the code is so concise) and Tangent, and I had an idea: I will keep the design of existing layers as-is but will add a separate math
module that implements the forward pass of basic mathematical functions whose gradients can be derived through triton-autodiff, thereby allowing developers to plug them into other kernels and do fusion should attorch not support their use case. There will inevitably be duplicate code involved, but I would rather have that than lose the simplicity of the neural net modules.
Cool. Feel free to reach out if I can help in anyway. Neat project.
I appreciate it, will do.