mikgroup/sigpy

Segfault in 1x1 conv2d backward pass when both Pytorch and sigpy are installed (cudnn version mismatch)

nishi951 opened this issue · 0 comments

Describe the bug
Segfault in backward pass when running on GPU with Pytorch and torch.backends.cudnn.deterministic is True

To Reproduce
Steps to reproduce the behavior:

  1. Install the environment below with conda env create -f environment.yaml
  2. Activate it with conda activate sigseg
  3. Run the script below with python segfault.py

environment.yaml:

name: sigseg
channels:
  - frankong
  - nvidia
  - pytorch
  - conda-forge
dependencies:
  - cupy=12.1
  - cudnn
  - cutensor
  - nccl
  - numpy=1.22
  - python=3.10
  - pytorch=2.1
  - pytorch-cuda=12.1
  - PyWavelets
  - scipy
  - torchvision
  - torchaudio
  - tqdm
  - pyyaml
  - pip
  - pip:
     - sigpy==0.1.25

segfault.py

import sigpy as sp
#from cupy import cudnn
import torch
torch.backends.cudnn.deterministic = True
import torch.nn.functional as F

net_input = torch.randn(1, 10, 220, 220, dtype=torch.float32).to('cuda:0')
weight = torch.randn(10, 10, 1, 1).requires_grad_(True).to('cuda:0')
net_output = F.conv2d(net_input, weight, padding='same')
z = net_output
loss = torch.sum(torch.abs(z))
loss.backward() # Segfault occurs here

Expected behavior
I expect the code to finish without segfaulting.

Desktop (please complete the following information):

  • Ubuntu 22.04
  • NVIDIA RTX 3090, CUDA 12.1

Additional context

  1. Problem only occurs when doing the backward pass on a 1x1 conv2d. (e.g. 3x3 conv2d is fine)
  2. Problem is GPU only.
  3. Problem disappears when torch is imported BEFORE sigpy (probably related to sigpy/config.py)
  • I think the problem is related to cuDNN version mismatch between torch and sigpy.
  • I was actually able to resolve the problem by installing a Pytorch-compatible cudnn directly from apt and NOT installing cudnn via conda, but this took some digging.
  • Can try to write a pull request to modify config.py to warn people more specifically about the cudnn version e.g. by using cudnn.getVersion and torch.backends.cudnn.version().

Here's the full frozen environment:

name: sigseg
channels:
  - pytorch
  - nvidia
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_kmp_llvm
  - blas=1.0=mkl
  - brotli-python=1.1.0=py310hc6cd4ac_1
  - bzip2=1.0.8=hd590300_5
  - ca-certificates=2023.11.17=hbcca054_0
  - certifi=2023.11.17=pyhd8ed1ab_0
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - cuda-cudart=12.1.105=0
  - cuda-cupti=12.1.105=0
  - cuda-libraries=12.1.0=0
  - cuda-nvrtc=12.1.105=0
  - cuda-nvtx=12.1.105=0
  - cuda-opencl=12.3.101=0
  - cuda-runtime=12.1.0=0
  - cuda-version=12.2=he2b69de_2
  - cudnn=8.8.0.121=h264754d_4
  - cupy=12.1.0=py310hfc31588_1
  - cutensor=1.7.0.1=0
  - cutensor-cuda-12=2.0.0=0
  - fastrlock=0.8.2=py310hc6cd4ac_1
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.13.1=pyhd8ed1ab_0
  - freetype=2.12.1=h267a509_2
  - gmp=6.3.0=h59595ed_0
  - gmpy2=2.1.2=py310h3ec546c_1
  - gnutls=3.6.13=h85f3911_1
  - icu=73.2=h59595ed_0
  - idna=3.6=pyhd8ed1ab_0
  - jinja2=3.1.2=pyhd8ed1ab_1
  - jpeg=9e=h166bdaf_2
  - lame=3.100=h166bdaf_1003
  - lcms2=2.15=hfd0df8a_0
  - ld_impl_linux-64=2.40=h41732ed_0
  - lerc=4.0.0=h27087fc_0
  - libblas=3.9.0=16_linux64_mkl
  - libcblas=3.9.0=16_linux64_mkl
  - libcublas=12.1.0.26=0
  - libcufft=11.0.2.4=0
  - libcufile=1.8.1.2=0
  - libcurand=10.3.4.101=0
  - libcusolver=11.4.4.55=0
  - libcusparse=12.0.2.55=0
  - libcutensor-cuda-12=2.0.0.7=0
  - libcutensor-dev-cuda-12=2.0.0.7=0
  - libdeflate=1.17=h0b41bf4_0
  - libffi=3.4.2=h7f98852_5
  - libgcc-ng=13.2.0=h807b86a_3
  - libgfortran-ng=13.2.0=h69a702a_3
  - libgfortran5=13.2.0=ha4646dd_3
  - libhwloc=2.9.3=default_h554bfaf_1009
  - libiconv=1.17=hd590300_1
  - libjpeg-turbo=2.0.0=h9bf148f_0
  - liblapack=3.9.0=16_linux64_mkl
  - libnpp=12.0.2.50=0
  - libnsl=2.0.1=hd590300_0
  - libnvjitlink=12.1.105=0
  - libnvjpeg=12.1.1.14=0
  - libpng=1.6.39=h753d276_0
  - libsqlite=3.44.2=h2797004_0
  - libstdcxx-ng=13.2.0=h7e041cc_3
  - libtiff=4.5.0=h6adf6a1_2
  - libuuid=2.38.1=h0b41bf4_0
  - libwebp-base=1.3.2=hd590300_0
  - libxcb=1.13=h7f98852_1004
  - libxml2=2.11.6=h232c23b_0
  - libzlib=1.2.13=hd590300_5
  - llvm-openmp=15.0.7=h0cdce71_0
  - markupsafe=2.1.3=py310h2372a71_1
  - mkl=2022.2.1=h84fe81f_16997
  - mpc=1.3.1=hfe3b2da_0
  - mpfr=4.2.1=h9458935_0
  - mpmath=1.3.0=pyhd8ed1ab_0
  - nccl=2.19.4.1=h3a97aeb_0
  - ncurses=6.4=h59595ed_2
  - nettle=3.6=he412f7d_0
  - networkx=3.2.1=pyhd8ed1ab_0
  - numpy=1.22.4=py310h4ef5377_0
  - openh264=2.1.1=h780b84a_0
  - openjpeg=2.5.0=hfec8fc6_2
  - openssl=3.2.0=hd590300_1
  - pillow=9.4.0=py310h023d228_1
  - pip=23.3.1=pyhd8ed1ab_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.10.13=hd12c33a_0_cpython
  - python_abi=3.10=4_cp310
  - pytorch=2.1.1=py3.10_cuda12.1_cudnn8.9.2_0
  - pytorch-cuda=12.1=ha16c6d3_5
  - pytorch-mutex=1.0=cuda
  - pywavelets=1.4.1=py310h1f7b6fc_1
  - pyyaml=6.0.1=py310h2372a71_1
  - readline=8.2=h8228510_1
  - requests=2.31.0=pyhd8ed1ab_0
  - scipy=1.11.4=py310hb13e2d6_0
  - setuptools=68.2.2=pyhd8ed1ab_0
  - sympy=1.12=pypyh9d50eac_103
  - tbb=2021.11.0=h00ab1b0_0
  - tk=8.6.13=noxft_h4845f30_101
  - torchaudio=2.1.1=py310_cu121
  - torchtriton=2.1.0=py310
  - torchvision=0.16.1=py310_cu121
  - tqdm=4.66.1=pyhd8ed1ab_0
  - typing_extensions=4.9.0=pyha770c72_0
  - tzdata=2023c=h71feb2d_0
  - urllib3=2.1.0=pyhd8ed1ab_0
  - wheel=0.42.0=pyhd8ed1ab_0
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7f98852_2
  - zlib=1.2.13=hd590300_5
  - zstd=1.5.5=hfc55251_0
  - pip:
      - llvmlite==0.41.1
      - numba==0.58.1
      - sigpy==0.1.25