argmaxinc/DiffusionKit

ValueError: Shapes (...) and (...) cannot be broadcast.

Closed this issue · 1 comments

Seems like when using long prompt i always got this error

ex

diffusionkit-cli --prompt "A realistic squirrel standing on its hind legs, holding a glowing blue laser saber in its front paws. The squirrel has detailed brown and gray fur, with sharp, focused eyes. The scene is set in a lush forest with sunlight filtering through the trees, creating soft shadows. The laser saber is sleek and futuristic, emitting a soft glow. The background is slightly blurred to emphasize the squirrel, adding a sense of adventure and action." --output-path image.png --model-version FLUX.1-schnell --steps 4
scikit-learn version 1.5.1 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.2.1 has not been tested with coremltools. You may run into unexpected errors. Torch 2.2.0 is the most recent version that has been tested.
WARNING:diffusionkit.mlx.scripts.generate_images:Disabling CFG for FLUX.1-schnell model.
INFO:diffusionkit.mlx.scripts.generate_images:Output image resolution will be 512x512
INFO:diffusionkit.mlx:Pre text encoding peak memory: 0.0GB
INFO:diffusionkit.mlx:Pre text encoding active memory: 0.0GB
Traceback (most recent call last):
  File "/Users/alban/.pyenv/versions/3.11.7/bin/diffusionkit-cli", line 8, in <module>
    sys.exit(cli())
             ^^^^^
  File "/Users/alban/.pyenv/versions/3.11.7/lib/python3.11/site-packages/diffusionkit/mlx/scripts/generate_images.py", line 164, in cli
    image, _ = sd.generate_image(
               ^^^^^^^^^^^^^^^^^^
  File "/Users/alban/.pyenv/versions/3.11.7/lib/python3.11/site-packages/diffusionkit/mlx/__init__.py", line 310, in generate_image
    conditioning, pooled_conditioning = self.encode_text(
                                        ^^^^^^^^^^^^^^^^^
  File "/Users/alban/.pyenv/versions/3.11.7/lib/python3.11/site-packages/diffusionkit/mlx/__init__.py", line 605, in encode_text
    conditioning_l = self.clip_l(tokens_l[[0], :])  # Ignore negative text
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alban/.pyenv/versions/3.11.7/lib/python3.11/site-packages/diffusionkit/mlx/clip.py", line 98, in __call__
    x = x + self.position_embedding.weight[:N]
        ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ValueError: Shapes (1,89,768) and (77,768) cannot be broadcast.

Hello,
For FLUX.1-schnell, the max context length for CLIP models and T5-XXL is 77 and 256 tokens, respectively.
This PR #18 will truncate extra tokens to fix the issue.