/torch-linear-assignment

Batch computation of the linear assignment problem on GPU.

Primary LanguageCudaApache License 2.0Apache-2.0

Batch linear assignment for PyTorch

PyPI version Build Status Downloads License

Batch computation of the linear assignment problem on GPU.

Install

Build and install via PyPI (source distribution):

pip install torch-linear-assignment

Build and install from Git repository:

pip install .

When building with CUDA, make sure NVCC has the same CUDA version as PyTorch. You can choose CUDA version by

export PATH=/usr/local/cuda-<version>/bin:"$PATH"

If you need custom C++ compiler, use the following command:

CXX=<c++-compiler> CC=<gcc-compiler> pip install .

If you get a torch-not-found error, try the following command:

pip install --upgrade pip wheel setuptools
python -m pip install .

Example

import torch
from torch_linear_assignment import batch_linear_assignment

cost = torch.tensor([
    8, 4, 7,
    5, 2, 3,
    9, 6, 7,
    9, 4, 8
]).reshape(1, 4, 3).cuda()

assignment = batch_linear_assignment(cost)
print(assignment)

The output is:

tensor([[ 0,  2, -1,  1]], device='cuda:0')

To get indices in the SciPy's format:

from torch_linear_assignment import assignment_to_indices

row_ind, col_ind = assignment_to_indices(assignment)
print(row_ind)
print(col_ind)

The output is:

tensor([[0, 1, 3]], device='cuda:0')
tensor([[0, 2, 1]], device='cuda:0')

Citation

The code was originally developed for the HoTPP Benchmark. If you use this code in your research project, please cite one of the following papers:

@article{karpukhin2024hotppbenchmark,
  title={HoTPP Benchmark: Are We Good at the Long Horizon Events Forecasting?},
  author={Karpukhin, Ivan and Shipilov, Foma and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2406.14341},
  year={2024},
  url ={https://arxiv.org/abs/2406.14341}
}

@article{karpukhin2024detpp,
  title={DeTPP: Leveraging Object Detection for Robust Long-Horizon Event Prediction},
  author={Karpukhin, Ivan and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2408.13131},
  year={2024},
  url ={https://arxiv.org/abs/2408.13131}
}