dfm/extending-jax

How would I include cuDNN?

helange23 opened this issue · 2 comments

First of all, thank you very much for this. You have saved me tons of work and I am very grateful for the great documentation.

I would like to extend JAX with custom calls that internally make use of cudnn. For this I added an include at the top of "kernels.cc.cu". I tried both of the following:

#include <cudnn.h>
#include "/usr/include/cudnn.h"

The compiler finds the header and does not complain when I add the following host code:

  cudnnHandle_t handle_;
  cudnnCreate(&handle_);

However as soon as I try to run the code from JAX, I get the error that cudnnCreate is an undefined symbol. If I remove the includes then the compiler complains.

Do you have any idea how I could potentially fix this?

dfm commented

It sounds like you'll need to edit the CMakeLists.txt file to link cudnn appropriately. I don't have any experience with this, so I'm not sure exactly what would be required, but that's the direction I would look!

Thanks! I will have a look,