google-deepmind/jraph

Replicating GAT with CORA dataset

Steboss89 opened this issue · 0 comments

Hello,

Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using jraph . I started from your example notebook, implementing GAT, along with add_self_edges_fn:

def add_self_edges_fn(receivers: jnp.ndarray,
                      senders: jnp.ndarray,
                      total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    r"""Adds self edges. Assumes self edges are not in the graph yet."""
    receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
    senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
    return receivers, senders
  
def GAT(attention_query_fn: Callable,
        attention_logit_fn: Callable,
        node_update_fn: Optional[Callable] = None,
        add_self_edges: bool = True) -> Callable:
    r""" Main GAT function"""
    # pylint: disable=g-long-lambda
    if node_update_fn is None:
        # By default, apply the leaky relu and then concatenate the heads on the
        # feature axis.
        node_update_fn = lambda x: jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))

    def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError('GAT requires node features')

        nodes = attention_query_fn(nodes)
        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]

        if add_self_edges:
            receivers, senders = add_self_edges_fn(receivers, senders,
                                                    total_num_nodes)
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        att_softmax_logits = attention_logit_fn(sent_attributes,
                                                received_attributes, edges)

        att_weights = jraph.segment_softmax(
            att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)

        messages = sent_attributes * att_weights

        nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)

        nodes = node_update_fn(nodes)

        return graph._replace(nodes=nodes)

    return _ApplyGAT


def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """ Define GAT algorithm to run 
    Parameters
    ----------
    graph: jraph.GraphsTupe, input network to be processed 
    
    Return 
    -------
    jraph.GraphsTuple updated node graph
    """

    def _attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
                            edges: jnp.ndarray) -> jnp.ndarray:
        del edges
        x = jnp.concatenate((sender_attr, receiver_attr), axis=-1)
        return jax.nn.leaky_relu(hk.Linear(1)(x))

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=None,
        add_self_edges=True)
    graph = gn(graph)

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=hk.Linear(2),
        add_self_edges=True)
    graph = gn(graph)
    return graph

Then, after defining the main GAT, I run the training as:


def run_cora(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
  r""" Run training on CORA dataset """
  cora_graph = cora_ds[0]['input_graph']
  labels = cora_ds[0]['target']
  params = network.init(jax.random.PRNGKey(42), cora_graph)

  @jax.jit
  def predict(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.argmax(decoded_graph.nodes, axis=1)

  @jax.jit
  def prediction_loss(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    preds = jnp.argmax(decoded_graph.nodes, axis=1)
    # We interpret the decoded nodes as a pair of logits for each node.
    loss = compute_bce_with_logits_loss(preds, labels)
    return loss#, preds

  opt_init, opt_update = optax.adam(5e-4)
  opt_state = opt_init(params)

  @jax.jit
  def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
    """Returns updated params and state."""
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)

  for step in range(num_steps):
    if step%100==0:
        print(f"step {step} accuracy {accuracy(params).item():.2f}")
    params, opt_state = update(params, opt_state)

  return predict(params)

The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you