IrisRainbowNeko/HCP-Diffusion

生成图像时报错

Isotr0py opened this issue · 2 comments

  • 生成图像时最后vae decode的时候报错
  • 复现:运行Colab example的Generate Images部分
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 16>:16                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in decorate_context       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/hcpdiff/visualizer.py:202 in vis_images                  │
│                                                                                                  │
│   199 │   │   │   │   for feeder in self.pipe.unet.input_feeder:                                 │
│   200 │   │   │   │   │   feeder(ex_input_dict)                                                  │
│   201 │   │   │                                                                                  │
│ ❱ 202 │   │   │   images = self.pipe(prompt_embeds=emb_p, negative_prompt_embeds=emb_n, **kwar   │
│   203 │   │   return images                                                                      │
│   204 │                                                                                          │
│   205 │   def save_images(self, images, root, prompt, negative_prompt='', save_cfg=True):        │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in decorate_context       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_dif │
│ fusion.py:755 in __call__                                                                        │
│                                                                                                  │
│   752 │   │   │   │   │   │   callback(i, t, latents)                                            │
│   753 │   │                                                                                      │
│   754 │   │   if not output_type == "latent":                                                    │
│ ❱ 755 │   │   │   image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dic   │
│   756 │   │   │   image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embe   │
│   757 │   │   else:                                                                              │
│   758 │   │   │   image = latents                                                                │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py:46 in wrapper        │
│                                                                                                  │
│   43 │   def wrapper(self, *args, **kwargs):                                                     │
│   44 │   │   if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):             │
│   45 │   │   │   self._hf_hook.pre_forward(self)                                                 │
│ ❱ 46 │   │   return method(self, *args, **kwargs)                                                │
│   47 │                                                                                           │
│   48 │   return wrapper                                                                          │
│   49                                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoder_kl.py:191 in decode         │
│                                                                                                  │
│   188 │   │   │   decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]      │
│   189 │   │   │   decoded = torch.cat(decoded_slices)                                            │
│   190 │   │   else:                                                                              │
│ ❱ 191 │   │   │   decoded = self._decode(z).sample                                               │
│   192 │   │                                                                                      │
│   193 │   │   if not return_dict:                                                                │
│   194 │   │   │   return (decoded,)                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoder_kl.py:177 in _decode        │
│                                                                                                  │
│   174 │   │   if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] >   │
│   175 │   │   │   return self.tiled_decode(z, return_dict=return_dict)                           │
│   176 │   │                                                                                      │
│ ❱ 177 │   │   z = self.post_quant_conv(z)                                                        │
│   178 │   │   dec = self.decoder(z)                                                              │
│   179 │   │                                                                                      │
│   180 │   │   if not return_dict:                                                                │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py:463 in forward                  │
│                                                                                                  │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
│ ❱  463 │   │   return self._conv_forward(input, self.weight, self.bias)                          │
│    464                                                                                           │
│    465 class Conv3d(_ConvNd):                                                                    │
│    466 │   __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py:459 in _conv_forward            │
│                                                                                                  │
│    456 │   │   │   return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel  │
│    457 │   │   │   │   │   │   │   weight, bias, self.stride,                                    │
│    458 │   │   │   │   │   │   │   _pair(0), self.dilation, self.groups)                         │
│ ❱  459 │   │   return F.conv2d(input, weight, bias, self.stride,                                 │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same

注释掉下面代码中第206行的to_cpu(self.pipe.text_encoder)能正常出图:
https://github.com/7eu7d7/HCP-Diffusion/blob/c7dff0c185a34c6c20cf4dba0e0206fd25468589/hcpdiff/visualizer.py#L190-L215

新版已修复