theislab/scarches

expected scalar type Double but found Float in forward of scGEN

Hrovatin opened this issue · 1 comments

Any ideas how to fix this? Data should be accessible on the server

adata = sc.read('/lustre/groups/ml01/workspace/lisa.sikkema/czi_tissue_references/data/scib/mock_scib/prepare/unscaled/hvg/adata_pre.h5ad')
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
from scipy import sparse
import numpy as np

adata = remove_sparsity(adata)
adata.X=adata.X.astype(np.double) # Hoped this would fix it, but it did not
network = sca.models.scgen(adata = adata,
                       hidden_layer_sizes=[256,128])
network.train()

Input In [21], in <cell line: 1>()
     53 adata.X=adata.X.astype(np.double)
     54 network = sca.models.scgen(adata = adata,
     55                        hidden_layer_sizes=[256,128])
---> 56 network.train()
     57 adata=network.batch_removal(
     58     adata, 
     59     batch_key=batch, 
     60     cell_label_key=cell_type,
     61     return_latent=True)
     62 # Edit object
     63 # REmove unnecesary added slots

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/models/scgen/vaearith_model.py:44, in scgen.train(self, n_epochs, lr, eps, batch_size, **kwargs)
     42 def train(self, n_epochs: int = 100, lr: float = 0.001, eps: float = 1e-8, batch_size = 32, **kwargs):
     43     self.trainer = vaeArithTrainer(self.model, self.adata, batch_size, **kwargs)
---> 44     self.trainer.train(n_epochs, lr, eps)
     45     self.is_trained_ = True

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/trainers/scgen/trainer.py:119, in vaeArithTrainer.train(self, n_epochs, lr, eps, **extras_kwargs)
    117 if upper - lower > 1:
    118     x_mb = x_mb.to(self.device) #to cuda or cpu
--> 119     reconstructions, mu, logvar = self.model(x_mb)
    121     loss = self.model._loss_function(x_mb, reconstructions, mu, logvar)
    123     self.optim.zero_grad()

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/models/scgen/vaearith.py:163, in vaeArith.forward(self, x)
    162 def forward(self, x: torch.Tensor):
--> 163     mu, logvar = self.encoder(x)
    164     z = self._sample_z(mu, logvar)
    165     x_hat = self.decoder(z)

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/models/scgen/modules.py:51, in Encoder.forward(self, x)
     49 def forward(self, x: torch.Tensor):
     50     if self.FC is not None:
---> 51         x = self.FC(x)
     53     mean = self.mean_encoder(x)
     54     log_var = self.log_var_encoder(x)

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/container.py:141, in Sequential.forward(self, input)
    139 def forward(self, input):
    140     for module in self:
--> 141         input = module(input)
    142     return input

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/linear.py:103, in Linear.forward(self, input)
    102 def forward(self, input: Tensor) -> Tensor:
--> 103     return F.linear(input, self.weight, self.bias)

RuntimeError: expected scalar type Double but found Float

quick fix:
network.model.float()
But i will check why this even happens.