theislab/scarches

Runtime error using scPoli - mat1 and mat2 must have the same dtype, but got Double and Float

Closed this issue · 7 comments

Hi,
I am trying to follow this tutorial: https://github.com/theislab/scarches/blob/master/notebooks/scpoli_surgery_pipeline.ipynb

matplotlib==3.8.0 numpy==1.26.0 pandas==2.1.1 scArches==0.5.9 scanpy==1.9.5 scikit-learn==1.3.1 seaborn==0.13.0 torch==2.1.0

When I try to train the model I run into this error:

`RuntimeError Traceback (most recent call last)
Cell In[17], line 8
1 scpoli_model = scPoli(
2 adata=ref,
3 condition_keys=condition_key,
(...)
6 recon_loss='zinb',
7 )
----> 8 scpoli_model.train(
9 n_epochs=125,
10 pretraining_epochs=100,
11 early_stopping_kwargs=early_stopping_kwargs,
12 eta=5,
13 )

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli_model.py:304, in scPoli.train(self, n_epochs, pretraining_epochs, eta, lr, eps, alpha_epoch_anneal, reload_best, prototype_training, unlabeled_prototype_training, **kwargs)
287 pretraining_epochs = int(np.floor(n_epochs * 0.9))
290 self.trainer = scPoliTrainer(
291 self.model,
292 self.adata,
(...)
302 **kwargs,
303 )
--> 304 self.trainer.train(n_epochs, lr, eps)
305 self.is_trained_ = True
306 self.prototypes_labeled_ = self.model.prototypes_labeled

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:305, in scPoliTrainer.train(self, n_epochs, lr, eps)
302 batch_data[key] = batch.to(self.device)
304 #loss calculation
--> 305 self.on_iteration(batch_data)
307 #validation of model, monitoring, early stopping
308 self.on_epoch_end()

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:333, in scPoliTrainer.on_iteration(self, batch_data)
330 module.track_running_stats = False
332 #calculate loss depending on trainer/model
--> 333 self.current_loss = loss = self.loss(batch_data)
334 self.optimizer.zero_grad()
336 loss.backward()

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/trainers/scpoli/trainer.py:533, in scPoliTrainer.loss(self, total_batch)
532 def loss(self, total_batch=None):
--> 533 latent, recon_loss, kl_loss, mmd_loss = self.model(**total_batch)
535 #calculate classifier loss for labeled/unlabeled data
536 label_categories = total_batch["labeled"].unique().tolist()

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:142, in scpoli.forward(self, x, batch, combined_batch, sizefactor, celltypes, labeled)
140 x_log = x
141 if "encoder" in self.inject_condition:
--> 142 z1_mean, z1_log_var = self.encoder(x_log, batch_embeddings)
143 else:
144 z1_mean, z1_log_var = self.encoder(x_log, batch=None)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:461, in Encoder.forward(self, x, batch)
459 x = torch.cat((x, batch), dim=-1)
460 if self.FC is not None:
--> 461 x = self.FC(x)
462 means = self.mean_encoder(x)
463 log_vars = self.log_var_encoder(x)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/scarches/models/scpoli/scpoli.py:665, in CondLayers.forward(self, x)
663 else:
664 expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1)
--> 665 out = self.expr_L(expr) + self.cond_L(cond)
666 return out

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/scarches/lib/python3.9/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float`

Has anyone run into this error and been able to fix it?

Hi @parkjosh-broadinstitute, could you try to check if casting the input to 'float32' fixes the issue?

adata.X = adata.X.astype('float32') before passing it to the model

Hi,
I get the exact same RunTimeError when running the scPoli function.
I tried your suggestion:
source_adata = source_adata.astype('float32')
but then I get:
AttributeError: 'AnnData' object has no attribute 'astype'
My h5ad object was built from a Seurat object with the R function SeuratDisk::Convert.
Thanks for your help

You need to transform the .X matrix within your adata object.

The correct command is: source_adata.X = source_adata.X.astype('float32').

oops, my mistake, this works for me and I managed to complete the remaining of the analysis without error.
Thanks!

Hi @cdedonno, yes that worked. thank you for your help!

Even though "nb" and "zinb" loss expect to receive raw counts the data should still be cast to float32 using @cdedonno suggestion.

adata.X = adata.X.astype('float32')

Hi @parkjosh-broadinstitute, the transformation should work automatically, but apparently there might be issues when passing integers, I will check. Thanks for reporting this.