/Early-Bird-GCN

[AAAI 2022] Early-Bird GCNs: Graph-Network Co-Optimization Towards More Efficient GCN Training and Inference via Drawing Early-Bird Lottery Tickets

Primary LanguagePythonApache License 2.0Apache-2.0

Early-Bird GCNs: Graph-Network Co-Optimization Towards More Efficient GCN Training and Inference via Drawing Early-Bird Lottery Tickets

License: Apache 2.0

Haoran You, Zhihan Lu, Zijian Zhou, Yonggan Fu, Yingyan Lin

Accepted by AAAI 2022. More Info: [ Paper | Appendix | Slide | Poster | Video | Github ]

Install the conda environment

conda env create -f env.yaml
pip install torch_geometric
pip uninstall torch-scatter
pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip uninstall torch-sparse
pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip uninstall torch-cluster
pip install torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html
pip uninstall torch-spline-conv
pip install torch-spline-conv==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html

Run the code

  • To pretrain, prune, retrain separately:

    • Pretrain the GCN:

    • python3 pytorch_train.py --epochs 10 --dataset Cora
    • Prune the pretrained GCN using different prune method:

    • python3 pytorch_prune_weight_iterate.py --ratio_graph 60 --ratio_weight 60
      # or
      python3 pytorch_prune_weight_cotrain.py --ratio_graph 60 --ratio_weight 60
      # or 
      python3 pytorch_prune_weight_first.py --ratio_graph 60 --ratio_weight 60
    • Retrain the pruned GCN to recover the accuracy:

    • python3 pytorch_retrain_with_graph.py --load_path prune_weight_iterate/model.pth.tar
  • By using functions like os.system("python3 "+"pytorch_train.py"+" --epochs "+str(1)+" --dataset "+str(args.dataset)) in Python, we are able to run the above process in one file, and can stop automatically when found jointEB ticket:

    • python run_threshold_jointEB.py --times 100 --epochs 1 --dataset Cora --ratio_graph 20 --ratio_weight 50
  • Futhermore, we use a script to run all experiment settings (like different pruning ratio of graph and pruning ratio of weights) automatically:

    • python test_jointEB_dist_traj.py
      

Run the code with Sparse Graph, SGCN-deep stuff

More details are coming soon.