`GraphsTuple` vs. batched `GraphsTuple` in `get_get_graph_padding_mask`
Opened this issue · 0 comments
thorben-frank commented
Hey,
thanks for this great package. I realized that jraph.get_graph_padding_mask
returns jnp.array([False])
when applied to a non-batched GraphsTuple
.
I am wondering why this is? Would it be possible to check the length of jraph.GraphsTuple.n_node
and return jnp.array([True])
in case it has length 1? Or does this break with some assumptions somewhere else in jraph
. Below you find a minimal example.
Thanks and best,
Thorben
import jraph
import jax.numpy as jnp
def get_number_of_graphs(graph):
"""
This function works for GraphsTuple and batched GraphsTuple.
For the latter the padding graph(s) are also counted.
"""
return len(graph.n_node)
def is_batched_bool(graph):
num_graphs = get_number_of_graphs(graph)
if num_graphs <= 1:
return False
else:
return True
def modified_get_graph_padding_mask(graph):
if is_batched_bool(graph) is True:
return jraph.get_graph_padding_mask(graph)
else:
return jnp.array([True])
graph = jraph.GraphsTuple(
nodes=dict(
atomic_numbers=jnp.ones((10, )),
positions=jnp.ones((10, 3)),
z=jnp.ones((10, 3))
),
edges=None,
receivers=jnp.arange(10),
senders=jnp.arange(10),
globals=dict(),
n_node=jnp.array([10]),
n_edge=jnp.array([10])
)
print('On unbatched graph')
print('Original version: graph_mask =', jraph.get_graph_padding_mask(graph))
print('Modified version: graph_mask =', modified_get_graph_padding_mask(graph))
batched_graph_iterator = jraph.dynamically_batch([graph, graph], n_node=11, n_edge=11, n_graph=3)
batched_graph = next(batched_graph_iterator)
print('\nOn batched graph')
print('Original version: graph_mask =', jraph.get_graph_padding_mask(batched_graph))
print('Modified version: graph_mask =', modified_get_graph_padding_mask(batched_graph))