replacing jax.vmap with objax.Vectorize
mathDR opened this issue · 3 comments
I am trying to replace using jax.vmap
with objax.Vectorize
(since mixing these types of operations can cause problems ) and am running into issues. When the function I am calling Vectorize
on is an objax
module, everything works fine (examples are great in the test code).
My issue is in trying to replace something like the following:
X0 = objax.random.uniform((5, 3))
T = jax.vmap(jnp.diag)(X0)
where this functionality does something similar to tf.diag_part
and returns a jnp.array
of size (5,3,3)
(i.e. 5 3x3 matrices having diagonals of the corresponding X0
).
Running
objax.Vectorize(jnp.diag)(X0)
fails with
ValueError: You must supply the VarCollection used by the function f
So my question: how can I pass in the VarCollection
for this example? Is that possible?
objax.Vectorize
does require variable collection if you vectorizing callable. In your case jnp.diag
does not use any variables, thus you can just pass empty variable collection which is created as objax.VarCollection()
.
On top of it, by default objax.Vectorize
will try to vectorize function over all arguments. jnp.diag
has two arguments, first one is the input, second one k
is an integer indication diagonal. So you only can vectorize jnp.diag
over first argument. This could be either achieved by using lambda x: jnp.diag(x, k=0)
instead of jnp.diag
or by providing extra batch_axis
argument to objax.Vectorize
.
Below are two versions of the code which demonstrate this.
Here is the one version of the code which uses lambda:
x0 = objax.random.uniform((5, 3))
vec_diag = objax.Vectorize(lambda x: jnp.diag(x, k=0), objax.VarCollection())
t = vec_diag(x0)
Here is another version which uses batch_axis
arguments and passes k
to vectorized function:
x0 = objax.random.uniform((5, 3))
vec_diag2 = objax.Vectorize(jn.diag, objax.VarCollection(), batch_axis=(0, None))
t = vec_diag2(x0, 0)
Let me know if you have any other questions
I'm closing this for now. Feel free to re-open if there are any follow up questions.
Thanks! This solved it. I guess I was surprised that I had to explicitly add default parameters (k=0). Appreciate the help!