/disttgl

Primary LanguagePythonApache License 2.0Apache-2.0

DistTGL: Distributed Memory-based Temporal Graph Neural Network Training

Overview

This repo is the open-sourced code for our work DistTGL: Distributed Memory-based Temporal Graph Neural Network Training.

Requirements

  • python >= 3.8.13
  • pytorch >= 1.11.0
  • pandas >= 1.1.5
  • numpy >= 1.19.5
  • dgl >= 0.8.2
  • pyyaml >= 5.4.1
  • tqdm >= 4.61.0
  • pybind11 >= 2.6.2
  • g++ >= 7.5.0
  • openmp >= 201511

Dataset

Download the dataset using the down.sh script. Note that we do not release the Flights dataset due to license restriction. You can download the Flight dataset directly from this link. Download time: GDELT/edges.csv 612KB/s eta 3h 41m GDELT/ec_edge_class.pt.pt 403 forbidden. many of them are 403 fobidden.

try to use wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edge_features.pt

Mini-batch Preparation

DistTGL pre-compute mini-batches before training. To ensure fast mini-batch loading from disk, please store the mini-batches in a fast SSD. In the paper, we use RAID0 array of two NVMe SSDs.

We first compile the sampler from TGL by

python setup.py build_ext --inplace

Then generate the mini-batches using

python gen_minibatch.py --data <DatasetName> --gen_eval --minibatch_parallelism <NumberofMinibatchParallelism>

where <NumberofMinibatchParallelism> is the i in (i x j x k) in the paper.

python gen_minibatch.py --data WIKI --gen_eval --minibatch_parallelism 2

Run

On each machine, execute

torchrun --nnodes=<NumberofMachines> --nproc_per_node=<NumberofGPUPerMachine> --rdzv_id=<JobID> --rdzv_backend=c10d --rdzv_endpoint=<HostNodeIPAddress>:<HostNodePort> train5.py --data <DatsetName> --group <NumberofGroupParallelism>

where <NumberofGroupParallelism> is the k in (i x j x k) in the paper.

torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=0 --rdzv_backend=c10d train.py --data WIKI --group 0

Error

File "/root/share/disttgl/train.py", line 137, in <module>
    with open('minibatches/{}_stats.pkl'.format(args.data), 'rb') as f:
FileNotFoundError: [Errno 2] No such file or directory: 'minibatches/WIKI_stats.pkl'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 169111) of binary: /opt/conda/bin/python
Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==1.13.1', 'console_scripts', 'torchrun')())
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 762, in main
    run(args)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-12-20_13:01:35
  host      : mcnode31.maas
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 169112)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-12-20_13:01:35
  host      : mcnode31.maas
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 169111)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.