yang-song/score_sde_pytorch

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

dorazhiyuyang opened this issue · 1 comments


RuntimeError Traceback (most recent call last)
Cell In[29], line 1
----> 1 x, n = sampling_fn(score_model)
2 show_samples(x)

File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:407, in get_pc_sampler..pc_sampler(model)
405 vec_t = torch.ones(shape[0], device=t.device) * t
406 x, x_mean = corrector_update_fn(x, vec_t, model=model)
--> 407 x, x_mean = predictor_update_fn(x, vec_t, model=model)
409 return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)

File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:341, in shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous)
339 else:
340 predictor_obj = predictor(sde, score_fn, probability_flow)
--> 341 return predictor_obj.update_fn(x, t)

File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:196, in ReverseDiffusionPredictor.update_fn(self, x, t)
195 def update_fn(self, x, t):
--> 196 f, G = self.rsde.discretize(x, t)
197 z = torch.randn_like(x)
198 x_mean = x - f

File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:104, in SDE.reverse..RSDE.discretize(self, x, t)
102 def discretize(self, x, t):
103 """Create discretized iteration rules for the reverse diffusion sampler."""
--> 104 f, G = discretize_fn(x, t)
105 rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
106 rev_G = torch.zeros_like(G) if self.probability_flow else G

File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:251, in VESDE.discretize(self, x, t)
248 timestep = (t * (self.N - 1) / self.T).long()
249 sigma = self.discrete_sigmas.to(t.device)[timestep]
250 adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
--> 251 self.discrete_sigmas[timestep - 1].to(t.device))
252 f = torch.zeros_like(x)
253 G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

Changing self.discrete_sigmas[timestep - 1].to(t.device) to self.discrete_sigmas.to(t.device)[timestep - 1] in this line of sde_lib.py seems to fix the problem.