Error with "x64" precision
agudym opened this issue · 3 comments
agudym commented
Hi there!
If x64 precision is enabled for JAX:
jax.config.update("jax_enable_x64", True)
The following error appears in jacfwd
:
TypeError: primal and tangent arguments to jax.jvp do not match; dtypes must be equal, or in case of int/bool primal dtype the tangent dtype must be float0.Got primal dtype float64
Due to
def jacfwd(
fn: Callable[[Any], ArrayWithOptionalAux],
sparsity: jsparse.BCOO,
argnums: int = 0,
has_aux: bool = False,
coloring_strategy: str = _DEFAULT_COLORING_STRATEGY,
) -> Callable[[Any], ArrayWithOptionalAux]:
# ...
basis = basis.astype(jnp.float32)
# ...
I assume here an environmental variable should be used, isn't it ?
Cheers!
mfschubert commented
Hi @agudym , thanks for reporting this. Please feel free to suggest a fix.
mfschubert commented
Hi @agudym , thanks for reporting this. Please feel free to suggest a fix.
Actually, there's a simple fix. I'll make a patch shortly.