google-deepmind/jraph

How to create custom batches?

MichaelMMeskhi opened this issue · 0 comments

Assume I have multiple graphs (nodes, senders, receivers and graph label). How would one create batches that would work with jax.vmap?