Support AD of PyTorch Calls within Julia functions
azane opened this issue · 5 comments
In order to e.g. use neural odes, we need neural networks to operate within the dynamics. These are typically pytorch
modules involving impure forward passes (#13), and torch functions (like an activation function) in those forward passes (a challenge due to the fact that torch functions can only operate on torch tensors).
For example, the backward pass in the code below will fail for a variety of reasons. First, the relu
will fail due to the ArrayValue
not being a Tensor
, the gradients won't show up in m
and b
(#13), and even when calling .item()
on m
and b
the issues discussed in #14 would need addressing.
class NeuralNetwork(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.Parameter(torch.tensor(1.))
self.b = torch.nn.Parameter(torch.tensor(3.))
def forward(self, x):
# In practice this would be more complex.
return self.m * torch.relu(x) + self.b
nn = NeuralNetwork()
def f(x):
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
return nn(x)
y = JuliaFunction.apply(f, torch.tensor(1., requires_grad=True))
y.backward()
In practice, however, one wants to use julia from within python for its speed, which is undercut when the julia function (e.g. an ODE solve) has to surface back to python to use python features (e.g. pytorch). I.e. pursuing this as a feature really doesn't play to the strengths of using julia as a backend in general or as a backend ODE solver.
To expand on the final point in the context of using diffeqpy as a backend for chirho — if we need to surface to python within the solve, then we might as well use the torchdiffeq backend because we won't be getting the performance benefits (afaik). The advantage of a diffeqpy backend is going to be speed, which we can only get if there is no python involved in the dynamics (or it can be compiled away).
Hey, sorry for the slow response.
Up until now, juliatorch has been focused on allowing the use of Julia functions within a PyTorch stack. This would support a pytorch optimization loop around a model written in python where some of the layers are calls to Julia (e.g. diffeq solvers).
I've been operating under the assumption that the function f
passed as the first argument to JuliaFunction.apply
is a thin wrapper around a function implemented in Julia. This explains why I didn't think of supporting #13 (julia functions are unlikely to depend on python global state) and why I didn't find #14 concerning until you pointed it out. In the Julia world, most functions can be expected to handle all sorts of unexpected types so passing in a few different types in different phases of evaluation didn't strike me as concerning. However, as soon as you use a python function for f
, then that becomes an issue.
If f
is, itself, a PyTorch model or some other non-trival python code, then there is no need to wrap it in a JuliaFunction.apply
. PyTorch should handle AD.
As for surfacing python within an ODE solver, that is in the realm of DiffEqPy and is actually supported reasonably well. It can get good performance by converting python code into a symbol expression and then performing symbolic simplification, and then evaluating the result in Julia.
To see how this would work, the pipeline looks like this
A pytorch optimization loop calls python function including neural net layers or other pytorch constructs that support AD. Among these "pytorch constructs that support AD" are JuliaFunction
s. For example, one of them may be a JuliaFunction
wrapping a differential equation solver. That differential equation solver may in turn be implemented in Python, but diffeqpy knows how to handle python code in it's dynamics so that should be okay.
Here's an example from the readme that I've added extra annotations to to try to clarify how it fits into this conversation
from juliatorch import JuliaFunction
from diffeqpy import de
import juliacall, torch
jl = juliacall.Main.seval
# Define the ODE kernel
# Note the dynamics are written in python
def ode_f(du, u, p, t):
x = u[0]
v = u[1]
dx = v
dv = -p * x
du[0] = dx
du[1] = dv
# Use diffeqpy to solve the differential equation for given parameters
def solve(parameters):
x0, v0, p = parameters
tspan = (0.0, 1.0)
# Why not just use `de.ODEProblem`? That would pass gradcheck but fail in the
# optimization loop. See https://github.com/SciML/juliatorch/issues/10
prob = de.seval("ODEProblem{true, SciMLBase.FullSpecialize}")(ode_f, [x0, v0], tspan, p)
return de.solve(prob)
# Extract the desired results
def solve_and_query(parameters):
sol = solve(parameters)
return de.hcat(sol(.5), sol(1.0))
print(solve_and_query([1, 2, 3]))
# [1.5274653930969104 0.9791625277649281; -0.023690980408490492 -2.0306945154435274]
# At this point `solve_and_query` is a thin wrapper around the Julia function `de.solve`.
# It's admittedly kind of annoying and finicky to write this wrapper (e.g. the use of Julia functions
# instead of python functions for solution handling in `de.hcat(sol(.5), sol(1.0))`)
# but there isn't any substantive computation in the wrapper so it's not a huge deal
# JuliaFunction.apply(solve_and_query, x) now behaves as a PyTorch compatible autograd function
# and it can be used just like other pytorch functions that support AD
x = torch.randn(3, dtype=torch.double, requires_grad=True)
print(JuliaFunction.apply(solve_and_query, x))
# tensor([[-0.4471, -0.3979],
# [ 0.3155, -0.1103]], dtype=torch.float64,
# grad_fn=<JuliaFunctionBackward>)
# Verify that autograd through solve_and_query is correct
from torch.autograd import gradcheck
print(gradcheck(JuliaFunction.apply, (solve_and_query, x), eps=1e-6, atol=1e-4))
# True
# We could go on to write an optimization loop on top of this (see the readme for an example of that)
we need neural networks to operate within the dynamics.
As long as we can get diffeqpy
to solve differential equations with neural networks within the dynamics, then we should be able to wrap those calls to solve
in JuliaFunction.apply, even if we are unable to wrap the neural networks directly.
The fundamental issue here is that we don't have robust AD over Python code so we can't compute gradients at the DifferentialEquation solver / pytorch boundary when pytorch is embedded in the dynamics
This is okay (and an example in the readme demonstrates it):
Toplevel python code
Python optimization loop
Diffeqsolver
Dynamics written in plain python
But it's only works for simple dynamics that can accept duals.
But this is IIUC what you want, and what is not currently supported [edit: it is supported but not if you also want to compute gradients via AD which, presumably you do want]:
Toplevel python code
Python optimization loop
Diffeqsolver
Dynamics written in plain python
Dynamics call into pytorch
which would require AD integration with pytorch not only in the pytorch calls jula direction but also in the Julia calls pytorch direction that's totally possible, but will be a new feature and take some time to implement.
Example of how the second example fails
from juliatorch import JuliaFunction # Load this first to avoid https://github.com/pytorch/pytorch/issues/78829
import numpy as np
from juliacall import Main as jl
from diffeqpy import de
import torch
import torch.nn as nn
import torch.nn.functional as F
model = nn.Sequential(
nn.Linear(2, 6),
nn.ReLU(),
nn.Linear(6, 2),
nn.Sigmoid())
def ode_f(du, u, p, t):
p2 = np.array(p)
model[0].weight = nn.Parameter(torch.from_numpy(p2[0:12]).reshape(6, 2))
model[2].weight = nn.Parameter(torch.from_numpy(p2[12:24]).reshape(2, 6))
model[0].bias = nn.Parameter(torch.from_numpy(p2[24:30]))
model[2].bias = nn.Parameter(torch.from_numpy(p2[30:32]))
res = model(torch.from_numpy(np.array(u)))
du[0] = res[0].item()
du[1] = res[1].item()
ode_f([0,0], np.array([1.0,1.0]), np.random.randn(32), 0)
def solve(parameters):
x0, v0, p = 1.0, 1.0, parameters
tspan = (0.0, 1.0)
# Why not just use `de.ODEProblem`? That would pass gradcheck but fail in the
# optimization loop. See https://github.com/SciML/juliatorch/issues/10
prob = de.seval("ODEProblem{true, SciMLBase.FullSpecialize}")(ode_f, [x0, v0], tspan, p)
return de.solve(prob)
solve(np.random.randn(32))
def solve_and_query(parameters):
sol = solve(parameters)
return de.hcat(sol(.5), sol(1.0))
print(solve_and_query(np.random.randn(32)))
parameters = torch.randn(32, dtype=torch.double, requires_grad=True)
print(JuliaFunction.apply(solve_and_query, parameters))
# Okay up to here
from torch.autograd import gradcheck
print(gradcheck(JuliaFunction.apply, (solve_and_query, parameters), eps=1e-6, atol=1e-4))
# Fails with
# TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.
Hey @LilithHafner, thanks!
Yes — you are totally right that the real ask here is AD not only through julia in a pytorch graph, but through pytorch inside of julia in a pytorch graph 😵💫.
For now, I've essentially pivoted to just assuming that we're going to always jit-compile the dynamics so that we can get the speed benefits (which is the primary advantage for chirho wrt diffeqpy, at least currently — given that we don't support e.g. SDEs). This assumption rules out e.g. a complex pytorch function inside the dynamics. I.e. no rush on this from us (chirho) right now. We have plenty of non-neural-ode applications we can use it for. :)
Also, I did some simple profiling, and if we run the solver with an uncompiled dynamics, it goes about 10x slower than torchdiffeq (and this is after moving the ODEProblem definition outside the solve loop, and using remake instead), while it runs about 10x faster than torchdiffeq when compiling the dynamics. I.e. especially if involving pytorch inside of julia, speed could continue to be an issue.
Since you mentioned it here, I wanted to say that I'll also post more info tomorrow re: #14 — at this point I have a solution on our end just by wrapping "julia things" (i.e symbolics, duals, etc.) in a way that keeps numpy from trying to unpack them as >32 ndim arrays 😠. It's a little clunky but could be helpful for the more general case as well. Anyway — stay tuned.
TLDR though is that I am unblocked at the moment on the chirho side.
@LilithHafner just added the promised post to #14.