/Global_contrast_CUDA

Global contrast operator written in CUDA for PyTorch.

Primary LanguageCuda

segment_mm

This is an operator written in CUDA for PyTorch.

To compute the global contrast among each pixel. This work is inspired by Xinyu Zhang and referred to the paper from Mingming Cheng Global Contrast Based Salient Region Detection.

def forward(
    input,
)
    """
    Params:
    ------
        input: float tensor, shape (B, C, W, H)
    
    Returns:
    ------
        output: float tensor, shape (B, 1, W, H)
    """

def backward(
    grad,
    input
):
    """
    Params:
    ------
        grad: float tensor, shape (B, 1, W, H)
        input: float tensor, shape (B, C, W, H)
    
    Returns:
    ------
        d_input: float tensor, shape (B, C, W, H)
    """

Test code

Installation

./install.sh

Script

./test.sh <loop_time>

Ipython

>>>import torch 
>>>from global_contrast import GlobalContrast
>>>x = torch.rand((20, 16, 336, 336)).cuda()
>>>model = GlobalContrast()
>>>y = model(x)
>>>y.size
tensor([20, 1, 336, 336], device='cuda:0')

Benchmark

forward(ms) backward(ms)
naive - -
cuda 53.693 179.787

Reference