peng-lab/BaSiCPy

Implement DCT form Jax

Closed this issue · 9 comments

Jax is a numerical acceleration library that should allow easy switching of computation from CPU to GPU. This is a step toward adding complete Jax support as requested in #4

This is blocked by #15

https://github.com/google/jax

  • Add GPU settings to the DCT base class, including whether the class supports GPU as a class property and a device parameter to the dct and idct methods
  • Implement a Jax DCT class

We may have to write our own idct. Jax currently only provides type 2 dct.

https://github.com/google/jax/blob/bebe9845a873b3203f8050395255f173ba3bbb71/jax/_src/scipy/fft.py#L41

Isn't type 2 what we want? Or are you saying they don't have type 3 so we can't do the inverse?

https://docs.scipy.org/doc/scipy/reference/generated/scipy.fftpack.dct.html#scipy-fftpack-dct

I mean they don't have an idct function and my understanding is that dct type 3 == idct for dct type 2.

So, right, we can't do the inverse dct.

How hard would it be to write our own (and then maybe make a PR on the JAX repo)?

How hard is it? I don't know. I feel like I've done it in the past. I don't think it should be hard, but doing it the "correct" way for JAX might increase the level of difficulty. Let's put it on the back burner until we get a mostly functional version, then make this high priority.

Sounds good. Looking at the source code for dct, I think it will be very doable when we get to it.

Are we good to close this? @Nicholas-Schaub @yfukai

Add GPU settings to the DCT base class, including whether the class supports GPU as a class property and a device parameter to the dct and idct methods

This is not done yet, but I think this can be set from the JAX side. Let's close this for now.