Replicating GAT with CORA dataset
Steboss89 opened this issue · 0 comments
Steboss89 commented
Hello,
Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using jraph
. I started from your example notebook, implementing GAT, along with add_self_edges_fn
:
def add_self_edges_fn(receivers: jnp.ndarray,
senders: jnp.ndarray,
total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
r"""Adds self edges. Assumes self edges are not in the graph yet."""
receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
return receivers, senders
def GAT(attention_query_fn: Callable,
attention_logit_fn: Callable,
node_update_fn: Optional[Callable] = None,
add_self_edges: bool = True) -> Callable:
r""" Main GAT function"""
# pylint: disable=g-long-lambda
if node_update_fn is None:
# By default, apply the leaky relu and then concatenate the heads on the
# feature axis.
node_update_fn = lambda x: jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))
def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
"""Applies a Graph Attention layer."""
nodes, edges, receivers, senders, _, _, _ = graph
try:
sum_n_node = nodes.shape[0]
except IndexError:
raise IndexError('GAT requires node features')
nodes = attention_query_fn(nodes)
total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
if add_self_edges:
receivers, senders = add_self_edges_fn(receivers, senders,
total_num_nodes)
sent_attributes = nodes[senders]
received_attributes = nodes[receivers]
att_softmax_logits = attention_logit_fn(sent_attributes,
received_attributes, edges)
att_weights = jraph.segment_softmax(
att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)
messages = sent_attributes * att_weights
nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)
nodes = node_update_fn(nodes)
return graph._replace(nodes=nodes)
return _ApplyGAT
def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
""" Define GAT algorithm to run
Parameters
----------
graph: jraph.GraphsTupe, input network to be processed
Return
-------
jraph.GraphsTuple updated node graph
"""
def _attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
edges: jnp.ndarray) -> jnp.ndarray:
del edges
x = jnp.concatenate((sender_attr, receiver_attr), axis=-1)
return jax.nn.leaky_relu(hk.Linear(1)(x))
gn = GAT(
attention_query_fn=lambda n: hk.Linear(8)(n),
attention_logit_fn=_attention_logit_fn,
node_update_fn=None,
add_self_edges=True)
graph = gn(graph)
gn = GAT(
attention_query_fn=lambda n: hk.Linear(8)(n),
attention_logit_fn=_attention_logit_fn,
node_update_fn=hk.Linear(2),
add_self_edges=True)
graph = gn(graph)
return graph
Then, after defining the main GAT, I run the training as:
def run_cora(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
r""" Run training on CORA dataset """
cora_graph = cora_ds[0]['input_graph']
labels = cora_ds[0]['target']
params = network.init(jax.random.PRNGKey(42), cora_graph)
@jax.jit
def predict(params: hk.Params) -> jnp.ndarray:
decoded_graph = network.apply(params, cora_graph)
return jnp.argmax(decoded_graph.nodes, axis=1)
@jax.jit
def prediction_loss(params: hk.Params) -> jnp.ndarray:
decoded_graph = network.apply(params, cora_graph)
preds = jnp.argmax(decoded_graph.nodes, axis=1)
# We interpret the decoded nodes as a pair of logits for each node.
loss = compute_bce_with_logits_loss(preds, labels)
return loss#, preds
opt_init, opt_update = optax.adam(5e-4)
opt_state = opt_init(params)
@jax.jit
def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
"""Returns updated params and state."""
g = jax.grad(prediction_loss)(params)
updates, opt_state = opt_update(g, opt_state)
return optax.apply_updates(params, updates), opt_state
@jax.jit
def accuracy(params: hk.Params) -> jnp.ndarray:
decoded_graph = network.apply(params, cora_graph)
return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)
for step in range(num_steps):
if step%100==0:
print(f"step {step} accuracy {accuracy(params).item():.2f}")
params, opt_state = update(params, opt_state)
return predict(params)
The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you