without_jit=True for already jitted functions
fabiannagel opened this issue · 0 comments
In most JAX-based implementations, jit
is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.
I noticed that @chex.variants(with_jit=True, without_jit=True)
is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.
In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted()
is executed twice with the jitted fn
, resulting in two tracer outputs.
@chex.variants(with_jit=True, without_jit=Truue)
def test_variant_pre_jitted(self):
@jit
def fn(x, y):
print("Tracing fn")
return x + y
var_fn = self.variant(fn)
self.assertEqual(var_fn(1, 2), 3)
self.assertEqual(var_fn(3, 4), 7)
self.assertEqual(var_fn(5, 6), 11)
Of course, omitting @jit
will lead to the expected behavior. However, when more complex implementations already make use of jit
, variants do not make sense anymore, sadly.
My case is the latter and I only see the option of implementing a model-wide use_jit
flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.
I'm aware this could well be a limitation of JAX and jit
itself rather than chex. In that case, I think an error when jitted code is passed to variant()
would make this more transparent.