gao-lab/GLUE

Error in scglue.models.fit_SCGLUE() function

LalicJ opened this issue · 0 comments

LalicJ commented

Hi, sorry to bother you. When I run scglue.models.fit_SCGLUE(), I got this error as follow:

glue = scglue.models.fit_SCGLUE(
    {"rna": rna, "atac": atac}, guidance,
    fit_kws={"directory": "glue"}
)

[INFO] fit_SCGLUE: Pretraining SCGLUE model... [INFO] autodevice: Using CPU as computation device. [INFO] check_graph: Checking variable coverage... [INFO] check_graph: Checking edge attributes... [INFO] check_graph: Checking self-loops... [INFO] check_graph: Checking graph symmetry... [INFO] check_graph: All checks passed! [INFO] SCGLUEModel: Setting graph_batch_size= 222199 [INFO] SCGLUEModel: Settingmax_epochs= 98 [INFO] SCGLUEModel: Settingpatience= 9 [INFO] SCGLUEModel: Settingreduce_lr_patience` = 5
[INFO] SCGLUETrainer: Using training directory: "glue/pretrain"
Current run is terminating due to exception: Expected value argument (Tensor of shape (128, 3428)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution NegativeBinomial(total_count: torch.Size([128, 3428]), logits: torch.Size([128, 3428])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000, ..., 1.4051, 1.4051, 1.4051],
[0.0000, 0.0000, 1.9990, ..., 1.4329, 1.4329, 0.0000],
[0.0000, 0.0000, 1.5017, ..., 1.5017, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 1.3823, 1.3823],
[0.0000, 0.0000, 0.0000, ..., 2.2712, 1.3602, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])
Engine run is terminating due to exception: Expected value argument (Tensor of shape (128, 3428)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution NegativeBinomial(total_count: torch.Size([128, 3428]), logits: torch.Size([128, 3428])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000, ..., 1.4051, 1.4051, 1.4051],
[0.0000, 0.0000, 1.9990, ..., 1.4329, 1.4329, 0.0000],
[0.0000, 0.0000, 1.5017, ..., 1.5017, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 1.3823, 1.3823],
[0.0000, 0.0000, 0.0000, ..., 2.2712, 1.3602, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])

ValueError Traceback (most recent call last)
Cell In[26], line 1
----> 1 glue = scglue.models.fit_SCGLUE(
2 {"rna": rna, "atac": atac}, guidance,
3 fit_kws={"directory": "glue"}
4 )

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/init.py:206, in fit_SCGLUE(adatas, graph, model, init_kws, compile_kws, fit_kws, balance_kws)
204 pretrain = model(adatas, sorted(graph.nodes), **pretrain_init_kws)
205 pretrain.compile(**compile_kws)
--> 206 pretrain.fit(adatas, graph, **pretrain_fit_kws)
207 if "directory" in pretrain_fit_kws:
208 pretrain.save(os.path.join(pretrain_fit_kws["directory"], "pretrain.dill"))

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/scglue.py:946, in SCGLUEModel.fit(self, adatas, graph, neg_samples, val_split, data_batch_size, graph_batch_size, align_burnin, safe_burnin, max_epochs, patience, reduce_lr_patience, wait_n_lrs, directory)
943 if self.trainer.freeze_u:
944 self.logger.info("Cell embeddings are frozen")
--> 946 super().fit(
947 data, graph, val_split=val_split,
948 data_batch_size=data_batch_size, graph_batch_size=graph_batch_size,
949 align_burnin=align_burnin, safe_burnin=safe_burnin,
950 max_epochs=max_epochs, patience=patience,
951 reduce_lr_patience=reduce_lr_patience, wait_n_lrs=wait_n_lrs,
952 random_seed=self.random_seed,
953 directory=directory
954 )

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/base.py:333, in Model.fit(self, *args, **kwargs)
318 def fit(self, *args, **kwargs) -> None:
319 r"""
320 Alias of .trainer.fit.
321
(...)
331 Subclasses may override arguments for API definition.
332 """
--> 333 self.trainer.fit(*args, **kwargs)

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/glue.py:608, in GLUETrainer.fit(self, data, graph, val_split, data_batch_size, graph_batch_size, align_burnin, safe_burnin, max_epochs, patience, reduce_lr_patience, wait_n_lrs, random_seed, directory, plugins)
606 plugins = default_plugins + (plugins or [])
607 try:
--> 608 super().fit(
609 train_loader, val_loader=val_loader,
610 max_epochs=max_epochs, random_seed=random_seed,
611 directory=directory, plugins=plugins
612 )
613 finally:
614 data.clean()

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/base.py:199, in Trainer.fit(self, train_loader, val_loader, max_epochs, random_seed, directory, plugins)
197 # Start engines
198 torch.manual_seed(random_seed)
--> 199 train_engine.run(train_loader, max_epochs=max_epochs)
201 torch.cuda.empty_cache()

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:704, in Engine.run(self, data, max_epochs, epoch_length, seed)
701 raise ValueError("epoch_length should be provided if data is None")
703 self.state.dataloader = data
--> 704 return self._internal_run()

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:783, in Engine._internal_run(self)
781 self._dataloader_iter = None
782 self.logger.error(f"Engine run is terminating due to exception: {e}")
--> 783 self._handle_exception(e)
785 self._dataloader_iter = None
786 return self.state

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:464, in Engine._handle_exception(self, e)
462 def _handle_exception(self, e: BaseException) -> None:
463 if Events.EXCEPTION_RAISED in self._event_handlers:
--> 464 self._fire_event(Events.EXCEPTION_RAISED, e)
465 else:
466 raise e

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:421, in Engine._fire_event(self, event_name, *event_args, **event_kwargs)
419 kwargs.update(event_kwargs)
420 first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
--> 421 func(*first, *(event_args + others), **kwargs)

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/base.py:162, in Trainer.fit.._handle_exception(engine, e)
160 engine.terminate()
161 else:
--> 162 raise e

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:753, in Engine._internal_run(self)
750 if self._dataloader_iter is None:
751 self._setup_engine()
--> 753 time_taken = self._run_once_on_dataset()
754 # time is available for handlers but must be update after fire
755 self.state.times[Events.EPOCH_COMPLETED.name] = time_taken

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:854, in Engine._run_once_on_dataset(self)
852 except Exception as e:
853 self.logger.error(f"Current run is terminating due to exception: {e}")
--> 854 self._handle_exception(e)
856 return time.time() - start_time

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:464, in Engine._handle_exception(self, e)
462 def _handle_exception(self, e: BaseException) -> None:
463 if Events.EXCEPTION_RAISED in self._event_handlers:
--> 464 self._fire_event(Events.EXCEPTION_RAISED, e)
465 else:
466 raise e

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:421, in Engine._fire_event(self, event_name, *event_args, **event_kwargs)
419 kwargs.update(event_kwargs)
420 first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
--> 421 func(*first, *(event_args + others), **kwargs)

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/base.py:162, in Trainer.fit.._handle_exception(engine, e)
160 engine.terminate()
161 else:
--> 162 raise e

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/ignite/engine/engine.py:840, in Engine._run_once_on_dataset(self)
838 self.state.iteration += 1
839 self._fire_event(Events.ITERATION_STARTED)
--> 840 self.state.output = self._process_function(self, self.state.batch)
841 self._fire_event(Events.ITERATION_COMPLETED)
843 if self.should_terminate or self.should_terminate_single_epoch:

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/scglue.py:372, in SCGLUETrainer.train_step(self, engine, data)
369 self.dsc_optim.step()
371 # Generator step
--> 372 losses = self.compute_losses(data, epoch)
373 self.net.zero_grad(set_to_none=True)
374 losses["gen_loss"].backward()

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/scglue.py:318, in SCGLUETrainer.compute_losses(self, data, epoch, dsc_only)
315 g_kl = D.kl_divergence(v, prior).sum(dim=1).mean() / vsamp.shape[0]
316 g_elbo = g_nll + self.lam_kl * g_kl
--> 318 x_nll = {
319 k: -net.u2x[k](
320 usamp[k], vsamp[getattr(net, f"{k}_idx")], xbch[k], l[k]
321 ).log_prob(x[k]).mean()
322 for k in net.keys
323 }
324 x_kl = {
325 k: D.kl_divergence(
326 u[k], prior
327 ).sum(dim=1).mean() / x[k].shape[1]
328 for k in net.keys
329 }
330 x_elbo = {
331 k: x_nll[k] + self.lam_kl * x_kl[k]
332 for k in net.keys
333 }

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/scglue/models/scglue.py:319, in (.0)
315 g_kl = D.kl_divergence(v, prior).sum(dim=1).mean() / vsamp.shape[0]
316 g_elbo = g_nll + self.lam_kl * g_kl
318 x_nll = {
--> 319 k: -net.u2x[k](
320 usamp[k], vsamp[getattr(net, f"{k}_idx")], xbch[k], l[k]
321 ).log_prob(x[k]).mean()
322 for k in net.keys
323 }
324 x_kl = {
325 k: D.kl_divergence(
326 u[k], prior
327 ).sum(dim=1).mean() / x[k].shape[1]
328 for k in net.keys
329 }
330 x_elbo = {
331 k: x_nll[k] + self.lam_kl * x_kl[k]
332 for k in net.keys
333 }

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/torch/distributions/negative_binomial.py:97, in NegativeBinomial.log_prob(self, value)
95 def log_prob(self, value):
96 if self._validate_args:
---> 97 self._validate_sample(value)
99 log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) +
100 value * F.logsigmoid(self.logits))
102 log_normalization = (-torch.lgamma(self.total_count + value) + torch.lgamma(1. + value) +
103 torch.lgamma(self.total_count))

File /data5/yali/softs/conda/Conda_data/envs/scglue/lib/python3.9/site-packages/torch/distributions/distribution.py:294, in Distribution._validate_sample(self, value)
292 valid = support.check(value)
293 if not valid.all():
--> 294 raise ValueError(
295 "Expected value argument "
296 f"({type(value).name} of shape {tuple(value.shape)}) "
297 f"to be within the support ({repr(support)}) "
298 f"of the distribution {repr(self)}, "
299 f"but found invalid values:\n{value}"
300 )

ValueError: Expected value argument (Tensor of shape (128, 3428)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution NegativeBinomial(total_count: torch.Size([128, 3428]), logits: torch.Size([128, 3428])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000, ..., 1.4051, 1.4051, 1.4051],
[0.0000, 0.0000, 1.9990, ..., 1.4329, 1.4329, 0.0000],
[0.0000, 0.0000, 1.5017, ..., 1.5017, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 1.3823, 1.3823],
[0.0000, 0.0000, 0.0000, ..., 2.2712, 1.3602, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])


`
I saw in the previous question that you said this step needed to use the raw count matrix. But I did use the raw count matrix for RNA and peaks. I don't know exactly what went wrong, do you have any suggestions? @Jeff1995 Any help would be greatly appreciated.