[RFC] `cugraph`-based neighborhood sampling
rusty1s opened this issue ยท 0 comments
๐ The feature, motivation and pitch
GPU-based neighborhood sampling can accelerate mini-batch creation for graphs that fit into GPU memory.
Currently, the (solely) CPU-based sampling interface inside PyG looks as follows:
template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node, const vector<int64_t> num_neighbors)
The PyG routine expects:
(colptr, row)
: CSC/CSR representation of the graphinput_node
: The seed nodes for which to sample neighborsnum_neighbors
: A list of neighbors to sample in each layerreplace
: Sample without or with replacementdirected
: Whether sampled edges are directed or not. If not, we extract the full subgraph of sampled nodes.
It returns (re-labeled) row
and col
vectors of the sampled subgraph/adjacency matrix, as well as output_node_id
and output_edge_id
of the sampled nodes/edges to perform feature fetching in a later stage.
On the other side, the sampling interface inside cugraph
looks as follows:
template <typename graph_t>
std::tuple<rmm::device_uvector<typename graph_t::edge_type>,
rmm::device_uvector<typename graph_t::vertex_type>>
sample_neighbors_adjacency_list(raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_t const& graph,
typename graph_t::vertex_type const* ptr_d_start,
size_t num_start_vertices,
size_t sampling_size,
ops::gnn::graph::SamplingAlgoT sampling_algo)
The major difference seems to be that cugraph
performs sampling for 1-hop, while PyG supports multi-hop sampling (which can be fixed easily by just calling the cugraph
routine multiple times) [to be confirmed by @pyg-team/nvidia-team].
Roadmap
For integrating GPU-based sampling inside PyG, we thus need to:
- Integrate
torch-sparse
neighorhood sampling interface inpyg-lib
cugraph
as a dependency insidepyg-lib
- Call the
cugraph
-based sampling routine inside the GPU-based dispatcher - Integrate changes on PyG side