google-deepmind/jraph

Problem with ValueError

Closed this issue ยท 2 comments

Hi everyone ๐Ÿ‘‹ ,

First of all, thank you for such a great work with this library!
I'm having some trouble to understand and create my own GNN. I'm trying to do some sort of graph classification. I followed this tutorial, and now I'm trying to apply this network example to my own data.

This is an example of a graph:

GraphsTuple(nodes=DeviceArray([[0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
             [1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
             [2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
             [3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
             [4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
             [5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
             [6.0000000e+00, 0.0000000e+00, 1.0000000e+00]],            dtype=float32), edges=DeviceArray([1, 1, 1, 1, 2, 2], dtype=int32), receivers=DeviceArray([3, 1, 2, 0, 4, 3], dtype=int32), senders=DeviceArray([5, 6, 6, 6, 3, 0], dtype=int32), globals=None, n_node=DeviceArray([7], dtype=int32), n_edge=DeviceArray([6], dtype=int32))

When I try to initialize the network, it outputs a ValueError: ValueError: data type <class 'numpy.int32'> not inexact.
This error comes from the last line of this code block (net.init(jax.random.PRNGKey(42), graph)):

def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
  """Training loop."""

  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(net_fn))
  # Get a candidate graph and label to initialize the network.
  graph = dataset[0]['input_graph']
  
  # Initialize the network.
  params = net.init(jax.random.PRNGKey(42), graph)

Graph dataset[0]['input_graph'] is the one shown above.

After reading the docs, some stackoverflow threads, and searching in Google, I haven't found anything to either understand or resolve this error.

I have some hesitation about the data types of the GraphsTuple. I tried to change the int32 type to int native type of python (as Nate says in this stackoverflow thread), and I couldn't change the types. Also, it may be the float32 type of the nodes field?

I submit this issue as I haven't found any useful resource to help me debug this error. I hope there is no inconvience to do so, and help others to resolve this error faster.

Thank you!

Hi @ademait,

I can't seem to reproduce your bug here, would it be possible that you send the full error message? It would also be helpful if you share how you defined dataset and net.


Here's my attempt at reproducing your bug.

  • Defining the graph
graph = jraph.GraphsTuple(
            nodes=jnp.asarray([
            [0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
            [1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
            [2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
            [3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
            [4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
            [5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
            [6.0000000e+00, 0.0000000e+00, 1.0000000e+00]]),
            edges=jnp.asarray([1, 1, 1, 1, 2, 2]), 
            receivers=jnp.asarray([3, 1, 2, 0, 4, 3]), 
            senders=jnp.asarray([5, 6, 6, 6, 3, 0]), 
            globals=None, 
            n_node=jnp.asarray([7]), 
            n_edge=jnp.asarray([6]))
dataset = [{'input_graph': graph}]  # Dummy dataset
  • Defining arbitrary net
def net_fn(graph: jraph.GraphsTuple) -> hk.Params:
  """Network function."""

  net = jraph.GraphConvolution(
    update_node_fn = lambda x: jax.nn.relu(hk.Linear(100)(x)),
    add_self_edges=True,
  )
  return net(graph)
  • Running train
def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
  """Training loop."""

  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(net_fn))
  # Get a candidate graph and label to initialize the network.
  graph = dataset[0]['input_graph']
  
  # Initialize the network.
  params = net.init(jax.random.PRNGKey(42), graph)

Running train works without any errors for me.

Hi @salfaris thanks for the answer.

It seems it was how I declared the variables with int32 types. Now the error is gone. ๐Ÿ‘