lucidrains/imagen-pytorch

import Imagen returns BeartypeDecorHintParamDefaultViolation

Closed this issue · 2 comments

This is my code, as run on Google Colab on V100 GPU:

!pip install imagen-pytorch
import torch
from imagen_pytorch import Unet, Imagen

(the same as in the README)

I get this error upon just trying to import Imagen:

---------------------------------------------------------------------------
BeartypeDecorHintParamDefaultViolation    Traceback (most recent call last)
[<ipython-input-2-a3d7ec49eaab>](https://localhost:8080/#) in <cell line: 4>()
----> 4 from imagen_pytorch import Unet, Imagen

13 frames
[/usr/local/lib/python3.10/dist-packages/imagen_pytorch/__init__.py](https://localhost:8080/#) in <module>
----> 1 from imagen_pytorch.imagen_pytorch import Imagen, Unet
      2 from imagen_pytorch.imagen_pytorch import NullUnet
      3 from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
      4 from imagen_pytorch.trainer import ImagenTrainer
      5 from imagen_pytorch.version import __version__

[/usr/local/lib/python3.10/dist-packages/imagen_pytorch/imagen_pytorch.py](https://localhost:8080/#) in <module>
   1780 # main imagen ddpm class, which is a cascading DDPM from Ho et al.
   1781 
-> 1782 class Imagen(nn.Module):
   1783     def __init__(
   1784         self,

[/usr/local/lib/python3.10/dist-packages/imagen_pytorch/imagen_pytorch.py](https://localhost:8080/#) in Imagen()
   2287     @eval_decorator
   2288     @beartype
-> 2289     def sample(
   2290         self,
   2291         texts: List[str] = None,

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/decorcache.py](https://localhost:8080/#) in beartype(obj, conf)
     75     # trusted to violate PEP 561-compliance if they so choose. So... *shrug*
     76     elif obj is not None:
---> 77         return beartype_object(obj, conf)
     78     # Else, we were passed *NO* object to be decorated. In this case, this
     79     # decorator is in configuration rather than decoration mode.

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/decorcore.py](https://localhost:8080/#) in beartype_object(obj, conf, **kwargs)
     85     # Return either...
     86     return (
---> 87         _beartype_object_fatal(obj, conf=conf, **kwargs)
     88         # If this beartype configuration requests that this decorator raise
     89         # fatal exceptions at decoration time, defer to the lower-level

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/decorcore.py](https://localhost:8080/#) in _beartype_object_fatal(obj, **kwargs)
    134         # Else, this object is a non-class. In this case, this non-class
    135         # decorated with type-checking.
--> 136         beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
    137     )
    138 

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/_decornontype.py](https://localhost:8080/#) in beartype_nontype(obj, **kwargs)
    172 
    173     # Return a new callable decorating that callable with type-checking.
--> 174     return beartype_func(obj, **kwargs)  # type: ignore[return-value]
    175 
    176 # ....................{ DECORATORS ~ func                  }....................

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/_decornontype.py](https://localhost:8080/#) in beartype_func(func, conf, **kwargs)
    237 
    238     # Generate the raw string of Python statements implementing this wrapper.
--> 239     func_wrapper_code = generate_code(bear_call)
    240 
    241     # If that callable requires *NO* type-checking, silently reduce to a noop

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/wrapmain.py](https://localhost:8080/#) in generate_code(bear_call)
    116     # such parameters are annotated with unignorable type hints *OR* the empty
    117     # string otherwise.
--> 118     code_check_params = _code_check_args(bear_call)
    119 
    120     # Python code snippet type-checking the callable return if this return is

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/_wrapargs.py](https://localhost:8080/#) in code_check_args(bear_call)
    307         # annotated parameter.
    308         except Exception as exception:
--> 309             reraise_exception_placeholder(
    310                 exception=exception,
    311                 #FIXME: Embed the kind of parameter both here and above as well

[/usr/local/lib/python3.10/dist-packages/beartype/_util/error/utilerrraise.py](https://localhost:8080/#) in reraise_exception_placeholder(exception, target_str, source_str)
    136 
    137     # Re-raise this exception while preserving its original traceback.
--> 138     raise exception.with_traceback(exception.__traceback__)

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/_wrapargs.py](https://localhost:8080/#) in code_check_args(bear_call)
    203                 # If this parameter is optional *AND* the default value of this
    204                 # optional parameter violates this hint, raise an exception.
--> 205                 _die_if_arg_default_unbearable(
    206                     bear_call=bear_call, arg_default=arg_default, hint=hint)
    207                 # Else, this parameter is either optional *OR* the default value

[/usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/_wrapargs.py](https://localhost:8080/#) in _die_if_arg_default_unbearable(bear_call, arg_default, hint)
    471 
    472     # Raise this type of violation exception.
--> 473     die_if_unbearable(
    474         obj=arg_default,
    475         hint=hint,

[/usr/local/lib/python3.10/dist-packages/beartype/door/_doorcheck.py](https://localhost:8080/#) in die_if_unbearable(obj, hint, conf, exception_prefix)
    104     # Either raise an exception or emit a warning only if the passed object
    105     # violates this hint.
--> 106     func_raiser(obj)  # pyright: ignore[reportUnboundVariable]
    107 
    108 # ....................{ TESTERS                            }....................

<@beartype(__beartype_checker_31) at 0x5583892d0520> in __beartype_checker_31(__beartype_pith_0, __beartype_getrandbits, __beartype_exception_prefix, __beartype_get_violation, __beartype_hint, __beartype_conf)

BeartypeDecorHintParamDefaultViolation: Method imagen_pytorch.imagen_pytorch.Imagen.sample() parameter "texts" default value "None" violates type hint list[str], as <class "builtins.NoneType"> "None" not instance of list.

Any help would be appreciated. Attempting to import modules shouldn't be this difficult?

@benjaminorr oops, could you try the latest version?

@lucidrains it works now! Thank you so much for your quick response