pytorch 1.9 support
jordibc opened this issue · 3 comments
Describe the bug
The current code doesn't seem to be compatible with the (new) pytorch 1.9.
To Reproduce
With a set of preprocessed particles:
CUDA_VISIBLE_DEVICES=0 cryodrgn train_vae particles.64.mrcs --poses poses.pkl --ctf ctfs.pkl --zdim 1 -o output -n 3 --relion31 --enc-layers 3 --enc-dim 256 --dec-layers 3 --dec-dim 256
It results eventually in:
File "/home/jordi/miniconda3/envs/cryodrgn-0.3.2/lib/python3.7/site-packages/torch/utils/data/sampler.py", line 124, in __iter__
yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
(When I try to "fix" it by changing torch.cuda.FloatTensor
to torch.FloatTensor
in the files backproject_voxel.py
, train_nn.py
, eval_vol.py
, train_vae.py
, and eval_images.py
, I get the same kind of error from different parts of the code.)
Expected behavior
Finish the training without producing an exception.
Additional context
This is the conda environment that I am using:
> conda list (cryodrgn-0.3.2)
# packages in environment at /home/jordi/miniconda3/envs/cryodrgn-0.3.2:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 1_llvm conda-forge
anyio 3.1.0 py37h89c1867_0 conda-forge
argon2-cffi 20.1.0 py37h5e8e339_2 conda-forge
async_generator 1.10 py_0 conda-forge
attrs 21.2.0 pyhd8ed1ab_0 conda-forge
babel 2.9.1 pyh44b312d_0 conda-forge
backcall 0.2.0 pyh9f0ad1d_0 conda-forge
backports 1.0 py_2 conda-forge
backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge
blas 1.0 mkl
bleach 3.3.0 pyh44b312d_0 conda-forge
brotlipy 0.7.0 py37h5e8e339_1001 conda-forge
bzip2 1.0.8 h7f98852_4 conda-forge
ca-certificates 2021.5.30 ha878542_0 conda-forge
certifi 2021.5.30 py37h89c1867_0 conda-forge
cffi 1.14.5 py37hc58025e_0 conda-forge
chardet 4.0.0 py37h89c1867_1 conda-forge
colorlover 0.3.0 pypi_0 pypi
cryodrgn 0.3.2 dev_0 <develop>
cryptography 3.4.7 py37h5d9358c_0 conda-forge
cudatoolkit 11.1.74 h6bb024c_0 nvidia
cufflinks 0.17.3 pypi_0 pypi
cycler 0.10.0 py_2 conda-forge
decorator 5.0.9 pyhd8ed1ab_0 conda-forge
defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge
entrypoints 0.3 pyhd8ed1ab_1003 conda-forge
ffmpeg 4.3.1 hca11adc_2 conda-forge
freetype 2.10.4 h0708190_1 conda-forge
future 0.18.2 py37h89c1867_3 conda-forge
gettext 0.21.0 hf68c758_0
gmp 6.2.1 h58526e2_0 conda-forge
gnutls 3.6.15 he1e5248_0
icu 68.1 h58526e2_0 conda-forge
idna 2.10 pyh9f0ad1d_0 conda-forge
importlib-metadata 4.5.0 py37h89c1867_0 conda-forge
intel-openmp 2021.2.0 h06a4308_610
ipykernel 5.5.5 py37h085eea5_0 conda-forge
ipython 7.24.1 py37h085eea5_0 conda-forge
ipython_genutils 0.2.0 py_1 conda-forge
ipywidgets 7.6.3 pypi_0 pypi
jbig 2.1 h7f98852_2003 conda-forge
jedi 0.18.0 py37h89c1867_2 conda-forge
jinja2 3.0.1 pyhd8ed1ab_0 conda-forge
joblib 1.0.1 pyhd8ed1ab_0 conda-forge
jpeg 9b h024ee3a_2
json5 0.9.5 pyh9f0ad1d_0 conda-forge
jsonschema 3.2.0 pyhd8ed1ab_3 conda-forge
jupyter_client 6.1.12 pyhd8ed1ab_0 conda-forge
jupyter_core 4.7.1 py37h89c1867_0 conda-forge
jupyter_server 1.8.0 pyhd8ed1ab_0 conda-forge
jupyterlab 3.0.16 pyhd8ed1ab_0 conda-forge
jupyterlab-widgets 1.0.0 pypi_0 pypi
jupyterlab_pygments 0.1.2 pyh9f0ad1d_0 conda-forge
jupyterlab_server 2.6.0 pyhd8ed1ab_0 conda-forge
kiwisolver 1.3.1 py37h2527ec5_1 conda-forge
lame 3.100 h14c3975_1001 conda-forge
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.35.1 hea4e1c9_2 conda-forge
lerc 2.2.1 h9c3ff4c_0 conda-forge
libblas 3.9.0 8_mkl conda-forge
libcblas 3.9.0 8_mkl conda-forge
libdeflate 1.7 h7f98852_5 conda-forge
libffi 3.3 h58526e2_2 conda-forge
libgcc-ng 9.3.0 h2828fa1_19 conda-forge
libgfortran-ng 9.3.0 hff62375_19 conda-forge
libgfortran5 9.3.0 hff62375_19 conda-forge
libgomp 9.3.0 h2828fa1_19 conda-forge
libiconv 1.16 h516909a_0 conda-forge
libidn2 2.3.1 h7f98852_0 conda-forge
liblapack 3.9.0 8_mkl conda-forge
libllvm10 10.0.1 he513fc3_3 conda-forge
libpng 1.6.37 h21135ba_2 conda-forge
libprotobuf 3.15.8 h780b84a_0 conda-forge
libsodium 1.0.18 h36c2ea0_1 conda-forge
libstdcxx-ng 9.3.0 h6de172a_19 conda-forge
libtasn1 4.16.0 h27cfd23_0
libtiff 4.2.0 h3942068_0
libunistring 0.9.10 h14c3975_0 conda-forge
libuv 1.41.0 h7f98852_0 conda-forge
libwebp-base 1.2.0 h7f98852_2 conda-forge
libxml2 2.9.12 h72842e0_0 conda-forge
llvm-openmp 11.1.0 h4bd325d_1 conda-forge
llvmlite 0.36.0 py37h9d7f4d0_0 conda-forge
lz4-c 1.9.3 h9c3ff4c_0 conda-forge
markupsafe 2.0.1 py37h5e8e339_0 conda-forge
matplotlib-base 3.4.2 py37hdd32ed1_0 conda-forge
matplotlib-inline 0.1.2 pyhd8ed1ab_2 conda-forge
mistune 0.8.4 py37h5e8e339_1003 conda-forge
mkl 2020.4 h726a3e6_304 conda-forge
mkl-service 2.3.0 py37h8f50634_2 conda-forge
mkl_fft 1.3.0 py37h902c9e0_1 conda-forge
mkl_random 1.2.0 py37h9fdb41a_1 conda-forge
nbclassic 0.3.1 pyhd8ed1ab_1 conda-forge
nbclient 0.5.3 pyhd8ed1ab_0 conda-forge
nbconvert 6.0.7 py37h89c1867_3 conda-forge
nbformat 5.1.3 pyhd8ed1ab_0 conda-forge
ncurses 6.2 h58526e2_4 conda-forge
nest-asyncio 1.5.1 pyhd8ed1ab_0 conda-forge
nettle 3.7.3 hbbd107a_1
ninja 1.10.2 h4bd325d_0 conda-forge
notebook 6.4.0 pyha770c72_0 conda-forge
numba 0.53.1 py37hb11d6e1_1 conda-forge
numpy 1.19.2 py37h54aff64_0
numpy-base 1.19.2 py37hfa32c7d_0
olefile 0.46 pyh9f0ad1d_1 conda-forge
openh264 2.1.1 h780b84a_0 conda-forge
openjpeg 2.4.0 hb52868f_1 conda-forge
openssl 1.1.1k h7f98852_0 conda-forge
packaging 20.9 pyh44b312d_0 conda-forge
pandas 1.2.4 py37h219a48f_0 conda-forge
pandoc 2.14.0.2 h7f98852_0 conda-forge
pandocfilters 1.4.2 py_1 conda-forge
parso 0.8.2 pyhd8ed1ab_0 conda-forge
patsy 0.5.1 py_0 conda-forge
pexpect 4.8.0 pyh9f0ad1d_2 conda-forge
pickleshare 0.7.5 py_1003 conda-forge
pillow 8.2.0 py37he98fc37_0
pip 21.1.2 pyhd8ed1ab_0 conda-forge
plotly 4.14.3 pypi_0 pypi
prometheus_client 0.11.0 pyhd8ed1ab_0 conda-forge
prompt-toolkit 3.0.18 pyha770c72_0 conda-forge
ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge
pycparser 2.20 pyh9f0ad1d_2 conda-forge
pygments 2.9.0 pyhd8ed1ab_0 conda-forge
pynndescent 0.5.2 pyh44b312d_0 conda-forge
pyopenssl 20.0.1 pyhd8ed1ab_0 conda-forge
pyparsing 2.4.7 pyh9f0ad1d_0 conda-forge
pyrsistent 0.17.3 py37h5e8e339_2 conda-forge
pysocks 1.7.1 py37h89c1867_3 conda-forge
python 3.7.10 hffdb5ce_100_cpython conda-forge
python-dateutil 2.8.1 py_0 conda-forge
python_abi 3.7 1_cp37m conda-forge
pytorch 1.9.0 py3.7_cuda11.1_cudnn8.0.5_0 pytorch
pytz 2021.1 pyhd8ed1ab_0 conda-forge
pyzmq 22.1.0 py37h336d617_0 conda-forge
readline 8.1 h46c0cb4_0 conda-forge
requests 2.25.1 pyhd3deb0d_0 conda-forge
retrying 1.3.3 pypi_0 pypi
scikit-learn 0.24.2 py37h18a542f_0 conda-forge
scipy 1.6.3 py37h29e03ee_0 conda-forge
seaborn 0.11.1 hd8ed1ab_1 conda-forge
seaborn-base 0.11.1 pyhd8ed1ab_1 conda-forge
send2trash 1.5.0 py_0 conda-forge
setuptools 49.6.0 py37h89c1867_3 conda-forge
six 1.16.0 pyh6c4a22f_0 conda-forge
sleef 3.5.1 h7f98852_1 conda-forge
sniffio 1.2.0 py37h89c1867_1 conda-forge
sqlite 3.35.5 h74cdb3f_0 conda-forge
statsmodels 0.12.2 py37h902c9e0_0 conda-forge
tbb 2020.2 h4bd325d_4 conda-forge
terminado 0.10.1 py37h89c1867_0 conda-forge
testpath 0.5.0 pyhd8ed1ab_0 conda-forge
threadpoolctl 2.1.0 pyh5ca1d4c_0 conda-forge
tk 8.6.10 h21135ba_1 conda-forge
torchaudio 0.9.0 py37 pytorch
torchvision 0.10.0 py37_cu111 pytorch
tornado 6.1 py37h5e8e339_1 conda-forge
traitlets 5.0.5 py_0 conda-forge
typing_extensions 3.10.0.0 pyha770c72_0 conda-forge
umap-learn 0.5.1 py37h89c1867_1 conda-forge
urllib3 1.26.5 pyhd8ed1ab_0 conda-forge
wcwidth 0.2.5 pyh9f0ad1d_2 conda-forge
webencodings 0.5.1 py_1 conda-forge
websocket-client 0.57.0 py37h89c1867_4 conda-forge
wheel 0.36.2 pyhd3deb0d_0 conda-forge
widgetsnbextension 3.5.1 pypi_0 pypi
x264 1!161.3030 h7f98852_1 conda-forge
xz 5.2.5 h516909a_1 conda-forge
zeromq 4.3.4 h9c3ff4c_0 conda-forge
zipp 3.4.1 pyhd8ed1ab_0 conda-forge
zlib 1.2.11 h516909a_1010 conda-forge
zstd 1.4.9 ha95c52a_0 conda-forge
I am having the same issue! I was able to circumvent this for now by downgrading to pytorch v1.8.0.
This required modifying the .condarc
file in my home directory by switching pip_interop_enabled: true
to false.
Then I could follow the instructions here: https://pytorch.org/get-started/previous-versions/
Thanks for reporting this. I think the newest version of pytorch sets the default device as cpu
instead of cuda
in the data loader. Quick search of reported issues suggests that it might be caused by torch.set_default_tensor_type(torch.cuda.FloatTensor)
but if we remove that there will other places in the code that break... I'm not sure if there is a quick fix, but I'll look into it.
In the meantime, as Ruby suggested, you can install cryoDRGN with pytorch 1.8:
conda install pytorch==1.8.0 -c pytorch
This incompatibility with the latest version of pytorch is fixed in the top-of-tree and will be available in the upcoming version 1.0.0 release.