Is the MixedFunctionSpace supported?
Closed this issue · 2 comments
Naruki-Ichihara commented
Hello, @IvanYashchuk.
I have tried this repository.
Is the MixedFunctioSpace supported? If this is supported, Please tell me how to define the return 'u' in forward function decorated by 'build_jax_solve_eval'.
Thanks.
IvanYashchuk commented
Hi, @Naruki-Ichihara!
Yes, that should work, however, it may not seem straightforward.
You can pass an instance of a Function defined on MixedFunctionSpace and inside the function decorated by build_jax_solve_eval
or build_jax_assemble_eval
you can use fenics.split
to use individual functions from mixed space to define the variational form.
Based on mixed Poisson example:
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as np
import fenics
import fdm # pip install fdm
from jaxfenics import solve_eval, vjp_solve_eval_impl
from jaxfenics import jvp_solve_eval
from jaxfenics import fenics_to_numpy, numpy_to_fenics
mesh = fenics.UnitSquareMesh(6, 5)
# Define finite elements spaces and build mixed space
BDM = fenics.FiniteElement("BDM", mesh.ufl_cell(), 1)
DG = fenics.FiniteElement("DG", mesh.ufl_cell(), 0)
W = fenics.FunctionSpace(mesh, BDM * DG)
f = fenics.Expression(
"10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2
)
# Define function G such that G \cdot n = g
class BoundarySource(fenics.UserExpression):
def __init__(self, mesh, **kwargs):
self.mesh = mesh
super().__init__(**kwargs)
def eval_cell(self, values, x, ufc_cell):
cell = fenics.Cell(self.mesh, ufc_cell.index)
n = cell.normal(ufc_cell.local_facet)
g = fenics.sin(5*x[0])
values[0] = g*n[0]
values[1] = g*n[1]
def value_shape(self):
return (2,)
G = BoundarySource(mesh, degree=2)
# Define essential boundary
def boundary(x):
return x[1] < fenics.DOLFIN_EPS or x[1] > 1.0 - fenics.DOLFIN_EPS
def solve_fenics(w):
sigma, u = fenics.split(w)
(tau, v) = fenics.TestFunctions(W)
# Define variational form
dot, div, dx = fenics.dot, fenics.div, fenics.dx
a = (dot(sigma, tau) + div(tau)*u + div(sigma)*v)*dx
L = - f*v*dx
F = a - L # create suitable form for F == 0 solving
bcs = [fenics.DirichletBC(W.sub(0), G, boundary)]
fenics.solve(F == 0, w, bcs=bcs)
return w, F, bcs
templates = (fenics.Function(W), )
w_numpy = np.ones(W.dim())
numpy_output, _, _, _, _ = solve_eval(solve_fenics, templates, w_numpy)
# Now coupling with JAX, before that everything was only related to FEniCS
from jaxfenics import build_jax_solve_eval
jax_solve_eval = build_jax_solve_eval(templates)(solve_fenics)
jax_jac = jax.jacrev(jax_solve_eval)(w_numpy) # adjoint based jacobian (identity matrix for this example)
fdm_jac = fdm.jacobian(jax_solve_eval)(w_numpy) # finite-differencing based jacobian (takes a LOT of time)
assert np.allclose(fdm_jac, jax_jac)
Naruki-Ichihara commented
Thanks, @IvanYashchuk .
This example looks like good. I will try it.
I will share my code if it works well.