/torch-image-binarization

`torch.compile`-compatible image binarization algorithm implemented in PyTorch.

Primary LanguagePythonApache License 2.0Apache-2.0

torch-image-binarization

test input image test output image

Local image binarization algorithms implemented in Pytorch. Includes the Otsu thresholding algorithm and the algorithm by Su et al. (which is an extension of the Otsu algorithm). The algorithms are implemented to optimize performance. With torch.compile it is approximately 4125x faster than the CPU-based version.

This was written to test the benefits of torch.compile. For a log of performance improvements, refer to optimizations.ipynb. The notebook expects triton==3.0.0 and torch==2.3.0.

Install

Install using pip:

pip install https://github.com/nopperl/torch-image-binarization

The package requires torch>=2.2.0 and optionally triton>=2.21, which need to be installed seperately, e.g. using pip:

pip install torch torchvision triton

Usage

Read an image:

from torchvision.io import ImageReadMode, read_image 
img = read_image("test_image.png", mode=ImageReadMode.GRAY)

Binarize the image:

from torck_image_binarization.thresholding import su
su(img)

For more information, refer to torch_image_binariztion/main.py

Benchmark

To show the performance gains of the CUDA-based PyTorch implementation over the CPU-based NumPy implementation and the benefits of torch.compile, the runtime is measured across different input image sizes. For more information, refer to optimizations.ipynb.

Results:

[------------------------------------ su -------------------------------------]
                                                                    |   runtime
1 threads: --------------------------------------------------------------------
      numpy                                                         | 3548992.0
      su(img)                                                       |   10426.5
      torch.compile(su)(img)                                        |    1333.6
      torch.compile(su, mode='reduce-overhead')(img)                |     858.8
      torch.compile(su, mode='max-autotune')(img)                   |     859.6
      torch.compile(su, dynamic=True)(img)                          |     859.7
      torch.compile(su, dynamic=True, mode='reduce-overhead')(img)  |     860.0
      torch.compile(su, dynamic=True, mode='max-autotune')(img)     |     860.0

Times are in microseconds (us).