dfm/extending-jax

How to reinterpret_cast a matrix?

Closed this issue · 2 comments

Dear @dfm , your tutorial is excellent. But I am not familar with c++. I have a naive question.

You said that how to receive input values by:

#include <cstdint> // int64_t

template <typename T>
void cpu_kepler(void *out, const void **in) {
const std::int64_t size = *reinterpret_cast<const std::int64_t *>(in[0]);
const T *mean_anom = reinterpret_cast<const T *>(in[1]);
const T *ecc = reinterpret_cast<const T *>(in[2]);
}

However, if one of my input values is a matrix, how can I reinterpret_cast it?

Thanks.

dfm commented

The parameters will always be raw pointers, even if they are defined as matrices or tensors in Python. So if you need more structure, you'll have to pass all the necessary parameters (e.g. strides) to your function as extra input parameters, and then use those to index your pointer appropriately. The example here works with any shape of input parameters, but just treats them as flat and aligned.

The logic for how this is handled is all in the XLA translation rule:

# We also need a translation rule to convert the function into an XLA op. In
# our case this is the custom XLA op that we've written. We're wrapping two
# translation rules into one here: one for the CPU and one for the GPU
def _kepler_translation(c, mean_anom, ecc, *, platform="cpu"):
# The inputs have "shapes" that provide both the shape and the dtype
mean_anom_shape = c.get_shape(mean_anom)
ecc_shape = c.get_shape(ecc)
# Extract the dtype and shape
dtype = mean_anom_shape.element_type()
dims = mean_anom_shape.dimensions()
assert ecc_shape.element_type() == dtype
assert ecc_shape.dimensions() == dims
# The total size of the input is the product across dimensions
size = np.prod(dims).astype(np.int64)
# The inputs and outputs all have the same shape so let's predefine this
# specification
shape = xla_client.Shape.array_shape(
np.dtype(dtype), dims, tuple(range(len(dims) - 1, -1, -1))
)
# We dispatch a different call depending on the dtype
if dtype == np.float32:
op_name = platform.encode() + b"_kepler_f32"
elif dtype == np.float64:
op_name = platform.encode() + b"_kepler_f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
# And then the following is what changes between the GPU and CPU
if platform == "cpu":
# On the CPU, we pass the size of the data as a the first input
# argument
return xops.CustomCallWithLayout(
c,
op_name,
# The inputs:
operands=(xops.ConstantLiteral(c, size), mean_anom, ecc),
# The input shapes:
operand_shapes_with_layout=(
xla_client.Shape.array_shape(np.dtype(np.int64), (), ()),
shape,
shape,
),
# The output shapes:
shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
)
elif platform == "gpu":
if gpu_ops is None:
raise ValueError(
"The 'kepler_jax' module was not compiled with CUDA support"
)
# On the GPU, we do things a little differently and encapsulate the
# dimension using the 'opaque' parameter
opaque = gpu_ops.build_kepler_descriptor(size)
return xops.CustomCallWithLayout(
c,
op_name,
operands=(mean_anom, ecc),
operand_shapes_with_layout=(shape, shape),
shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
opaque=opaque,
)
raise ValueError(
"Unsupported platform; this must be either 'cpu' or 'gpu'"
)

Hope this helps!

Dear @dfm , thank you very much.