Calling sparsejac within a jitted linen module
jamesheald opened this issue · 2 comments
jamesheald commented
Great work on this function. I noticed that sparsejac complains with a ConcretizationTypeError when I call it within a jitted linen module (if I don't jit the module call, it works fine, but that's not practical). Here is a MWE:
import flax.linen as nn
import jax
from jax import numpy as jnp
import sparsejac
from jax.experimental.sparse import BCOO
class Network(nn.Module):
dim: int
def setup(self):
self.dense = nn.Dense(features = self.dim)
self.sparsity = BCOO.fromdense(jnp.eye(self.dim))
fn = lambda x: self.dense(x)
self.sparse_fn = sparsejac.jacrev(fn, self.sparsity)
def __call__(self, x):
return self.sparse_fn(x).todense()
dim = 2
model = Network(dim)
params = model.init(x = jnp.ones((dim,)), rngs = {'params': jax.random.PRNGKey(0)})
jit_model = jax.jit(model.apply)
jit_model(params, x = jnp.ones((dim,)))
The traceback is
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 5, in setup
File "/nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/jax/experimental/sparse/bcoo.py", line 2510, in fromdense
return bcoo_fromdense(
File "/nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/jax/experimental/sparse/bcoo.py", line 266, in bcoo_fromdense
nse = _count_stored_elements(mat, n_batch, n_dense)
File "/nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/jax/experimental/sparse/util.py", line 106, in _count_stored_elements
return int(_count_stored_elements_per_batch(mat, n_batch, n_dense).max(initial=0))
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function apply at /nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/flax/linen/module.py:1831 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[2,2] = add b c
from line <stdin>:5 (setup)
I want to call the sparse jacobian function with a linen module
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line <stdin>:5 (setup)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line <stdin>:5 (setup)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
I have tried to call sparsejac outside of the linen module and pass the sparsified function in as an argument but that doesn't work either:
import flax.linen as nn
import jax
from jax import numpy as jnp
import sparsejac
from jax.experimental.sparse import BCOO
from functools import partial
class Network(nn.Module):
dim: int
def setup(self):
self.dense = nn.Dense(features = self.dim)
def __call__(self, sparse_fn, x):
return sparse_fn(x, self.dense).todense()
dim = 2
sparsity = BCOO.fromdense(jnp.eye(dim))
fn = lambda x, mod: mod(x)
sparse_fn = sparsejac.jacrev(fn, sparsity)
model = Network(dim)
params = model.init(sparse_fn = sparse_fn, x = jnp.ones((dim,)), rngs = {'params': jax.random.PRNGKey(0)})
jit_model = jax.jit(model.apply)
jit_model(params, x = jnp.ones((dim,)), sparse_fn = sparse_fn)
Any suggestions on what I should do here?
Thanks in advancd
mfschubert commented
@jamesheald, thanks! I believe the solution here would be to put the jacrev
call inside the jax.ensure_compile_time_eval
context manager, i.e.
import flax.linen as nn
import jax
from jax import numpy as jnp
import sparsejac
from jax.experimental.sparse import BCOO
class Network(nn.Module):
dim: int
def setup(self):
self.dense = nn.Dense(features = self.dim)
with jax.ensure_compile_time_eval():
self.sparsity = BCOO.fromdense(jnp.eye(self.dim))
fn = lambda x: self.dense(x)
self.sparse_fn = sparsejac.jacrev(fn, self.sparsity)
def __call__(self, x):
return self.sparse_fn(x).todense()
dim = 2
model = Network(dim)
params = model.init(x = jnp.ones((dim,)), rngs = {'params': jax.random.PRNGKey(0)})
jit_model = jax.jit(model.apply)
jit_model(params, x = jnp.ones((dim,)))
jamesheald commented
Ah great. I wasn't aware of jax.ensure_compile_time_eval, that seems like a useful function. Thanks!