ml-struct-bio/cryodrgn

pytorch 1.9 support

jordibc opened this issue · 3 comments

Describe the bug
The current code doesn't seem to be compatible with the (new) pytorch 1.9.

To Reproduce
With a set of preprocessed particles:

CUDA_VISIBLE_DEVICES=0 cryodrgn train_vae particles.64.mrcs --poses poses.pkl --ctf ctfs.pkl --zdim 1 -o output -n 3 --relion31 --enc-layers 3 --enc-dim 256 --dec-layers 3 --dec-dim 256

It results eventually in:

  File "/home/jordi/miniconda3/envs/cryodrgn-0.3.2/lib/python3.7/site-packages/torch/utils/data/sampler.py", line 124, in __iter__
    yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

(When I try to "fix" it by changing torch.cuda.FloatTensor to torch.FloatTensor in the files backproject_voxel.py, train_nn.py, eval_vol.py, train_vae.py, and eval_images.py, I get the same kind of error from different parts of the code.)

Expected behavior
Finish the training without producing an exception.

Additional context
This is the conda environment that I am using:

> conda list                                                                          (cryodrgn-0.3.2) 
# packages in environment at /home/jordi/miniconda3/envs/cryodrgn-0.3.2:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                      1_llvm    conda-forge
anyio                     3.1.0            py37h89c1867_0    conda-forge
argon2-cffi               20.1.0           py37h5e8e339_2    conda-forge
async_generator           1.10                       py_0    conda-forge
attrs                     21.2.0             pyhd8ed1ab_0    conda-forge
babel                     2.9.1              pyh44b312d_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl  
bleach                    3.3.0              pyh44b312d_0    conda-forge
brotlipy                  0.7.0           py37h5e8e339_1001    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2021.5.30            ha878542_0    conda-forge
certifi                   2021.5.30        py37h89c1867_0    conda-forge
cffi                      1.14.5           py37hc58025e_0    conda-forge
chardet                   4.0.0            py37h89c1867_1    conda-forge
colorlover                0.3.0                    pypi_0    pypi
cryodrgn                  0.3.2                     dev_0    <develop>
cryptography              3.4.7            py37h5d9358c_0    conda-forge
cudatoolkit               11.1.74              h6bb024c_0    nvidia
cufflinks                 0.17.3                   pypi_0    pypi
cycler                    0.10.0                     py_2    conda-forge
decorator                 5.0.9              pyhd8ed1ab_0    conda-forge
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
ffmpeg                    4.3.1                hca11adc_2    conda-forge
freetype                  2.10.4               h0708190_1    conda-forge
future                    0.18.2           py37h89c1867_3    conda-forge
gettext                   0.21.0               hf68c758_0  
gmp                       6.2.1                h58526e2_0    conda-forge
gnutls                    3.6.15               he1e5248_0  
icu                       68.1                 h58526e2_0    conda-forge
idna                      2.10               pyh9f0ad1d_0    conda-forge
importlib-metadata        4.5.0            py37h89c1867_0    conda-forge
intel-openmp              2021.2.0           h06a4308_610  
ipykernel                 5.5.5            py37h085eea5_0    conda-forge
ipython                   7.24.1           py37h085eea5_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
ipywidgets                7.6.3                    pypi_0    pypi
jbig                      2.1               h7f98852_2003    conda-forge
jedi                      0.18.0           py37h89c1867_2    conda-forge
jinja2                    3.0.1              pyhd8ed1ab_0    conda-forge
joblib                    1.0.1              pyhd8ed1ab_0    conda-forge
jpeg                      9b                   h024ee3a_2  
json5                     0.9.5              pyh9f0ad1d_0    conda-forge
jsonschema                3.2.0              pyhd8ed1ab_3    conda-forge
jupyter_client            6.1.12             pyhd8ed1ab_0    conda-forge
jupyter_core              4.7.1            py37h89c1867_0    conda-forge
jupyter_server            1.8.0              pyhd8ed1ab_0    conda-forge
jupyterlab                3.0.16             pyhd8ed1ab_0    conda-forge
jupyterlab-widgets        1.0.0                    pypi_0    pypi
jupyterlab_pygments       0.1.2              pyh9f0ad1d_0    conda-forge
jupyterlab_server         2.6.0              pyhd8ed1ab_0    conda-forge
kiwisolver                1.3.1            py37h2527ec5_1    conda-forge
lame                      3.100             h14c3975_1001    conda-forge
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.35.1               hea4e1c9_2    conda-forge
lerc                      2.2.1                h9c3ff4c_0    conda-forge
libblas                   3.9.0                     8_mkl    conda-forge
libcblas                  3.9.0                     8_mkl    conda-forge
libdeflate                1.7                  h7f98852_5    conda-forge
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 9.3.0               h2828fa1_19    conda-forge
libgfortran-ng            9.3.0               hff62375_19    conda-forge
libgfortran5              9.3.0               hff62375_19    conda-forge
libgomp                   9.3.0               h2828fa1_19    conda-forge
libiconv                  1.16                 h516909a_0    conda-forge
libidn2                   2.3.1                h7f98852_0    conda-forge
liblapack                 3.9.0                     8_mkl    conda-forge
libllvm10                 10.0.1               he513fc3_3    conda-forge
libpng                    1.6.37               h21135ba_2    conda-forge
libprotobuf               3.15.8               h780b84a_0    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              9.3.0               h6de172a_19    conda-forge
libtasn1                  4.16.0               h27cfd23_0  
libtiff                   4.2.0                h3942068_0  
libunistring              0.9.10               h14c3975_0    conda-forge
libuv                     1.41.0               h7f98852_0    conda-forge
libwebp-base              1.2.0                h7f98852_2    conda-forge
libxml2                   2.9.12               h72842e0_0    conda-forge
llvm-openmp               11.1.0               h4bd325d_1    conda-forge
llvmlite                  0.36.0           py37h9d7f4d0_0    conda-forge
lz4-c                     1.9.3                h9c3ff4c_0    conda-forge
markupsafe                2.0.1            py37h5e8e339_0    conda-forge
matplotlib-base           3.4.2            py37hdd32ed1_0    conda-forge
matplotlib-inline         0.1.2              pyhd8ed1ab_2    conda-forge
mistune                   0.8.4           py37h5e8e339_1003    conda-forge
mkl                       2020.4             h726a3e6_304    conda-forge
mkl-service               2.3.0            py37h8f50634_2    conda-forge
mkl_fft                   1.3.0            py37h902c9e0_1    conda-forge
mkl_random                1.2.0            py37h9fdb41a_1    conda-forge
nbclassic                 0.3.1              pyhd8ed1ab_1    conda-forge
nbclient                  0.5.3              pyhd8ed1ab_0    conda-forge
nbconvert                 6.0.7            py37h89c1867_3    conda-forge
nbformat                  5.1.3              pyhd8ed1ab_0    conda-forge
ncurses                   6.2                  h58526e2_4    conda-forge
nest-asyncio              1.5.1              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1  
ninja                     1.10.2               h4bd325d_0    conda-forge
notebook                  6.4.0              pyha770c72_0    conda-forge
numba                     0.53.1           py37hb11d6e1_1    conda-forge
numpy                     1.19.2           py37h54aff64_0  
numpy-base                1.19.2           py37hfa32c7d_0  
olefile                   0.46               pyh9f0ad1d_1    conda-forge
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.4.0                hb52868f_1    conda-forge
openssl                   1.1.1k               h7f98852_0    conda-forge
packaging                 20.9               pyh44b312d_0    conda-forge
pandas                    1.2.4            py37h219a48f_0    conda-forge
pandoc                    2.14.0.2             h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.8.2              pyhd8ed1ab_0    conda-forge
patsy                     0.5.1                      py_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.2.0            py37he98fc37_0  
pip                       21.1.2             pyhd8ed1ab_0    conda-forge
plotly                    4.14.3                   pypi_0    pypi
prometheus_client         0.11.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.18             pyha770c72_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pycparser                 2.20               pyh9f0ad1d_2    conda-forge
pygments                  2.9.0              pyhd8ed1ab_0    conda-forge
pynndescent               0.5.2              pyh44b312d_0    conda-forge
pyopenssl                 20.0.1             pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyrsistent                0.17.3           py37h5e8e339_2    conda-forge
pysocks                   1.7.1            py37h89c1867_3    conda-forge
python                    3.7.10          hffdb5ce_100_cpython    conda-forge
python-dateutil           2.8.1                      py_0    conda-forge
python_abi                3.7                     1_cp37m    conda-forge
pytorch                   1.9.0           py3.7_cuda11.1_cudnn8.0.5_0    pytorch
pytz                      2021.1             pyhd8ed1ab_0    conda-forge
pyzmq                     22.1.0           py37h336d617_0    conda-forge
readline                  8.1                  h46c0cb4_0    conda-forge
requests                  2.25.1             pyhd3deb0d_0    conda-forge
retrying                  1.3.3                    pypi_0    pypi
scikit-learn              0.24.2           py37h18a542f_0    conda-forge
scipy                     1.6.3            py37h29e03ee_0    conda-forge
seaborn                   0.11.1               hd8ed1ab_1    conda-forge
seaborn-base              0.11.1             pyhd8ed1ab_1    conda-forge
send2trash                1.5.0                      py_0    conda-forge
setuptools                49.6.0           py37h89c1867_3    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h7f98852_1    conda-forge
sniffio                   1.2.0            py37h89c1867_1    conda-forge
sqlite                    3.35.5               h74cdb3f_0    conda-forge
statsmodels               0.12.2           py37h902c9e0_0    conda-forge
tbb                       2020.2               h4bd325d_4    conda-forge
terminado                 0.10.1           py37h89c1867_0    conda-forge
testpath                  0.5.0              pyhd8ed1ab_0    conda-forge
threadpoolctl             2.1.0              pyh5ca1d4c_0    conda-forge
tk                        8.6.10               h21135ba_1    conda-forge
torchaudio                0.9.0                      py37    pytorch
torchvision               0.10.0               py37_cu111    pytorch
tornado                   6.1              py37h5e8e339_1    conda-forge
traitlets                 5.0.5                      py_0    conda-forge
typing_extensions         3.10.0.0           pyha770c72_0    conda-forge
umap-learn                0.5.1            py37h89c1867_1    conda-forge
urllib3                   1.26.5             pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websocket-client          0.57.0           py37h89c1867_4    conda-forge
wheel                     0.36.2             pyhd3deb0d_0    conda-forge
widgetsnbextension        3.5.1                    pypi_0    pypi
x264                      1!161.3030           h7f98852_1    conda-forge
xz                        5.2.5                h516909a_1    conda-forge
zeromq                    4.3.4                h9c3ff4c_0    conda-forge
zipp                      3.4.1              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11            h516909a_1010    conda-forge
zstd                      1.4.9                ha95c52a_0    conda-forge

I am having the same issue! I was able to circumvent this for now by downgrading to pytorch v1.8.0.

This required modifying the .condarc file in my home directory by switching pip_interop_enabled: true to false.
Then I could follow the instructions here: https://pytorch.org/get-started/previous-versions/

Thanks for reporting this. I think the newest version of pytorch sets the default device as cpu instead of cuda in the data loader. Quick search of reported issues suggests that it might be caused by torch.set_default_tensor_type(torch.cuda.FloatTensor) but if we remove that there will other places in the code that break... I'm not sure if there is a quick fix, but I'll look into it.

In the meantime, as Ruby suggested, you can install cryoDRGN with pytorch 1.8:

conda install pytorch==1.8.0  -c pytorch

This incompatibility with the latest version of pytorch is fixed in the top-of-tree and will be available in the upcoming version 1.0.0 release.