google-research/torchsde

Different `t` for data in a minibatch

ain-soph opened this issue · 5 comments

In torchdiffeq ODE repo rtqichen/torchdiffeq#122 (comment), there is a dummy-variable trick that could enable users to use different t for different inputs in the same minibatch. I wonder if we can do the same thing on SDE equation. If so, could anyone kindly provide the formula of modified f and g?

dx = f(x,t)dt + g(x,t)dw

I'm currently solving SDE equation with a ts tuple (t_start, t_end)


For ODE dx=f(x,t)dt,
I could make a variable replacement t=(t_end - t_start)*s+t_start to transform different t range into the same s range in (0, 1).
And new_f(x, s) = dx/ds = dx/dt * dt/ds = f(x, t(s)) * (t_end - t_start)

I greedily remind everyone that I see in this repo. Hope that won't disturb.
@lxuechen @patrick-kidger @mtsokol

For batch-dependent times, probably the simplest thing to do is to use Diffrax instead, in conjunction with jax.vmap.

For batch-dependent times, probably the simplest thing to do is to use Diffrax instead, in conjunction with jax.vmap.

@patrick-kidger Thanks for your quick reply. That seems a perfect solution in JAX. Sadly all my codes are written in PyTorch. Is there anything that I could use in functorch? They have vmap as well.

You could try torch.vmap, but I believe it's untested with torchsde. (Or torchdiffeq.)

If you want to reparameterise then you can use the self-similarity of scaling property for Brownian motion. Note that this scales as sqrt(c) (rather than just c as in the ODE case).

That said, I've done this before, and it can behave a bit weirdly. It generally means you need to have your output times be a union over the whole batch, and then do a gather -- this is exceptionally slow. In addition torchsde's adaptive solvers handle the whole batch in one go, so if two different batch elements have very different scales then you might find more steps are taken than are really necessary.

(This exact use case was one of the things I was really looking to fix when I wrote Diffrax.)

@patrick-kidger Thanks for your advice! You are absolutely right!
dx = f(x,t)dt + g(x,t)dw(t)
I use the formula dw(t) = sqrt(dt) * ϵ = sqrt(dt/ds) * dw(s) and it becomes
dx = f(x,t) * dt/ds * ds + g(x,t) * sqrt(dt/ds) * dw(s)

I apply this in my code and it works fine. It's no longer necessary to involve vmap to transform an input-wise function to batch-wise. (Personally I think vmap style calls python function and it's not efficient as native parallel. I never do an execution speed comparison though.)
But just as you said, it becomes much slower. I guess it's because of the variable range increase.
(e.g., original t ranges in [0.3, 0.5], but now s range becomes [0, 1]. In the meantime, the solver step size is not changed, so it calculates much more steps.)
I'll try to unify s in range [0, precision] where precision is a mutable argument. I'll see if that could accelerate.

Again, thanks for your generous help! I'll close this issue once I test the precision and post results here.

Yes, After making precision the same as original t scale, the execution speed using s is the same as using raw t.