This is the official codebase of the paper
A*Net: A Scalable Path-based Reasoning Approach for Knowledge Graphs
Zhaocheng Zhu*, Xinyu Yuan*, Mikhail Galkin, Sophie Xhonneux, Ming Zhang, Maxime Gazeau, Jian Tang
A*Net is a scalable path-based method for knowledge graph reasoning. Inspired by the classical A* algorithm, A*Net learns a neural priority function to select important nodes and edges at each iteration, which significantly reduces time and memory footprint for both training and inference.
A*Net is the first path-based method that scales to ogbl-wikikg2 (2.5M entities, 16M triplets). It also enjoys the advantages of path-based methods such as inductive capacity and interpretability.
astarnet.illustration.mp4
This codebase contains implementation for A*Net and its predecessor NBFNet.
The dependencies can be installed via either conda or pip. A*Net is compatible with 3.7 <= Python <= 3.10 and PyTorch >= 1.13.0.
conda install pytorch cudatoolkit torchdrug pytorch-sparse -c pytorch -c pyg -c milagraph
conda install ogb easydict pyyaml -c conda-forge
pip install torch torchdrug torch-sparse
pip install ogb easydict pyyaml
To run A*Net, use the following command. The argument -c
specifies the experiment
configuration file, which includes the dataset, model architecture, and
hyperparameters. You can find all configuration files in config/.../*.yaml
.
All the datasets will be automatically downloaded in the code.
python script/run.py -c config/transductive/fb15k237_astarnet.yaml --gpus [0]
For each experiment, you can specify the number of GPU via the argument --gpus
.
You may use --gpus null
to run A*Net on a CPU, though it would be very slow.
To run A*Net with multiple GPUs, launch the experiment with torchrun
torchrun --nproc_per_node=4 script/run.py -c config/transductive/fb15k237_astarnet.yaml --gpus [0,1,2,3]
For the inductive setting, there are 4 different splits for each dataset. You need
to additionally specify the split version with --version v1
.
A*Net supports visualization of important paths for its predictions. With a trained model, you can visualize the important paths with the following line. Please replace the checkpoint with your own path.
python script/visualize.py -c config/knowledge_graph/fb15k237_astarnet_visualize.yaml --checkpoint /path/to/astarnet/experiment/model_epoch_20.pth
A*Net is designed to be general frameworks for knowledge graph reasoning. This
means you can parameterize it with a broad range of message-passing GNNs. To do so,
just implement a convolution layer in reasoning/layer.py
and register it with
@R.register
. The GNN layer is expected to have the following member functions
def message(self, graph, input):
...
return message
def aggregate(self, graph, message):
...
return update
def combine(self, input, update):
...
return output
where the arguments and the return values are
graph
(data.PackedGraph): a batch of subgraphs selected by A*Net, withgraph.query
being the query embeddings of shape(batch_size, input_dim)
.input
(Tensor): node representations of shape(graph.num_node, input_dim)
.message
(Tensor): messages of shape(graph.num_edge, input_dim)
.update
(Tensor): aggregated messages of shape(graph.num_node, *)
.output
(Tensor): output representations of shape(graph.num_node, output_dim)
.
To support the neural priority function in A*Net, we need to additionally provide an interface for computing messages
def compute_message(self, node_input, edge_input):
...
return msg_output
You may refer to the following tutorials of TorchDrug
-
The code is stuck at the beginning of epoch 0.
This is probably because the JIT cache is broken. Try
rm -r ~/.cache/torch_extensions/*
and run the code again.
If you find this project useful, please consider citing the following paper
@article{zhu2022scalable,
title={A*Net: A Scalable Path-based Reasoning Approach for Knowledge Graphs},
author={Zhu, Zhaocheng and Yuan, Xinyu and Galkin, Mikhail and Xhonneux, Sophie and Zhang, Ming and Gazeau, Maxime and Tang, Jian},
journal={arXiv preprint arXiv:2206.04798},
year={2022}
}