Bug: factories inside a nested pydantic model with custom types do not inherit the provider map
potatoUnicornDev opened this issue · 1 comments
potatoUnicornDev commented
Hello all, I really like the project, saves me tons of time when writing tests :).
I encountered a problem with nested pydantic models that have custom types.
The following example with a nested pydantic model only works if you override _get_or_create_factory
by replacing the get_provider_map
of the created factory with the one of the class. If you do not override this method you will get a ParameterException
from polyfactory.factories.pydantic_factory import ModelFactory
from pydantic import BaseModel
class MyClass:
def __init__(self, value: int) -> None:
self.value = value
class B(BaseModel):
my_class: MyClass
class Config:
arbitrary_types_allowed = True
class ANested(BaseModel):
b: B
class A(BaseModel):
my_class: MyClass
class Config:
arbitrary_types_allowed = True
class AFactory(ModelFactory):
__model__ = A
@classmethod
def get_provider_map(cls) -> dict[type, Any]:
providers_map = super().get_provider_map()
return {
**providers_map,
MyClass: lambda: MyClass(value=1),
}
class ANestedFactory(ModelFactory):
__model__ = ANested
@classmethod
def get_provider_map(cls) -> dict[type, Any]:
providers_map = super().get_provider_map()
return {
**providers_map,
MyClass: lambda: MyClass(value=1),
}
@classmethod
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
"""Get a factory from registered factories or generate a factory dynamically.
:param model: A model type.
:returns: A Factory sub-class.
"""
if factory := BaseFactory._factory_type_mapping.get(model):
return factory
if cls.__base_factory_overrides__:
for model_ancestor in model.mro():
if factory := cls.__base_factory_overrides__.get(model_ancestor):
return factory.create_factory(model)
for factory in reversed(BaseFactory._base_factories):
if factory.is_supported_type(model):
# what is was originally
return factory.create_factory(model)
# --- CHANGE START --- this makes it work
created_factory = factory.create_factory(model)
created_factory.get_provider_map = cls.get_provider_map
return created_factory
# --- CHANGE END ---
raise ParameterException(f"unsupported model type {model.__name__}") # pragma: no cover
Funding
- If you would like to see an issue prioritized, make a pledge towards it!
- We receive the pledge once the issue is completed & verified
potatoUnicornDev commented
or override it like this, if you don't care about overriding the provider map if it was not created
@classmethod
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
created_factory = super()._get_or_create_factory(model)
created_factory.get_provider_map = cls.get_provider_map
created_factory._get_or_create_factory = cls._get_or_create_factory
return created_factory