lucidrains/deep-daze

RuntimeError: cosine_similarity requires both inputs to have the same sizes

leoauri opened this issue · 2 comments

Hi there,
I ran

pip install deep-daze
imagine "a prompt"

(and also other prompts) and got the error

Setting jit to False because torch version is not 1.7.1.
Starting up...
Imagining "a_prompt" from the depths of my weights...
Traceback (most recent call last):
  File "/opt/conda/bin/imagine", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.8/site-packages/deep_daze/cli.py", line 151, in main
    fire.Fire(train)
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deep_daze/cli.py", line 147, in train
    imagine()
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1056, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deep_daze/deep_daze.py", line 574, in forward
    self.model(self.clip_encoding, dry_run=True) # do one warmup step due to potential issue with CLIP and CUDA
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1056, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deep_daze/deep_daze.py", line 235, in forward
    general_loss = -self.loss_coef * torch.cosine_similarity(text_embed, image_embed, dim=-1).mean()
RuntimeError: cosine_similarity requires both inputs to have the same sizes, but x1 has [1, 512] and x2 has [4, 512]

Any idea why?

Running in nvcr.io/nvidia/pytorch:21.08-py3 container with NVIDIA Release 21.08 (build 26011915) and PyTorch Version 1.10.0a0+3fd9dcf

Thanks

setting batch size explicitly to 1 like
imagine --batch_size=1 "a prompt"
is a workaround... 🥴

It's not working