Problem with ValueError
Closed this issue ยท 2 comments
Hi everyone ๐ ,
First of all, thank you for such a great work with this library!
I'm having some trouble to understand and create my own GNN. I'm trying to do some sort of graph classification. I followed this tutorial, and now I'm trying to apply this network example to my own data.
This is an example of a graph:
GraphsTuple(nodes=DeviceArray([[0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
[1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
[2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
[3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
[4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
[5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
[6.0000000e+00, 0.0000000e+00, 1.0000000e+00]], dtype=float32), edges=DeviceArray([1, 1, 1, 1, 2, 2], dtype=int32), receivers=DeviceArray([3, 1, 2, 0, 4, 3], dtype=int32), senders=DeviceArray([5, 6, 6, 6, 3, 0], dtype=int32), globals=None, n_node=DeviceArray([7], dtype=int32), n_edge=DeviceArray([6], dtype=int32))
When I try to initialize the network, it outputs a ValueError: ValueError: data type <class 'numpy.int32'> not inexact
.
This error comes from the last line of this code block (net.init(jax.random.PRNGKey(42), graph)
):
def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
"""Training loop."""
# Transform impure `net_fn` to pure functions with hk.transform.
net = hk.without_apply_rng(hk.transform(net_fn))
# Get a candidate graph and label to initialize the network.
graph = dataset[0]['input_graph']
# Initialize the network.
params = net.init(jax.random.PRNGKey(42), graph)
Graph dataset[0]['input_graph']
is the one shown above.
After reading the docs, some stackoverflow threads, and searching in Google, I haven't found anything to either understand or resolve this error.
I have some hesitation about the data types of the GraphsTuple
. I tried to change the int32
type to int
native type of python (as Nate says in this stackoverflow thread), and I couldn't change the types. Also, it may be the float32
type of the nodes
field?
I submit this issue as I haven't found any useful resource to help me debug this error. I hope there is no inconvience to do so, and help others to resolve this error faster.
Thank you!
Hi @ademait,
I can't seem to reproduce your bug here, would it be possible that you send the full error message? It would also be helpful if you share how you defined dataset
and net
.
Here's my attempt at reproducing your bug.
- Defining the graph
graph = jraph.GraphsTuple(
nodes=jnp.asarray([
[0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
[1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
[2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
[3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
[4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
[5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
[6.0000000e+00, 0.0000000e+00, 1.0000000e+00]]),
edges=jnp.asarray([1, 1, 1, 1, 2, 2]),
receivers=jnp.asarray([3, 1, 2, 0, 4, 3]),
senders=jnp.asarray([5, 6, 6, 6, 3, 0]),
globals=None,
n_node=jnp.asarray([7]),
n_edge=jnp.asarray([6]))
dataset = [{'input_graph': graph}] # Dummy dataset
- Defining arbitrary net
def net_fn(graph: jraph.GraphsTuple) -> hk.Params:
"""Network function."""
net = jraph.GraphConvolution(
update_node_fn = lambda x: jax.nn.relu(hk.Linear(100)(x)),
add_self_edges=True,
)
return net(graph)
- Running train
def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
"""Training loop."""
# Transform impure `net_fn` to pure functions with hk.transform.
net = hk.without_apply_rng(hk.transform(net_fn))
# Get a candidate graph and label to initialize the network.
graph = dataset[0]['input_graph']
# Initialize the network.
params = net.init(jax.random.PRNGKey(42), graph)
Running train
works without any errors for me.