GPU support for `sde_gan.py`
bras-p opened this issue · 4 comments
I am trying the run the example sde_gan.py
with GPU support.
I have run the code directly as it is in a Google Colab session (with GPU execution) as well as in local and in both cases, there is almost no improvement in the running time in comparison with CPU execution.
I have looked up the code but I did not find any problem in sde_gan.py
; maybe the problem is in the source files for torchsde and torchcde.
I have looked up the code but I did not find any problem in sde_gan.py; maybe the problem is in the source files for torchsde and torchcde.
Have you taken a look at GPU utilization?
@patrick-kidger might have better ideas.
I have looked up the code but I did not find any problem in sde_gan.py; maybe the problem is in the source files for torchsde and torchcde.
Have you taken a look at GPU utilization?
@patrick-kidger might have better ideas.
Yes, and GPU utilization is quite low (but not zero). The GPU is far from being fully used.
I don't think this reflects an issue in torchsde. Rather, it generally seems to be the case that pytorch + diffeqs are overhead-bound so GPU utilisation can be quite low.
If this is a concern for you then it's probably worth switching to JAX. For example this is the SDE-GAN example in JAX+Diffrax.
I don't think this reflects an issue in torchsde. Rather, it generally seems to be the case that pytorch + diffeqs are overhead-bound so GPU utilisation can be quite low.
I think this makes sense. Size of the model is also a factor. For small models, the Python overhead can be large compared to the computation on the GPU.
JAX's JIT would alleviate this.