litestar-org/polyfactory

Bug: wrong factory is selected

Closed this issue · 7 comments

Description

Creating a custom base factory (with __is_base_factory__ = True) may have unexpected side-effects even if the factory is not used explicitly. This happens when the factory is being picked as the (first) one that supports a given model type. I first ran into this issue for #193 and I thought I fixed it (by reversing the iteration order of base factories) but apparently it's more tricky.

In the MCVE below, MyBaseFactory is a Pydantic base model factory that appears to be unused but in fact it is selected as the base of the dynamically generated factory of MyEmbeddedModel. The latter is an odmantic model so it should have been handled by OdmanticModelFactory (which supports bson types since #193). The reason MyBaseFactory is selected instead is because:

  1. It is the last registered factory in BaseFactory._base_factories, and
  2. Base factories are iterated in reverse, and
  3. Odmantic models are also Pydantic models and thus MyBaseFactory.is_supported_type(MyEmbeddedModel) returns True.

Although the MCVE involves odmantic models, I think the issue is more general. IMO the crux of the matter is that the order of the _base_factories list (which reflects the order of their registration) should bear no relation to determining which factory to pick for a given model.

URL to code causing the issue

No response

MCVE

from datetime import datetime

import odmantic

from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory
from polyfactory.factories.pydantic_factory import ModelFactory


class MyEmbeddedModel(odmantic.EmbeddedModel):
    timestamp: datetime


class MyOdmanticModel(odmantic.Model):
    embedded: MyEmbeddedModel



class MyBaseFactory(ModelFactory):
    __is_base_factory__ = True


class MyOdmanticFactory(OdmanticModelFactory):
    __model__ = MyOdmanticModel


MyOdmanticFactory.build()

Steps to reproduce

1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

Screenshots

"In the format of: `![SCREENSHOT_DESCRIPTION](SCREENSHOT_LINK.png)`"

Logs

Traceback (most recent call last):
...
polyfactory.exceptions.ParameterException: Unsupported type: <class 'odmantic.bson._datetime'>

Litestar Version

2.0.1

Platform

  • Linux
  • Mac
  • Windows
  • Other (Please specify in the description above)
Fund with Polar

Hmm, ok thanks for the report.

Solution in place was obviously too naive. We need a way to set the base factory per type.

Any suggestions?

I couldn't come up with a (working) solution for the general need a way to set the base factory per type problem. Still, given that currently the issue manifests only with custom base factories (the five ones defined in polyfactory itself seem to work fine when iterated in reversed() order), what if we take a step back and disallow the registration of user-defined base factories?

AFAICT the use case for user-defined base factories is handling custom types. This can probably be achieved with a more explicit {custom_type: provider_callable} registration mechanism instead of subclassing an existing base factory and overriding its get_provider_map. This would have the extra benefit of decoupling the custom type provider from the enclosing class flavor (dataclass, TypedDict, Pydantic model, etc.).

For now I posted a draft PR that does the first part - disallow (or rather ignore) non-concrete factories from being registered in BaseFactory._base_factories. All tests are passing, including the one with custom base factory, even without introducing a new registration mechanism for custom types. I could follow up with a separate PR for it if you think it's worth going this direction.

I think the solution you are proposing @gsakkis is a bit too radical.

I would try to opt for something like this:

class MyEmbeddedModel(odmantic.EmbeddedModel):
    timestamp: datetime


class MyOdmanticModel(odmantic.Model):
    embedded: MyEmbeddedModel



class MyBaseFactory(ModelFactory):
    __is_base_factory__ = (odmantic.EmbeddedModel,) # a tuple of types OR boolean

If I understand correctly, the tuple of types would be the superclasses of the types which MyBaseFactory should handle? If so it wouldn't help here; what we want in this case is in fact the opposite: exclude odmantic.EmbeddedModel (and odmantic.Model) from being handled by MyBaseFactory (which are handled by default if __is_base_factory__ = True).

So perhaps __is_base_factory__ could be a callable __is_base_factory__(cls) that returns True if cls is a subclass of pydantic.BaseModel and is not subclass of odmantic._BaseODMModel. But what about beanie.Document, should it be excluded too? What if a new framework that extends pydantic comes along? Coming up with a complete exclusion list doesn't seem realistic, let alone user-friendly, way to define custom base factories.

If I understand correctly, the tuple of types would be the superclasses of the types which MyBaseFactory should handle? If so it wouldn't help here; what we want in this case is in fact the opposite: exclude odmantic.EmbeddedModel (and odmantic.Model) from being handled by MyBaseFactory (which are handled by default if __is_base_factory__ = True).

So perhaps __is_base_factory__ could be a callable __is_base_factory__(cls) that returns True if cls is a subclass of pydantic.BaseModel and is not subclass of odmantic._BaseODMModel. But what about beanie.Document, should it be excluded too? What if a new framework that extends pydantic comes along? Coming up with a complete exclusion list doesn't seem realistic, let alone user-friendly, way to define custom base factories.

Well, what we could do with a tuple of types is defined precedence using a map:

base_factory_type_map: dict[type, AbstractBaseFactory] = {
   odmantic.EmbeddesModel: OdmanticFactory,
}

Not sure I'm following but happy to check out an alternative PR, until then #199 works for me.

Another MCVE using plain vanilla dataclasses, no pydantic/odmantic models:

def test_multiple_base_factories() -> None:
    class Foo:
        def __init__(self, value: str) -> None:
            self.value = value

    class FooDataclassFactory(DataclassFactory):
        __is_base_factory__ = True

        @classmethod
        def get_provider_map(cls) -> Dict[Type, Any]:
            return {Foo: lambda: Foo("foo"), **super().get_provider_map()}

    class DummyDataclassFactory(DataclassFactory):
        __is_base_factory__ = True

    @dataclass
    class MyModelWithFoo:
        foo: Foo

    @dataclass
    class MyModel:
        nested: MyModelWithFoo

    class MyFactory(FooDataclassFactory):
        __model__ = MyModel

    MyFactory.build()

Output:

polyfactory.exceptions.ParameterException: Unsupported type: <class 'tests.test_factory_subclassing.test_multiple_base_factories.<locals>.Foo'>

       Either extend the providers map or add a factory function for this type.