pytorch/vision

OOM Error with `roi_align` in PyTorch 2.1.1 but fine in PyTorch 2.0.1

davidjaw opened this issue ยท 17 comments

๐Ÿ› Describe the bug

Description

I am encountering an Out of Memory (OOM) error when using the roi_align function from PyTorch version 2.1.1 with torchvision 0.16.1. This issue does not occur with PyTorch version 2.0.1 and torchvision 0.15.2. The error happens regardless of the GPU used (tested on NVIDIA A2000 and RTX 4090).
Note that when I downgrade the PyTorch and torchvision back to 2.0.1 and 0.15.2, the function can work properly.
I am seeking assistance in understanding why this OOM error occurs in the newer versions of PyTorch and torchvision and whether this is a bug or a change in how roi_align manages memory.

Background

  • Hooking feature maps from YOLO model and using RoI-align to crop out object features.

Function

  • The object_roi_align function crops feature maps based on YOLO's object detection labels and uses RoI align to extract features of the object. The function accepts feature maps, YOLO detection labels, and several optional parameters for noise and class constraints.
def object_roi_align(feature, targets, nc=None, target_size=7, return_cls_feature=False,
                     noise_ratio=None, constrained_cls=None, constrained_cls_ratio=None):
    """
    This function will crop feature map based on the label, and use RoI align to get the feature of the object
    Args:
        feature: feature maps in shape [batch_size, channels, height, width]
        targets: [batch_idx, cls_index, cx, cy, w, h] (coords are normalized)
        target_size: size of the roi-align output (will be one of [7, 14, 28])
        noise_ratio: if not None, add noise to the target box (0 to 2)
        constrained_cls: if not None, limit the ratio of noise applied to the box of this class
        constrained_cls_ratio: if not None, limit the ratio of noise applied to the box of this class
        return_cls_feature: if True, return a list of feature maps for each class
    """
    scale = feature.shape[-1] - 1
    bidx, cidx, *box = [targets[:, i] for i in range(6)]
    box = torch.stack(box, dim=1)
    if noise_ratio is not None:
        # randomly add noise to the box coordinates to provide context
        box_wh = box[:, 2:]
        box_wh_ratio = torch.rand_like(box_wh, device=box_wh.device) * noise_ratio
        if constrained_cls is not None:
            # if targets[:, 1] == 7 (board), ratio is limited to max of 0.3
            box_wh_ratio[targets[:, 1] == constrained_cls] = torch.clamp(
                box_wh_ratio[targets[:, 1] == constrained_cls], max=constrained_cls_ratio)
        box_wh_noise = box_wh_ratio * box_wh
        box = torch.cat([box[:, :2], box_wh + box_wh_noise], dim=1)
    # RoI align function takes real coordinates instead of normalized coordinates
    box = box_convert(box, 'cxcywh', 'xyxy')
    box = torch.clamp(box, min=0, max=1) * scale
    box = torch.cat([bidx.unsqueeze(-1), box], dim=1).to(feature.device)
    roi_features = roi_align(feature, box, output_size=target_size, aligned=True)
    if return_cls_feature and nc:
        return split_feature_by_cls(roi_features, targets, nc)

    return roi_features

Error messages (A2000)

  File "/home/davidjaw/Desktop/yolov5/custom_func.py", line 322, in <listcomp>
    x = [object_roi_align(x[i], targets, self.nc, target_size=self.base_resolution[i],
  File "/home/davidjaw/Desktop/yolov5/custom_func.py", line 486, in object_roi_align
    roi_features = roi_align(feature, box, output_size=target_size, aligned=True)
  File "/home/davidjaw/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 236, in roi_align
    return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
  File "/home/davidjaw/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 168, in _roi_align
    val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]
  File "/home/davidjaw/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 62, in _bilinear_interpolate
    v1 = masked_index(y_low, x_low)
  File "/home/davidjaw/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 55, in masked_index
    return input[
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 14.96 GiB. GPU 0 has a total capacty of 11.75 GiB of which 3.89 GiB is free. Including non-PyTorch memory, this process has 7.38 GiB memory in use. Of the allocated memory 7.01 GiB is allocated by PyTorch, and 219.61 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Error message (RTX 4090)

  File "/home/jdway/Desktop/yolov5/custom_func.py", line 322, in <listcomp>
    x = [object_roi_align(x[i], targets, self.nc, target_size=self.base_resolution[i],
  File "/home/jdway/Desktop/yolov5/custom_func.py", line 486, in object_roi_align
    roi_features = roi_align(feature, box, output_size=target_size, aligned=True)
  File "/home/jdway/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 236, in roi_align
    return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
  File "/home/jdway/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 168, in _roi_align
    val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]
  File "/home/jdway/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 62, in _bilinear_interpolate
    v1 = masked_index(y_low, x_low)
  File "/home/jdway/miniconda3/envs/torch210/lib/python3.10/site-packages/torchvision/ops/roi_align.py", line 55, in masked_index
    return input[
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 14.96 GiB. GPU 0 has a total capacty of 23.64 GiB of which 14.67 GiB is free. Including non-PyTorch memory, this process has 8.72 GiB memory in use. Of the allocated memory 6.98 GiB is allocated by PyTorch, and 224.38 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
R

Versions

Versions (RTX 4090)

PyTorch version: 2.1.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.10.13 | packaged by conda-forge | (main, Oct 26 2023, 18:07:37) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 530.30.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             24
On-line CPU(s) list:                0-23
Thread(s) per core:                 1
Core(s) per socket:                 16
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              151
Model name:                         12th Gen Intel(R) Core(TM) i9-12900
Stepping:                           2
CPU MHz:                            2400.000
CPU max MHz:                        5100.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4838.40
Virtualization:                     VT-x
L1d cache:                          384 KiB
L1i cache:                          256 KiB
L2 cache:                           10 MiB
NUMA node0 CPU(s):                  0-23
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l2 cdp_l2 ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[pip3] triton==2.1.0
[conda] libopenvino-pytorch-frontend 2023.1.0             h59595ed_2    conda-forge
[conda] numpy                     1.26.2          py310hb13e2d6_0    conda-forge
[conda] torch                     2.1.1                    pypi_0    pypi
[conda] torchaudio                2.1.1                    pypi_0    pypi
[conda] torchdata                 0.7.1                    pypi_0    pypi
[conda] torchtext                 0.16.1                   pypi_0    pypi
[conda] torchvision               0.16.1                   pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

Versions (A2000)

Collecting environment information...
PyTorch version: 2.1.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A2000 12GB
Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             20
On-line CPU(s) list:                0-19
Vendor ID:                          GenuineIntel
Model name:                         12th Gen Intel(R) Core(TM) i7-12700
CPU family:                         6
Model:                              151
Thread(s) per core:                 2
Core(s) per socket:                 12
Socket(s):                          1
Stepping:                           2
CPU max MHz:                        4900.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4224.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          512 KiB (12 instances)
L1i cache:                          512 KiB (12 instances)
L2 cache:                           12 MiB (9 instances)
L3 cache:                           25 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-19
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[pip3] triton==2.1.0
[conda] numpy                     1.26.2                   pypi_0    pypi
[conda] torch                     2.1.1                    pypi_0    pypi
[conda] torchaudio                2.1.1                    pypi_0    pypi
[conda] torchdata                 0.7.1                    pypi_0    pypi
[conda] torchtext                 0.16.1                   pypi_0    pypi
[conda] torchvision               0.16.1                   pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

cc @ezyang @gchanan @zou3519 @kadeng @ptrblck

We need to understand where regression is coming from, but sounds a bit like a torchvision problem, isn't it?

Also, I wonder if this is CUDA-11.8 vs CUDA-12.1 regression (2.0.1 was shipped with 11.8 by default, but 2.1 with 12.1)

Hello,

I've created a minimal toy example to demonstrate the issue in detail.
You can find it here: https://gist.github.com/davidjaw/40bcbcf44cb3db01fd9178e193edb0de

This example relies on the ultralytics library. For context, the code runs as expected when using PyTorch version 2.0.1 and Torchvision version 0.15.2+cu118, and OOM when PyTorch 2.1.1.
I believe this setup aligns with the requirements mentioned in the original issue.

Please take a look at the gist, and let me know if you need any more information or if there's anything else I can do to assist in resolving this issue.

Thank you!

I just want to chime in an mention that I have the same problem. A very large memory allocation is attempted both on the GPU and the CPU. I observe the problem in the following environment:

Collecting environment information...
PyTorch version: 2.1.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

But everything works with:

Collecting environment information...
PyTorch version: 2.1.0.dev20230714+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

Below is a little snippet which leads to the OOM error with 2.1.0+cu118:

import torchvision
import torch
inp = torch.rand((1, 256, 48, 64))
bbox = torch.tensor([[0, 0, 0, 128, 96]]).float()
output_size = (48, 64)
scale = 12 / 384
aligned=True
torch.use_deterministic_algorithms(True)
out = torchvision.ops.roi_align(inp.cuda(), bbox.cuda(), output_size, scale, aligned=aligned)

Can someone else replicate this?

Oh you know what, it's probably because of use deterministic algorithms. We added a deterministic implementation but it is very memory hungry