litestar-org/polyfactory

Bug: factories inside a nested pydantic model with custom types do not inherit the provider map

potatoUnicornDev opened this issue · 1 comments

Hello all, I really like the project, saves me tons of time when writing tests :).

I encountered a problem with nested pydantic models that have custom types.

The following example with a nested pydantic model only works if you override _get_or_create_factory by replacing the get_provider_map of the created factory with the one of the class. If you do not override this method you will get a ParameterException

from polyfactory.factories.pydantic_factory import ModelFactory
from pydantic import BaseModel


class MyClass:
    def __init__(self, value: int) -> None:
        self.value = value


class B(BaseModel):
    my_class: MyClass

    class Config:
        arbitrary_types_allowed = True


class ANested(BaseModel):
    b: B


class A(BaseModel):
    my_class: MyClass

    class Config:
        arbitrary_types_allowed = True


class AFactory(ModelFactory):
    __model__ = A

    @classmethod
    def get_provider_map(cls) -> dict[type, Any]:
        providers_map = super().get_provider_map()

        return {
            **providers_map,
            MyClass: lambda: MyClass(value=1),
        }


class ANestedFactory(ModelFactory):
    __model__ = ANested

    @classmethod
    def get_provider_map(cls) -> dict[type, Any]:
        providers_map = super().get_provider_map()

        return {
            **providers_map,
            MyClass: lambda: MyClass(value=1),
        }

    @classmethod
    def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
        """Get a factory from registered factories or generate a factory dynamically.

        :param model: A model type.
        :returns: A Factory sub-class.

        """
        if factory := BaseFactory._factory_type_mapping.get(model):
            return factory

        if cls.__base_factory_overrides__:
            for model_ancestor in model.mro():
                if factory := cls.__base_factory_overrides__.get(model_ancestor):
                    return factory.create_factory(model)

        for factory in reversed(BaseFactory._base_factories):
            if factory.is_supported_type(model):
                # what is was originally
                return factory.create_factory(model)

                # --- CHANGE START --- this makes it work
                created_factory = factory.create_factory(model)
                created_factory.get_provider_map = cls.get_provider_map
                return created_factory
                # --- CHANGE END ---

        raise ParameterException(f"unsupported model type {model.__name__}")  # pragma: no cover


Funding

  • If you would like to see an issue prioritized, make a pledge towards it!
  • We receive the pledge once the issue is completed & verified
Fund with Polar

or override it like this, if you don't care about overriding the provider map if it was not created

@classmethod
    def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
        created_factory = super()._get_or_create_factory(model)
        created_factory.get_provider_map = cls.get_provider_map
        created_factory._get_or_create_factory = cls._get_or_create_factory
        return created_factory