google-research/torchsde

GPU support for `sde_gan.py`

bras-p opened this issue · 4 comments

I am trying the run the example sde_gan.py with GPU support.
I have run the code directly as it is in a Google Colab session (with GPU execution) as well as in local and in both cases, there is almost no improvement in the running time in comparison with CPU execution.
I have looked up the code but I did not find any problem in sde_gan.py; maybe the problem is in the source files for torchsde and torchcde.

I have looked up the code but I did not find any problem in sde_gan.py; maybe the problem is in the source files for torchsde and torchcde.

Have you taken a look at GPU utilization?

@patrick-kidger might have better ideas.

I have looked up the code but I did not find any problem in sde_gan.py; maybe the problem is in the source files for torchsde and torchcde.

Have you taken a look at GPU utilization?

@patrick-kidger might have better ideas.

Yes, and GPU utilization is quite low (but not zero). The GPU is far from being fully used.

I don't think this reflects an issue in torchsde. Rather, it generally seems to be the case that pytorch + diffeqs are overhead-bound so GPU utilisation can be quite low.

If this is a concern for you then it's probably worth switching to JAX. For example this is the SDE-GAN example in JAX+Diffrax.

I don't think this reflects an issue in torchsde. Rather, it generally seems to be the case that pytorch + diffeqs are overhead-bound so GPU utilisation can be quite low.

I think this makes sense. Size of the model is also a factor. For small models, the Python overhead can be large compared to the computation on the GPU.

JAX's JIT would alleviate this.