mfschubert/sparsejac

Error with "x64" precision

agudym opened this issue · 3 comments

Hi there!

If x64 precision is enabled for JAX:

jax.config.update("jax_enable_x64", True)

The following error appears in jacfwd:
TypeError: primal and tangent arguments to jax.jvp do not match; dtypes must be equal, or in case of int/bool primal dtype the tangent dtype must be float0.Got primal dtype float64

Due to

def jacfwd(
    fn: Callable[[Any], ArrayWithOptionalAux],
    sparsity: jsparse.BCOO,
    argnums: int = 0,
    has_aux: bool = False,
    coloring_strategy: str = _DEFAULT_COLORING_STRATEGY,
) -> Callable[[Any], ArrayWithOptionalAux]:
# ...
basis = basis.astype(jnp.float32)
# ...

I assume here an environmental variable should be used, isn't it ?

Cheers!

Hi @agudym , thanks for reporting this. Please feel free to suggest a fix.

Hi @agudym , thanks for reporting this. Please feel free to suggest a fix.

Actually, there's a simple fix. I'll make a patch shortly.

Hi @agudym , thanks for reporting this. Please feel free to suggest a fix.

Actually, there's a simple fix. I'll make a patch shortly.

Awesome, thanks!!!