expected scalar type Double but found Float in forward of scGEN
Hrovatin opened this issue · 1 comments
Hrovatin commented
Any ideas how to fix this? Data should be accessible on the server
adata = sc.read('/lustre/groups/ml01/workspace/lisa.sikkema/czi_tissue_references/data/scib/mock_scib/prepare/unscaled/hvg/adata_pre.h5ad')
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
from scipy import sparse
import numpy as np
adata = remove_sparsity(adata)
adata.X=adata.X.astype(np.double) # Hoped this would fix it, but it did not
network = sca.models.scgen(adata = adata,
hidden_layer_sizes=[256,128])
network.train()
Input In [21], in <cell line: 1>()
53 adata.X=adata.X.astype(np.double)
54 network = sca.models.scgen(adata = adata,
55 hidden_layer_sizes=[256,128])
---> 56 network.train()
57 adata=network.batch_removal(
58 adata,
59 batch_key=batch,
60 cell_label_key=cell_type,
61 return_latent=True)
62 # Edit object
63 # REmove unnecesary added slots
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/models/scgen/vaearith_model.py:44, in scgen.train(self, n_epochs, lr, eps, batch_size, **kwargs)
42 def train(self, n_epochs: int = 100, lr: float = 0.001, eps: float = 1e-8, batch_size = 32, **kwargs):
43 self.trainer = vaeArithTrainer(self.model, self.adata, batch_size, **kwargs)
---> 44 self.trainer.train(n_epochs, lr, eps)
45 self.is_trained_ = True
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/trainers/scgen/trainer.py:119, in vaeArithTrainer.train(self, n_epochs, lr, eps, **extras_kwargs)
117 if upper - lower > 1:
118 x_mb = x_mb.to(self.device) #to cuda or cpu
--> 119 reconstructions, mu, logvar = self.model(x_mb)
121 loss = self.model._loss_function(x_mb, reconstructions, mu, logvar)
123 self.optim.zero_grad()
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/models/scgen/vaearith.py:163, in vaeArith.forward(self, x)
162 def forward(self, x: torch.Tensor):
--> 163 mu, logvar = self.encoder(x)
164 z = self._sample_z(mu, logvar)
165 x_hat = self.decoder(z)
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/scarches/models/scgen/modules.py:51, in Encoder.forward(self, x)
49 def forward(self, x: torch.Tensor):
50 if self.FC is not None:
---> 51 x = self.FC(x)
53 mean = self.mean_encoder(x)
54 log_var = self.log_var_encoder(x)
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/container.py:141, in Sequential.forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/scib-pipeline/lib/python3.8/site-packages/torch/nn/modules/linear.py:103, in Linear.forward(self, input)
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Double but found Float
Koncopd commented
quick fix:
network.model.float()
But i will check why this even happens.