theislab/scarches

Issue with expimap model

Closed this issue · 1 comments

early_stopping_kwargs = { "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss "threshold": 0, "patience": 50, "reduce_lr": True, "lr_patience": 13, "lr_factor": 0.1, }

intr_cvae.train( n_epochs=400, alpha_epoch_anneal=100, alpha=ALPHA, alpha_kl=0.5, weight_decay=0., early_stopping_kwargs=early_stopping_kwargs, use_early_stopping=True, monitor_only_val=False, seed=2024, )

`Preparing (32484, 1967)
Instantiating dataset
Init the group lasso proximal operator for the main terms.

NameError Traceback (most recent call last)
Cell In[16], line 9
1 early_stopping_kwargs = {
2 "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss
3 "threshold": 0,
(...)
7 "lr_factor": 0.1,
8 }
----> 9 intr_cvae.train(
10 n_epochs=400,
11 alpha_epoch_anneal=100,
12 alpha=ALPHA,
13 alpha_kl=0.5,
14 weight_decay=0.,
15 early_stopping_kwargs=early_stopping_kwargs,
16 use_early_stopping=True,
17 monitor_only_val=False,
18 seed=2024,
19 )

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/models/expimap/expimap_model.py:242, in EXPIMAP.train(self, n_epochs, lr, eps, alpha, omega, **kwargs)
232 kwargs["alpha_epoch_anneal"] = epochs_anneal
234 self.trainer = expiMapTrainer(
235 self.model,
236 self.adata,
(...)
240 **kwargs
241 )
--> 242 self.trainer.train(n_epochs, lr, eps)
243 self.is_trained_ = True

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/trainers/trvae/trainer.py:235, in Trainer.train(self, n_epochs, lr, eps)
232 batch_data[key] = batch.to(self.device)
234 # Loss Calculation
--> 235 self.on_iteration(batch_data)
237 # Validation of Model, Monitoring, Early Stopping
238 self.on_epoch_end()

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/trainers/expimap/regularized.py:296, in expiMapTrainer.on_iteration(self, batch_data)
293 def on_iteration(self, batch_data):
294 self.init_prox_ops()
--> 296 super().on_iteration(batch_data)
298 self.apply_prox_ops()

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/trainers/trvae/trainer.py:272, in Trainer.on_iteration(self, batch_data)
269 module.track_running_stats = False
271 # Calculate Loss depending on Trainer/Model
--> 272 self.current_loss = loss = self.loss(batch_data)
273 self.optimizer.zero_grad()
274 loss.backward()

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/trainers/expimap/regularized.py:346, in expiMapTrainer.loss(self, total_batch)
345 def loss(self, total_batch=None):
--> 346 recon_loss, kl_loss, hsic_loss = self.model(**total_batch)
348 if self.beta is not None and self.model.use_hsic:
349 weighted_hsic = self.beta * hsic_loss

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/models/expimap/expimap.py:179, in expiMap.forward(self, x, batch, sizefactor, labeled)
176 x_log = x
178 z1_mean, z1_log_var = self.encoder(x_log, batch)
--> 179 z1 = self.sampling(z1_mean, z1_log_var)
180 outputs = self.decoder(z1, batch)
182 if self.recon_loss == "mse":

File ~/miniconda3/envs/scarches/lib/python3.9/site-packages/scarches/models/base/_base.py:406, in CVAELatentsModelMixin.sampling(self, mu, log_var)
391 """Samples from standard Normal distribution and applies re-parametrization trick.
392 It is actually sampling from latent space distributions with N(mu, var), computed by encoder.
393
(...)
403 Torch Tensor of sampled data.
404 """
405 var = torch.exp(log_var) + 1e-4
--> 406 return Normal(mu, var.sqrt()).rsample()

NameError: name 'Normal' is not defined
`

install package version:

Name Version Build Channel

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
absl-py 2.1.0 pypi_0 pypi
aiohttp 3.9.3 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
anndata 0.10.5.post1 pypi_0 pypi
array-api-compat 1.4.1 pypi_0 pypi
asttokens 2.4.1 pypi_0 pypi
async-timeout 4.0.3 pypi_0 pypi
attrs 23.2.0 pypi_0 pypi
beautifulsoup4 4.12.3 pypi_0 pypi
blas 1.0 mkl
brewer2mpl 1.4.1 pypi_0 pypi
brotli-python 1.0.9 py39h6a678d5_7
bzip2 1.0.8 h7b6447c_0
ca-certificates 2023.12.12 h06a4308_0
certifi 2024.2.2 py39h06a4308_0
charset-normalizer 2.0.4 pyhd3eb1b0_0
chex 0.1.85 pypi_0 pypi
comm 0.2.1 pypi_0 pypi
contextlib2 21.6.0 pypi_0 pypi
contourpy 1.2.0 pypi_0 pypi
cuda-cudart 11.8.89 0 nvidia
cuda-cupti 11.8.87 0 nvidia
cuda-libraries 11.8.0 0 nvidia
cuda-nvrtc 11.8.89 0 nvidia
cuda-nvtx 11.8.86 0 nvidia
cuda-runtime 11.8.0 0 nvidia
cycler 0.12.1 pypi_0 pypi
debugpy 1.8.1 pypi_0 pypi
decorator 5.1.1 pypi_0 pypi
docrep 0.3.2 pypi_0 pypi
etils 1.5.2 pypi_0 pypi
exceptiongroup 1.2.0 pypi_0 pypi
executing 2.0.1 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.13.1 py39h06a4308_0
flax 0.8.1 pypi_0 pypi
fonttools 4.49.0 pypi_0 pypi
freetype 2.12.1 h4a9f257_0
frozenlist 1.4.1 pypi_0 pypi
fsspec 2024.2.0 pypi_0 pypi
gdown 5.1.0 pypi_0 pypi
get-annotations 0.1.2 pypi_0 pypi
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py39heeb90bb_0
gnutls 3.6.15 he1e5248_0
h5py 3.10.0 pypi_0 pypi
idna 3.4 py39h06a4308_0
igraph 0.11.4 pypi_0 pypi
importlib-metadata 7.0.1 pypi_0 pypi
importlib-resources 6.1.2 pypi_0 pypi
intel-openmp 2023.1.0 hdb19cb5_46306
ipykernel 6.29.3 pypi_0 pypi
ipython 8.18.1 pypi_0 pypi
jax 0.4.25 pypi_0 pypi
jaxlib 0.4.25 pypi_0 pypi
jedi 0.19.1 pypi_0 pypi
jinja2 3.1.3 py39h06a4308_0
joblib 1.3.2 pypi_0 pypi
jpeg 9e h5eee18b_1
jupyter-client 8.6.0 pypi_0 pypi
jupyter-core 5.7.1 pypi_0 pypi
kiwisolver 1.4.5 pypi_0 pypi
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
leidenalg 0.10.2 pypi_0 pypi
lerc 3.0 h295c915_0
libcublas 11.11.3.6 0 nvidia
libcufft 10.9.0.58 0 nvidia
libcufile 1.8.1.2 0 nvidia
libcurand 10.3.4.107 0 nvidia
libcusolver 11.4.1.48 0 nvidia
libcusparse 11.7.5.86 0 nvidia
libdeflate 1.17 h5eee18b_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libiconv 1.16 h7f8727e_2
libidn2 2.3.4 h5eee18b_0
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
libnpp 11.8.0.86 0 nvidia
libnvjpeg 11.9.0.86 0 nvidia
libpng 1.6.39 h5eee18b_0
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.1 h6a678d5_0
libunistring 0.9.10 h27cfd23_0
libwebp-base 1.3.2 h5eee18b_0
lightning 2.1.4 pypi_0 pypi
lightning-utilities 0.10.1 pypi_0 pypi
llvm-openmp 14.0.6 h9e868ea_0
llvmlite 0.42.0 pypi_0 pypi
lz4-c 1.9.4 h6a678d5_0
markdown-it-py 3.0.0 pypi_0 pypi
markupsafe 2.1.3 py39h5eee18b_0
matplotlib 3.8.3 pypi_0 pypi
matplotlib-inline 0.1.6 pypi_0 pypi
mdurl 0.1.2 pypi_0 pypi
mkl 2023.1.0 h213fc3f_46344
mkl-service 2.4.0 py39h5eee18b_1
mkl_fft 1.3.8 py39h5eee18b_0
mkl_random 1.2.4 py39hdb19cb5_0
ml-collections 0.1.1 pypi_0 pypi
ml-dtypes 0.3.2 pypi_0 pypi
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py39h06a4308_0
msgpack 1.0.7 pypi_0 pypi
mudata 0.2.3 pypi_0 pypi
multidict 6.0.5 pypi_0 pypi
multipledispatch 1.0.0 pypi_0 pypi
muon 0.1.5 pypi_0 pypi
natsort 8.4.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
nest-asyncio 1.6.0 pypi_0 pypi
nettle 3.7.3 hbbd107a_1
networkx 3.1 py39h06a4308_0
newick 1.0.0 pypi_0 pypi
numba 0.59.0 pypi_0 pypi
numpy 1.26.4 py39h5f9d8c6_0
numpy-base 1.26.4 py39hb5e798b_0
numpyro 0.13.2 pypi_0 pypi
openh264 2.1.1 h4ff587b_0
openjpeg 2.4.0 h3ad879b_0
openssl 3.0.13 h7f8727e_0
opt-einsum 3.3.0 pypi_0 pypi
optax 0.1.9 pypi_0 pypi
orbax-checkpoint 0.5.3 pypi_0 pypi
packaging 23.2 pypi_0 pypi
pandas 1.5.3 pypi_0 pypi
parso 0.8.3 pypi_0 pypi
patsy 0.5.6 pypi_0 pypi
pexpect 4.9.0 pypi_0 pypi
pillow 10.2.0 py39h5eee18b_0
pip 23.3.1 py39h06a4308_0
platformdirs 4.2.0 pypi_0 pypi
prompt-toolkit 3.0.43 pypi_0 pypi
protobuf 4.25.3 pypi_0 pypi
psutil 5.9.8 pypi_0 pypi
ptyprocess 0.7.0 pypi_0 pypi
pure-eval 0.2.2 pypi_0 pypi
pygments 2.17.2 pypi_0 pypi
pynndescent 0.5.11 pypi_0 pypi
pyparsing 3.1.1 pypi_0 pypi
pyro-api 0.1.2 pypi_0 pypi
pyro-ppl 1.9.0 pypi_0 pypi
pysocks 1.7.1 py39h06a4308_0
python 3.9.18 h955ad1f_0
python-dateutil 2.8.2 pypi_0 pypi
pytorch 2.2.1 py3.9_cuda11.8_cudnn8.7.0_0 pytorch
pytorch-cuda 11.8 h7e8668a_5 pytorch
pytorch-lightning 2.2.0.post0 pypi_0 pypi
pytorch-mutex 1.0 cuda pytorch
pytz 2024.1 pypi_0 pypi
pyyaml 6.0.1 py39h5eee18b_0
pyzmq 25.1.2 pypi_0 pypi
readline 8.2 h5eee18b_0
requests 2.31.0 py39h06a4308_1
rich 13.7.0 pypi_0 pypi
scanpy 1.9.8 pypi_0 pypi
scarches 0.6.0 pypi_0 pypi
schpl 1.0.5 pypi_0 pypi
scikit-learn 1.4.1.post1 pypi_0 pypi
scipy 1.12.0 pypi_0 pypi
scvi-tools 1.1.1 pypi_0 pypi
seaborn 0.13.2 pypi_0 pypi
session-info 1.0.0 pypi_0 pypi
setuptools 68.2.2 py39h06a4308_0
six 1.16.0 pypi_0 pypi
slalom 1.0.0.dev11 pypi_0 pypi
soupsieve 2.5 pypi_0 pypi
sqlite 3.41.2 h5eee18b_0
stack-data 0.6.3 pypi_0 pypi
statsmodels 0.14.1 pypi_0 pypi
stdlib-list 0.10.0 pypi_0 pypi
sympy 1.12 py39h06a4308_0
tbb 2021.8.0 hdb19cb5_0
tensorstore 0.1.54 pypi_0 pypi
texttable 1.7.0 pypi_0 pypi
threadpoolctl 3.3.0 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
toolz 0.12.1 pypi_0 pypi
torchaudio 2.2.1 py39_cu118 pytorch
torchmetrics 1.3.1 pypi_0 pypi
torchtriton 2.2.0 py39 pytorch
torchvision 0.17.1 py39_cu118 pytorch
tornado 6.4 pypi_0 pypi
tqdm 4.66.2 pypi_0 pypi
traitlets 5.14.1 pypi_0 pypi
typing_extensions 4.9.0 py39h06a4308_1
tzdata 2024a h04d1e81_0
umap-learn 0.5.5 pypi_0 pypi
urllib3 2.1.0 py39h06a4308_1
wcwidth 0.2.13 pypi_0 pypi
wheel 0.41.2 py39h06a4308_0
xz 5.4.6 h5eee18b_0
yaml 0.2.5 h7b6447c_0
yarl 1.9.4 pypi_0 pypi
zipp 3.17.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_0
zstd 1.5.5 hc292b87_0

Yes, thx for reporting, i have pushed the bugfix, please install the new version of scarches.