FEniCS/basix

Provide interface to run Basix functions inside Numba-compiled kernels

Opened this issue · 2 comments

Currently, a function is provided to allow dof transformations to be applied, but it would be good to allow more functions to be called without having to reimplement everything in Python.

@mscroggs could you sketch out a use case?

Hello! Any update on this? By the way, is there any reason that prevents numba from JITting pybind11 generated functions?

@garth-wells Here is a real world use case for you:
Let's say I need to interpolate the results of a FEM simulation on a very large set of points. This is required, e.g. when working on implementing digital volume correlation where we seek to optimize a displacement field from two volumic images. As that interpolation is part of the process of minimizing an objective function, we also need the jacobian of this interpolation operator, which is basically the associated sparse interpolation matrix. The entries of that matrix are computed by evaluating each FEM basis function at a given coordinate.

A snipped of code to construct such a matrix would be the following (this code may contain some errors, I've not checked it, but it's just to illustrate the idea):

import numpy as np
import dolfinx.mesh, dolfinx.fem as fem
import basix
from basix.ufl_wrapper import BasixElement
from petsc4py import PETSc


def interpolation_matrix(
    x: np.ndarray,
    x_to_cell: np.ndarray,
    mesh: dolfinx.mesh.Mesh,
    element: basix.finite_element.FiniteElement,
):
    """Create a sparse interpolation matrix from a function space to a set of fixed points"""
    nx = np.shape(x)[0]

    # create dofmap
    ufl_element = BasixElement(element)
    V = fem.FunctionSpace(mesh, ufl_element)
    dofmap = V.dofmap
    tdim = element.dim
    bs = element.value_size

    rows = np.zeros(nx * bs + 1, dtype=np.int32)
    cols = np.zeros(nx * tdim * bs, dtype=np.int32)
    vals = np.zeros(nx * tdim * bs)

    num_cells = mesh.topology.index_map(mesh.topology.dim).size_global
    num_dofs_x = mesh.geometry.dofmap.links(0).size
    coords = mesh.geometry.x
    x_dofs = mesh.geometry.dofmap.array.reshape(num_cells, num_dofs_x)

    # loop needs to be JIT'd
    for k in range(nx):
        if len(x_to_cell[k]) == -1: # point is outside of the mesh
            rows[k + 1] = rows[k]
            continue

        cell = x_to_cell[k]

        vertices = np.array([coords[x_dofs[cell, i]] for i in range(num_dofs_x)])
        x_ref = mesh.geometry.cmap.pull_back(
            x[k].reshape(1, -1), vertices
        )  # need to use JIT'd basix.pull_back
        tab = element.tabulate(0, x_ref)  # need to use JIT'd basix element tabulation
        num_entries = np.shape(tab)[2]

        columns = dofmap.cell_dofs(cell)

        for b in range(bs):
            rows[k * bs + b + 1] = rows[k * bs + b] + num_entries
            cols[rows[k * bs + b] : rows[k * bs + b] + num_entries] = columns
            vals[rows[k * bs + b] : rows[k * bs + b] + num_entries] = tab[
                ..., b
            ].flatten()

    matrix = PETSc.Mat().createAIJWithArrays(
        size=(nx * bs, dofmap.index_map.size_global),
        csr=(rows, cols, vals),
        comm=PETSc.COMM_SELF,
    )
    matrix.assemble()

    return matrix

where x_to_cell is a mapping from the interpolation points to the mesh cells constructed beforehand.
As nx is typically very large, this python loop has miserable performance, and from my experience, constructing this matrix can be slower that solving several FEM problems.
That code could benefit from JIT'd versions of core basix functionality, such as tabulating a FEM basis and performing push-forward and pull-back transformations.