Profiling code for low level CUDA analysis ?
Opened this issue ยท 5 comments
Coming here after reading this twitter thread and the results look super cool !
Is it possible to share the profiling scripts for both of the plots referenced there ? I was interested in plugging it into NVIDIA's Nsight Systems and look at the low level CUDA kernels (similar to what was done here).
Thanks for sharing this work !
Edit: Found this helpful thread about benchmarking torch.compile
which might be relevant.
Hi Mit,
Sure, I will push the code soon, likely on Friday/Saturday.
Thanks for sharing!
Apologies for reopening this but was wondering whether you have timing/benchmarking scripts that I can quickly throw onto my RTX A5500 to see what the numbers look like and hopefully do an apples to apples comparison with your plots. Thanks again for doing this !
Also was trying to get torch.compile
to work with fullgraph=True
and as you said on Twitter there seem to be some errors. Is it Ok if I file issues for them so that other folks can help out with debugging ?
Would be cool to comparejax.jit
vs torch.compile
for these type of implementations (Don't worry I am not trying to start another Twitter war :D )
And FYI for SEGNN and EGNN there's also these implementations that would be interesting to compare against.
I just added the script for CEGNN in JAX.
Concerning opening issues, please go ahead, I would love people to help. I was recently playing with CEGNNs and managed to indicate operations that make torch struggle with compilation:
- normalization (due to the for loop in algebra.norms, perhaps)
- broadcasting along subspaces in MVLinear
- weight = self.weight.repeat_interleave(self.algebra.subspaces, dim=-1)
- get_weight in Cayley table
I will update the codebase if I manage to resolve them :)