google-deepmind/jraph

`GraphsTuple` vs. batched `GraphsTuple` in `get_get_graph_padding_mask`

Opened this issue · 0 comments

Hey,

thanks for this great package. I realized that jraph.get_graph_padding_mask returns jnp.array([False]) when applied to a non-batched GraphsTuple.

I am wondering why this is? Would it be possible to check the length of jraph.GraphsTuple.n_node and return jnp.array([True]) in case it has length 1? Or does this break with some assumptions somewhere else in jraph. Below you find a minimal example.

Thanks and best,
Thorben

import jraph
import jax.numpy as jnp


def get_number_of_graphs(graph):
    """ 
    This function works for GraphsTuple and batched GraphsTuple. 
    For the latter the padding graph(s) are also counted.
    """
    return len(graph.n_node)


def is_batched_bool(graph):
    num_graphs = get_number_of_graphs(graph)
    if num_graphs <= 1:
        return False
    else:
        return True


def modified_get_graph_padding_mask(graph):
    if is_batched_bool(graph) is True:
        return jraph.get_graph_padding_mask(graph)
    else:
        return jnp.array([True])

    
graph = jraph.GraphsTuple(
    nodes=dict(
        atomic_numbers=jnp.ones((10, )),
        positions=jnp.ones((10, 3)),
        z=jnp.ones((10, 3))
    ),
    edges=None,
    receivers=jnp.arange(10),
    senders=jnp.arange(10),
    globals=dict(),
    n_node=jnp.array([10]),
    n_edge=jnp.array([10])
)


print('On unbatched graph')
print('Original version: graph_mask =', jraph.get_graph_padding_mask(graph))
print('Modified version: graph_mask =', modified_get_graph_padding_mask(graph))

batched_graph_iterator = jraph.dynamically_batch([graph, graph], n_node=11, n_edge=11, n_graph=3)
batched_graph = next(batched_graph_iterator)
print('\nOn batched graph')
print('Original version: graph_mask =', jraph.get_graph_padding_mask(batched_graph))
print('Modified version: graph_mask =', modified_get_graph_padding_mask(batched_graph))