Use Jax functions in Pytorch with DLPack, as outlined in a gist by @mattjj. The repository was made for the purposes of making this differentiable alignment work interoperable with Pytorch projects.
$ pip install jax2torch
By default, Jax pre-allocates 90% of VRAM, which leaves Pytorch with very little left over. To prevent this behavior, set the XLA_PYTHON_CLIENT_PREALLOCATE
environmental variable to false before running any Jax code:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax
import torch
from jax2torch import jax2torch
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# Jax function
@jax.jit
def jax_pow(x, y = 2):
return x ** y
# convert to Torch function
torch_pow = jax2torch(jax_pow)
# run it on Torch data!
x = torch.tensor([1., 2., 3.])
y = torch_pow(x, y = 3)
print(y) # tensor([1., 8., 27.])
# And differentiate!
x = torch.tensor([2., 3.], requires_grad = True)
y = torch.sum(torch_pow(x, y = 3))
y.backward()
print(x.grad) # tensor([12., 27.])