google-deepmind/jraph

pad_with_graphs written with numpy

GrantMcConachie opened this issue · 1 comments

Hello! I was wondering if there is any particular reason that the pad_with_graphs function uses the numpy library rather than the jax.numpy library. It looks like every numpy function in there can just be replaced with jax.numpy without any issues, but I could be missing something.

I was also wondering this in the past. Another piece where numpy is used/can be used is batch/batch_np. IIRC the numpy version was much faster in some situations, hinting that there is some unwanted jit compilation happening when using jnp functions. That might also be the case if the numpy functions in pad_with_graphs were replaced with jnp functions. To me it seems that in jax.numpy.sum some jit compiling is always happening, which is not what you want to happen if array sizes change. It would be nice to have some clarification on this though.