maxxxzdn/jax-geometric

Profiling code for low level CUDA analysis ?

Opened this issue ยท 5 comments

Coming here after reading this twitter thread and the results look super cool !

Is it possible to share the profiling scripts for both of the plots referenced there ? I was interested in plugging it into NVIDIA's Nsight Systems and look at the low level CUDA kernels (similar to what was done here).

Thanks for sharing this work !

Edit: Found this helpful thread about benchmarking torch.compile which might be relevant.

Hi Mit,

Sure, I will push the code soon, likely on Friday/Saturday.
Thanks for sharing!

Apologies for reopening this but was wondering whether you have timing/benchmarking scripts that I can quickly throw onto my RTX A5500 to see what the numbers look like and hopefully do an apples to apples comparison with your plots. Thanks again for doing this !

Also was trying to get torch.compile to work with fullgraph=True and as you said on Twitter there seem to be some errors. Is it Ok if I file issues for them so that other folks can help out with debugging ?

Would be cool to comparejax.jit vs torch.compile for these type of implementations (Don't worry I am not trying to start another Twitter war :D )

And FYI for SEGNN and EGNN there's also these implementations that would be interesting to compare against.

I just added the script for CEGNN in JAX.

Concerning opening issues, please go ahead, I would love people to help. I was recently playing with CEGNNs and managed to indicate operations that make torch struggle with compilation:

  • normalization (due to the for loop in algebra.norms, perhaps)
  • broadcasting along subspaces in MVLinear
    • weight = self.weight.repeat_interleave(self.algebra.subspaces, dim=-1)
  • get_weight in Cayley table

I will update the codebase if I manage to resolve them :)