google-deepmind/chex

without_jit=True for already jitted functions

fabiannagel opened this issue · 0 comments

In most JAX-based implementations, jit is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.

I noticed that @chex.variants(with_jit=True, without_jit=True) is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.

In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted() is executed twice with the jitted fn, resulting in two tracer outputs.

@chex.variants(with_jit=True, without_jit=Truue)
def test_variant_pre_jitted(self):
  @jit
  def fn(x, y):
    print("Tracing fn")
    return x + y

  var_fn = self.variant(fn)
  self.assertEqual(var_fn(1, 2), 3)
  self.assertEqual(var_fn(3, 4), 7)
  self.assertEqual(var_fn(5, 6), 11)

Of course, omitting @jit will lead to the expected behavior. However, when more complex implementations already make use of jit, variants do not make sense anymore, sadly.

My case is the latter and I only see the option of implementing a model-wide use_jit flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.

I'm aware this could well be a limitation of JAX and jit itself rather than chex. In that case, I think an error when jitted code is passed to variant() would make this more transparent.