IvanYashchuk/jax-fenics

Is the MixedFunctionSpace supported?

Closed this issue · 2 comments

Hello, @IvanYashchuk.
I have tried this repository.
Is the MixedFunctioSpace supported? If this is supported, Please tell me how to define the return 'u' in forward function decorated by 'build_jax_solve_eval'.
Thanks.

Hi, @Naruki-Ichihara!
Yes, that should work, however, it may not seem straightforward.
You can pass an instance of a Function defined on MixedFunctionSpace and inside the function decorated by build_jax_solve_eval or build_jax_assemble_eval you can use fenics.split to use individual functions from mixed space to define the variational form.
Based on mixed Poisson example:

import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as np

import fenics
import fdm # pip install fdm

from jaxfenics import solve_eval, vjp_solve_eval_impl
from jaxfenics import jvp_solve_eval
from jaxfenics import fenics_to_numpy, numpy_to_fenics

mesh = fenics.UnitSquareMesh(6, 5)
# Define finite elements spaces and build mixed space
BDM = fenics.FiniteElement("BDM", mesh.ufl_cell(), 1)
DG  = fenics.FiniteElement("DG", mesh.ufl_cell(), 0)
W = fenics.FunctionSpace(mesh, BDM * DG)

f = fenics.Expression(
    "10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2
)

# Define function G such that G \cdot n = g
class BoundarySource(fenics.UserExpression):
    def __init__(self, mesh, **kwargs):
        self.mesh = mesh
        super().__init__(**kwargs)
    def eval_cell(self, values, x, ufc_cell):
        cell = fenics.Cell(self.mesh, ufc_cell.index)
        n = cell.normal(ufc_cell.local_facet)
        g = fenics.sin(5*x[0])
        values[0] = g*n[0]
        values[1] = g*n[1]
    def value_shape(self):
        return (2,)

G = BoundarySource(mesh, degree=2)

# Define essential boundary
def boundary(x):
    return x[1] < fenics.DOLFIN_EPS or x[1] > 1.0 - fenics.DOLFIN_EPS

def solve_fenics(w):
    sigma, u = fenics.split(w)
    (tau, v) = fenics.TestFunctions(W)

    # Define variational form
    dot, div, dx = fenics.dot, fenics.div, fenics.dx
    a = (dot(sigma, tau) + div(tau)*u + div(sigma)*v)*dx
    L = - f*v*dx
    F = a - L # create suitable form for F == 0 solving

    bcs = [fenics.DirichletBC(W.sub(0), G, boundary)]
    fenics.solve(F == 0, w, bcs=bcs)
    return w, F, bcs

templates = (fenics.Function(W), )

w_numpy = np.ones(W.dim())
numpy_output, _, _, _, _ = solve_eval(solve_fenics, templates, w_numpy)

# Now coupling with JAX, before that everything was only related to FEniCS
from jaxfenics import build_jax_solve_eval
jax_solve_eval = build_jax_solve_eval(templates)(solve_fenics)

jax_jac = jax.jacrev(jax_solve_eval)(w_numpy) # adjoint based jacobian (identity matrix for this example)
fdm_jac = fdm.jacobian(jax_solve_eval)(w_numpy) # finite-differencing based jacobian (takes a LOT of time)
assert np.allclose(fdm_jac, jax_jac)

Thanks, @IvanYashchuk .
This example looks like good. I will try it.
I will share my code if it works well.