mfschubert/sparsejac

Calling sparsejac within a jitted linen module

jamesheald opened this issue · 2 comments

Great work on this function. I noticed that sparsejac complains with a ConcretizationTypeError when I call it within a jitted linen module (if I don't jit the module call, it works fine, but that's not practical). Here is a MWE:

import flax.linen as nn
import jax
from jax import numpy as jnp
import sparsejac
from jax.experimental.sparse import BCOO

class Network(nn.Module):
    dim: int
    def setup(self):
        self.dense = nn.Dense(features = self.dim)
        self.sparsity = BCOO.fromdense(jnp.eye(self.dim))
        fn = lambda x: self.dense(x)
        self.sparse_fn = sparsejac.jacrev(fn, self.sparsity)
    def __call__(self, x):
        return self.sparse_fn(x).todense()

dim = 2
model = Network(dim)
params = model.init(x = jnp.ones((dim,)), rngs = {'params': jax.random.PRNGKey(0)})
jit_model = jax.jit(model.apply)
jit_model(params, x = jnp.ones((dim,)))

The traceback is

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 5, in setup
  File "/nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/jax/experimental/sparse/bcoo.py", line 2510, in fromdense
    return bcoo_fromdense(
  File "/nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/jax/experimental/sparse/bcoo.py", line 266, in bcoo_fromdense
    nse = _count_stored_elements(mat, n_batch, n_dense)
  File "/nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/jax/experimental/sparse/util.py", line 106, in _count_stored_elements
    return int(_count_stored_elements_per_batch(mat, n_batch, n_dense).max(initial=0))
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function apply at /nfs/nhome/live/jheald/.conda/envs/hDiffLocal/lib/python3.9/site-packages/flax/linen/module.py:1831 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2,2] = add b c
    from line <stdin>:5 (setup)

I want to call the sparse jacobian function with a linen module 
  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line <stdin>:5 (setup)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line <stdin>:5 (setup)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

I have tried to call sparsejac outside of the linen module and pass the sparsified function in as an argument but that doesn't work either:

import flax.linen as nn
import jax
from jax import numpy as jnp
import sparsejac
from jax.experimental.sparse import BCOO
from functools import partial

class Network(nn.Module):
    dim: int
    def setup(self):
        self.dense = nn.Dense(features = self.dim)
    def __call__(self, sparse_fn, x):
        return sparse_fn(x, self.dense).todense()

dim = 2

sparsity = BCOO.fromdense(jnp.eye(dim))
fn = lambda x, mod: mod(x)
sparse_fn = sparsejac.jacrev(fn, sparsity)

model = Network(dim)
params = model.init(sparse_fn = sparse_fn, x = jnp.ones((dim,)), rngs = {'params': jax.random.PRNGKey(0)})
jit_model = jax.jit(model.apply)
jit_model(params, x = jnp.ones((dim,)), sparse_fn = sparse_fn)

Any suggestions on what I should do here?

Thanks in advancd

@jamesheald, thanks! I believe the solution here would be to put the jacrev call inside the jax.ensure_compile_time_eval context manager, i.e.

import flax.linen as nn
import jax
from jax import numpy as jnp
import sparsejac
from jax.experimental.sparse import BCOO

class Network(nn.Module):
    dim: int
    def setup(self):
        self.dense = nn.Dense(features = self.dim)
        with jax.ensure_compile_time_eval():
          self.sparsity = BCOO.fromdense(jnp.eye(self.dim))
          fn = lambda x: self.dense(x)
          self.sparse_fn = sparsejac.jacrev(fn, self.sparsity)
    def __call__(self, x):
        return self.sparse_fn(x).todense()

dim = 2
model = Network(dim)
params = model.init(x = jnp.ones((dim,)), rngs = {'params': jax.random.PRNGKey(0)})
jit_model = jax.jit(model.apply)
jit_model(params, x = jnp.ones((dim,)))

Ah great. I wasn't aware of jax.ensure_compile_time_eval, that seems like a useful function. Thanks!