pad_with_graphs written with numpy
GrantMcConachie opened this issue · 1 comments
Hello! I was wondering if there is any particular reason that the pad_with_graphs
function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.
I was also wondering this in the past. Another piece where numpy is used/can be used is batch
/batch_np
. IIRC the numpy version was much faster in some situations, hinting that there is some unwanted jit compilation happening when using jnp functions. That might also be the case if the numpy functions in pad_with_graphs
were replaced with jnp functions. To me it seems that in jax.numpy.sum some jit compiling is always happening, which is not what you want to happen if array sizes change. It would be nice to have some clarification on this though.