litestar-org/polyfactory

Support for relationship building models in `SQLAlchemyFactory`

0x587 opened this issue ยท 10 comments

Summary

Suppose I have a series of cascading models like A, B, C. When I call the factory for A and enable the __set_relationships__ parameter, the factory will build model A and have B in each A, but no C in B.
just like:
A
|-B1
|- Expected C but did not
|-B2
|- Expected C but did not

I went and read the source code and tried to modify it, and found that it was because the __set_relationships__ parameter was not set when recursively constructing the factory. When I tried to set the parameter manually it triggered a recursive error, and this feature may require a larger change to the structure of the code.

I am trying to use this library for database mock fake data, if implemented this feature will be able to stage my work.

Basic Example

Here's a sample test

class Author(Base):
    __tablename__ = "authors"

    id: Any = Column(Integer(), primary_key=True)
    books: Any = orm.relationship(
        "Book",
        uselist=True,
        back_populates="author",
    )


class Book(Base):
    __tablename__ = "books"

    id: Any = Column(Integer(), primary_key=True)
    author_id: Any = Column(
        Integer(),
        ForeignKey(Author.id),
        nullable=False,
    )
    author: Any = orm.relationship(
        Author,
        uselist=False,
        back_populates="books",
    )
    comments: Any = orm.relationship(
        "Comment",
        uselist=True,
        back_populates="book",
    )


class Comment(Base):
    __tablename__ = "comments"

    id: Any = Column(Integer(), primary_key=True)
    book_id: Any = Column(
        Integer(),
        ForeignKey(Book.id),
        nullable=False,
    )
    book: Any = orm.relationship(
        Book,
        uselist=False,
        back_populates="comments",
    )

def test_recursive_relationship_resolution() -> None:
    class AuthorFactory(SQLAlchemyFactory[Author]):
        __model__ = Author
        __set_relationships__ = True

    result = AuthorFactory.build()
    assert isinstance(result.books, list)
    assert isinstance(result.books[0], Book)
    assert isinstance(result.books[0].comments, list)
    assert isinstance(result.books[0].comments[0], Comment)

Drawbacks and Impact

I try add code in
https://github.com/litestar-org/polyfactory/blob/8dc8e1a4594a75ad9a16e1b6f5041b6044fc4f51/polyfactory/factories/base.py#L650C1-L650C80

like that

if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation):
    factory = cls._get_or_create_factory(model=field_meta.type_args[0])
    if hasattr(factory, '__set_relationships__'):
        setattr(factory, '__set_relationships__', getattr(cls, '__set_relationships__'))
    ......

This only solves recursive construction in one-to-many or many-to-many relationships and is very inelegant, raising recursion errors when I try to extend it to one-to-one or many-to-one relationships.

like that

if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
    factory = cls._get_or_create_factory(model=unwrapped_annotation)
    if hasattr(factory, '__set_relationships__'):
        setattr(factory, '__set_relationships__', getattr(cls, '__set_relationships__'))

    return factory.build(
        **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
    )

Unresolved questions

No response


Note

While we are open for sponsoring on GitHub Sponsors and
OpenCollective, we also utilize Polar.sh to engage in pledge-based sponsorship.

Check out all issues funded or available for funding on our Polar.sh dashboard

  • If you would like to see an issue prioritized, make a pledge towards it!
  • We receive the pledge once the issue is completed & verified
  • This, along with engagement in the community, helps us know which features are a priority to our users.
Fund with Polar

I'm trying to build a factory for each model, building from the bottom up in the order in which the models are referenced.
Referring to the example above, the AuthorFactory constructs a number of authors, but when the BookFactory starts constructing, the book's author attribute reconstructs new authors instead of using the authors already constructed by the AuthorFactory.
Perhaps we could prioritize the use of already generated factories when constructing relationships instead of constructing a new factory.
Maybe we can correct it from this?

I've implemented a usable multilevel generation example with some factory configurations that I hope will be helpful.

Rule

My thinking is to only set __set_relationships__ on the factories of classes that have foreign keys, such as Grade and Course. If a relationship attribute points to a class with __set_relationships__ set to prevent recursion, the attribute returns Noneor [] to prevent recursion, such as Course.grades.


Here's a diagram of the foreign key associations for these models.
image

Model define

class DBStudent(Base):
    __tablename__ = "student"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    name: Mapped[str] = mapped_column(sqltypes.String, nullable=False)

    grades: Mapped[List["DBGrade"]] = relationship(back_populates="student")

class DBCourse(Base):
    __tablename__ = "course"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    name: Mapped[str] = mapped_column(sqltypes.String, nullable=False)

    grades: Mapped[List["DBGrade"]] = relationship(back_populates="course")
    teacher: Mapped['DBTeacher'] = relationship(back_populates="courses")
    fk_teacher_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("teacher.id"), nullable=False)

class DBTeacher(Base):
    __tablename__ = "teacher"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    courses: Mapped[List["DBCourse"]] = relationship(back_populates="teacher")

class DBGrade(Base):
    __tablename__ = "grade"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    grade: Mapped[float] = mapped_column(sqltypes.Float, nullable=False)

    student: Mapped['DBStudent'] = relationship(back_populates="grades")
    fk_student_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("student.id"), nullable=False)
    course: Mapped['DBCourse'] = relationship(back_populates="grades")
    fk_course_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("course.id"), nullable=False)

Factory define

class StudentFactory(SQLAlchemyFactory[DBStudent]):
    __sync_persistence__ = True

class CourseFactory(SQLAlchemyFactory[DBCourse]):
    __sync_persistence__ = True
    __set_relationships__ = True

    @classmethod
    def teacher(cls):
        return TeacherFactory.build()

    @classmethod
    def grades(cls):
        return []

class TeacherFactory(SQLAlchemyFactory[DBTeacher]):
    __sync_persistence__ = True

class GradeFactory(SQLAlchemyFactory[DBGrade]):
    __sync_persistence__ = True
    __set_relationships__ = True

    @classmethod
    def student(cls):
        return StudentFactory.build()

    @classmethod
    def course(cls):
        return CourseFactory.build()

Main call

async def main():
    async with sessionmanager.session() as session:
        GradeFactory.__async_session__ = session
        await GradeFactory.create_batch_async(size=10)

I think the settings here may cover some of the desired use cases https://polyfactory.litestar.dev/reference/factories/base.html#polyfactory.factories.base.BaseFactory.__set_as_default_factory_for_type__.

class MySQLAlchemyFactory(SQLAlchemyFactory):
     __is_base_factory__ = True
     __set_relationships__ = True

will allow changing this globally. Inheriting config for dynamically created subfactories is a known issue (see #426).


My thinking is to only set set_relationships on the factories of classes that have foreign keys, such as Grade and Course. If a relationship attribute points to a class with set_relationships set to prevent recursion, the attribute returns Noneor [] to prevent recursion, such as Course.grades.

I am not sure on this rule. This works for Grade but if calling TeacherFactory would generating some grades be expected? Maybe an enum/literal values would be appropriate to allow different strategy on a per factory basis, e.g. has __set_relationships__: bool | Literal['all', 'collections', 'foreign_key'] or similar

Hi @adhtruong. Thanks for the reply.

For your suggestion.

I think the settings here may cover some of the desired use cases https://polyfactory.litestar.dev/reference/factories/base.html#polyfactory.factories.base.BaseFactory.__set_as_default_factory_for_type__.

class MySQLAlchemyFactory(SQLAlchemyFactory):
     __is_base_factory__ = True
     __set_relationships__ = True

will allow changing this globally. Inheriting config for dynamically created subfactories is a known issue (see #426).

This would cause the program to go into infinite recursion, which is why I need to artificially disable certain field generation.Like Course.grades.

In the example I gave, the only way to get the expected generated result is to call the Grade Factory. I have created a network of model relationships and computed a spanning tree to define the factory with reference to this tree. Grade Factory is used as the root of this tree, so only Grade Factory can generate the expected results.

Giving __set_relationships__ more possible values might be a good solution, but would that be a big change to the whole library?

I wrote a script to apply networkx to compute the spanning tree of model relationships and then generate the factory definition code, which is my current solution.

Perhaps we can find a more elegant way to generate a series of correlated models.

Giving set_relationships more possible values might be a good solution, but would that be a big change to the whole library?

This configuration is only used by the SQLAlchemyFactory currently. A configuration here makes sense if useful to have these distinct behaves in the library itself.


Going back to your previous comment

This would cause the program to go into infinite recursion, which is why I need to artificially disable certain field generation.Like Course.grades.

would just keeping track of seen types beforehand handle this case? Note this may be a more naive solution than full graph resolution but may suffice for a lot of use cases. Here's a quick prototype of this based on an example in the docs

from __future__ import annotations

import contextlib
from typing import Any, Iterator, List, TypedDict

from sqlalchemy import ForeignKey, inspect
from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
from polyfactory.field_meta import FieldMeta


class Base(DeclarativeBase):
    ...


class Author(Base):
    __tablename__ = "authors"

    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str]

    books: Mapped[list["Book"]] = relationship("Book", uselist=True)


class Book(Base):
    __tablename__ = "books"

    id: Mapped[int] = mapped_column(primary_key=True)
    author_id: Mapped[int] = mapped_column(ForeignKey(Author.id))

    author: Mapped[Author] = relationship("Author", uselist=False)


class Context(TypedDict):
    seen: set[Any]


_context: Context = {"seen": set()}  # This should not be global


@contextlib.contextmanager
def add_to_context(model: type) -> Iterator[None]:
    _context["seen"].add(model)
    yield
    _context["seen"].remove(model)


class ImprovedSQLAlchemyFactory(SQLAlchemyFactory):
    __is_base_factory__ = True

    @classmethod
    def build(cls, **kwargs: Any) -> Any:
        with add_to_context(cls.__model__):
            return super().build(**kwargs)

    @classmethod
    def get_model_fields(cls) -> list[FieldMeta]:
        fields_meta = super().get_model_fields()

        table: Mapper = inspect(cls.__model__)  # type: ignore[assignment]
        for name, relationship_ in table.relationships.items():
            class_ = relationship_.entity.class_
            if class_ in _context["seen"]:
                continue

            annotation = class_ if not relationship_.uselist else List[class_]  # type: ignore[valid-type]
            fields_meta.append(
                FieldMeta.from_type(
                    name=name,
                    annotation=annotation,
                    random=cls.__random__,
                ),
            )

        return fields_meta


def test_sqla_factory() -> None:
    author: Author = ImprovedSQLAlchemyFactory.create_factory(Author).build()
    assert isinstance(author.books[0], Book)
    assert author.books[0].author is None

    book: Book = ImprovedSQLAlchemyFactory.create_factory(Book).build()
    assert book.author is not None
    assert book.author.books == []


def test_sqla_factory_create() -> None:
    engine = create_engine("sqlite:///:memory:")
    Base.metadata.create_all(engine)
    ImprovedSQLAlchemyFactory.__session__ = Session(engine)

    author: Author = ImprovedSQLAlchemyFactory.create_factory(Author).create_sync()
    assert isinstance(author.books[0], Book)
    assert author.books[0].author is author

    book = ImprovedSQLAlchemyFactory.create_factory(Book).create_sync()
    assert book.author is not None
    assert book.author.books == [book]

Edit: Add test for example with session

The results of this prototype don't seem to be what was expected.

This assert

assert author.books[0].author is None

I think should be assert author.books[0].author == author

@0x587 I think that difference comes from using create_sync/ create_async vs build. The former adds the instances to a session so these are resolved by SQLA. Out the box these won't necessarily be set correctly with build.

I've extended the example to use create which does match the expected assertion. Do you think this logic would meet your use case? I think this logic is generic enough to be in the library itself if so.

The example of this extension perfectly meets my needs, thank you.

I think this need commonly arises in scenarios where fake data needs to be generated for a database. Creating a subclass of such a SQLAlchemyFactory with a nice name to put into the library is necessary.

Maybe we can close this issue.

Great, thanks for checking and confirming!

I would be in favour of keeping this issue open just so the workaround is documented here as I agree with you this is a probably a common issue. I'll see if this feature is part of the main library or at least document the above as a workaround.