erm0l0v/django-fake-model

Support for Django Content type ?

Closed this issue · 3 comments

Hello,

Is there a support for content type architecture planned ?

I understand that technically using fake model is for avoiding django app that is related to content types.
But as library developpeur I use content type to some auto generation and doesn't have a django app in the library.

魔改一下,可以强行支持!
If we make some modifications, it can forcefully support it!
应该还不支持 GenericRelation ,需要阅读 contentype 和 AppConfig 代码
It probably doesn't support GenericRelation yet, you'll need to read the code for ContentType and AppConfig.

-------------------------show code -------------------------
Python 3.9.16
django 4.0.3

const.py

APP_LABEL = "django_fake_models"

typings.py

from typing import Type, Union
from typing_extensions import TypeAlias
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.mysql.schema import DatabaseSchemaEditor

AnySchemaEditorClass: TypeAlias = Type[Union[BaseDatabaseSchemaEditor, DatabaseSchemaEditor]]

utils.py

from contextlib import contextmanager
from functools import lru_cache
from django.db.utils import OperationalError


@contextmanager
def ignore_already_exists_operational_error():
    try:
        yield
    except OperationalError as e:
        if "already exists" in str(e):
            # 如果异常中包含 "already exists" 字符串,表示相关资源已经存在
            # 忽略继续执行
            ...
        else:
            # 如果异常不是因为资源已经存在引起的其他问题,则继续向上抛出异常
            raise


@contextmanager
def ignore_duplicate_foreign_key_operational_error():
    try:
        yield
    except OperationalError as e:
        if "Duplicate foreign key constraint name" in str(e):
            # 如果异常中包含 "already exists" 字符串,表示相关资源已经存在
            # 忽略继续执行
            ...
        else:
            # 如果异常不是因为资源已经存在引起的其他问题,则继续向上抛出异常
            raise


@lru_cache(maxsize=None)
def detect_django_maybe_keepdb():
    import sys
    import os
    import configparser

    for _dowhile0 in range(1):
        if "--keepdb" in sys.argv:
            break
        pytest_path = os.path.join(os.getcwd(), "pytest.ini")
        if os.path.exists(pytest_path):
            # read ini file
            config = configparser.ConfigParser()
            with open(pytest_path, "r") as fd:
                config.read_file(fd)
                options_pytest = config.options("pytest")
                if "addopts" not in options_pytest:
                    continue
                opt_addpts = config["pytest"]["addopts"].split("\n")
                opt_addpts = list(filter(None, map(str.strip, opt_addpts)))
                if "--reuse-db" in opt_addpts:
                    break
    else:
        return False
    return True

models.py

from __future__ import unicode_literals
from functools import wraps
import warnings
from django.core.management.color import no_style
from django.db import connection, models
from django.test import SimpleTestCase

from .const import APP_LABEL
from .typings import AnySchemaEditorClass
from .case_extension import CaseExtension
from .utils import (
    ignore_already_exists_operational_error,
    ignore_duplicate_foreign_key_operational_error,
    detect_django_maybe_keepdb,
)


class FakeModel(models.Model):
    """
    FakeModel

    Base class for all fake model.
    """

    class Meta:
        abstract = True
        app_label = APP_LABEL

    @classmethod
    def create_table(cls):
        """
        create_table

        Manually create a temporary table for model in test data base.
        :return:
        """
        schema_editor = getattr(connection, "schema_editor", None)
        if schema_editor:
            with schema_editor() as schema_editor:
                with ignore_already_exists_operational_error(), ignore_duplicate_foreign_key_operational_error():
                    schema_editor.create_model(cls)
        else:
            raw_sql, _ = connection.creation.sql_create_model(cls, no_style(), [])
            cls.delete_table()
            cursor = connection.cursor()
            try:
                cursor.execute(*raw_sql)
            finally:
                cursor.close()

    @classmethod
    def delete_table(cls):
        """
        delete_table

        Manually delete a temporary table for model in test data base.
        :return:
        """

        schema_editor: AnySchemaEditorClass = getattr(connection, "schema_editor", None)
        if schema_editor:
            with connection.schema_editor() as schema_editor:
                if not detect_django_maybe_keepdb():
                    schema_editor.delete_model(cls)
                else:
                    # 清空表内数据
                    cls.objects.all().delete()
        else:
            cursor = connection.cursor()
            try:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", "unknown table")
                    cursor.execute("DROP TABLE IF EXISTS {0}".format(cls._meta.db_table))
            finally:
                cursor.close()

    @classmethod
    def fake_me(cls, source):
        """
        fake_me

        Class or method decorator

        Class decorator: create temporary table for all tests in SimpleTestCase.
        Method decorator: create temporary model only for given test method.
        :param source: SimpleTestCase or test function
        :return:
        """
        if source and isinstance(source, type) and issubclass(source, SimpleTestCase):
            return cls._class_extension(source)
        elif hasattr(source, "__call__"):
            return cls._decorator(source)
        else:
            raise AttributeError("source - must be a SimpleTestCase subclass of function")

    @classmethod
    def _class_extension(cls, source):
        if not issubclass(source, CaseExtension):

            @wraps(
                source,
                assigned=(
                    "__module__",
                    "__name__",
                ),
                updated=[],
            )
            class __wrap_class__(source, CaseExtension):
                pass

            source = __wrap_class__
        source.append_model(cls)
        return source

    @classmethod
    def _decorator(cls, source):

        @wraps(source)
        def __wrapper__(*args, **kwargs):
            try:
                cls.create_table()
                return source(*args, **kwargs)
            finally:
                cls.delete_table()

        return __wrapper__

case_extension.py

from __future__ import unicode_literals
from django.test import SimpleTestCase

from .utils import detect_django_maybe_keepdb, ignore_duplicate_foreign_key_operational_error
from .case_apps import DjangoFakeModelsApp, DjangoFakeModelsCollector


class CaseExtension(SimpleTestCase):
    _models = tuple()

    @classmethod
    def append_model(cls, model):
        cls._models += (model,)

    def _pre_setup(self):
        super(CaseExtension, self)._pre_setup()
        with ignore_duplicate_foreign_key_operational_error():
            self._map_models("create_table")
        DjangoFakeModelsCollector.singleton().collect(*self._models)
        self._init_appconfig()

    def _post_teardown(self):
        # If we don't remove them in reverse order, then if we created table A
        # after table B and it has a foreignkey to table B, then trying to
        # remove B first will fail on some configurations, as documented
        # in issue #1
        if not detect_django_maybe_keepdb():
            DjangoFakeModelsCollector.singleton().uncollect(*self._models)
        self._map_models("delete_table", reverse=True)
        super(CaseExtension, self)._post_teardown()

    def _map_models(self, method_name, reverse=False):
        for model in reversed(self._models) if reverse else self._models:
            try:
                getattr(model, method_name)()
            except AttributeError:
                raise TypeError("{0} doesn't support table method {1}".format(model, method_name))

    def _init_appconfig(self):
        from django.apps import apps
        from .models import FakeModel

        app_label = FakeModel.Meta.app_label

        if apps.app_configs.get(app_label):
            return

        fake_app = DjangoFakeModelsApp(app_label, None)
        apps.app_configs[app_label] = fake_app
        return

case_apps.py

import threading
from django.apps import AppConfig
from django.contrib.contenttypes.models import ContentType

from .const import APP_LABEL


class DjangoFakeModelsCollector:
    _instance = None
    # 并发锁
    _locker = threading.Lock()
    _models = []

    @classmethod
    def singleton(cls):
        cls._locker.acquire()
        try:
            if not cls._instance:
                cls._instance = cls()
        finally:
            cls._locker.release()
        return cls._instance

    def collect(self, *models):
        self._locker.acquire()
        try:
            for model in models:
                if model in self._models:
                    continue
                self._models.append(model)
                # add model to content type cache
                ct = ContentType.objects.get_for_model(model)
                del ct
        finally:
            self._locker.release()
        return

    def uncollect(self, *models):
        self._locker.acquire()
        try:
            for model in models:
                self._models.remove(model)
                # remove model from content type cache
                ct = ContentType.objects.get_for_model(model)
                key = (ct.app_label, ct.model)
                ct_using = ContentType.objects.using
                ContentType.objects._cache.get(ct_using, {}).pop(key, None)
                ContentType.objects._cache.get(ct_using, {}).pop(ct.id, None)
        finally:
            self._locker.release()
        return

    def models(self):
        self._locker.acquire()
        try:
            return self._models
        finally:
            self._locker.release()


class DjangoFakeModelsApp(AppConfig):
    default_auto_field = "django.db.models.BigAutoField"
    name = APP_LABEL
    path = "."

    def get_model(self, model_name, require_ready=True):
        """
        Return the model with the given case-insensitive model_name.

        Raise LookupError if no model exists with this name.
        """
        try:
            models = DjangoFakeModelsCollector.singleton().models()
            return next((i for i in models if i._meta.model_name == model_name))
        except StopIteration:
            raise LookupError("App '%s' doesn't have a '%s' model." % (self.label, model_name))

    def get_models(self, include_auto_created=False, include_swapped=False):
        """
        Return an iterable of models.

        By default, the following models aren't included:

        - auto-created models for many-to-many relations without
        an explicit intermediate table,
        - models that have been swapped out.

        Set the corresponding keyword argument to True to include such models.
        Keyword arguments aren't documented; they're a private API.
        """
        models = DjangoFakeModelsCollector.singleton().models()
        for model in models:
            yield model

prev test

# fake_models.py
from __future__ import unicode_literals
from ...django_fake_model import models as f
from django.db import models


class DFM_MyFakeModel(f.FakeModel):
    name = models.CharField(max_length=100)

# test_class_extension.py
from __future__ import unicode_literals
from django.test import TransactionTestCase
from .fake_models import DFM_MyFakeModel


@DFM_MyFakeModel.fake_me
class MyFakeModelTests(TransactionTestCase):
    def test_create_model(self):
        DFM_MyFakeModel.objects.create(name="123")
        model = DFM_MyFakeModel.objects.get(name="123")
        self.assertEqual(model.name, "123")

# test_func_decorator.py
from __future__ import unicode_literals
from django.test import TransactionTestCase
from .fake_models import DFM_MyFakeModel


class MyFakeModelTests(TransactionTestCase):
    @DFM_MyFakeModel.fake_me
    def test_create_model(self):
        DFM_MyFakeModel.objects.create(name="123")
        model = DFM_MyFakeModel.objects.get(name="123")
        self.assertEqual(model.name, "123")


# test_related_model.py
from __future__ import unicode_literals

from django.db import models
from django.test import TransactionTestCase
from ...django_fake_model import models as f


class DFM_RelatedModel(f.FakeModel):
    text = models.CharField(max_length=400)


class DFM_MyModel(f.FakeModel):
    text = models.CharField(max_length=400)
    related_model = models.ForeignKey(DFM_RelatedModel, on_delete=models.CASCADE)


@DFM_MyModel.fake_me
@DFM_RelatedModel.fake_me
class TestRelatedModelsClassDecorator(TransactionTestCase):
    def test_create_models(self):
        related_model = DFM_RelatedModel()
        related_model.text = "qwerty"
        related_model.save()
        my_model = DFM_MyModel()
        my_model.test = "qwerty"
        my_model.related_model = related_model
        my_model.save()
        try:
            self.assertIsNotNone(my_model)
            self.assertIsNotNone(related_model)
        except Exception as ex:
            self.fail(ex.message)

new test

# test_content_type.py
from __future__ import unicode_literals

from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.test import TransactionTestCase

from ...django_fake_model.case_apps import DjangoFakeModelsCollector
from ...django_fake_model import models as f


class DFM_MyGenericFakeModel(f.FakeModel):
    name = models.CharField(max_length=100)
    entity_type = models.ForeignKey(
        ContentType,
        verbose_name="实体类型",
        on_delete=models.PROTECT,
        null=True,
        blank=True,
    )
    entity_id = models.PositiveIntegerField(null=True, blank=True)
    entity = GenericForeignKey("entity_type", "entity_id")


class DFM_MyEntityModel(f.FakeModel):
    text = models.CharField(max_length=400)


class DFM_MyEntity2Model(f.FakeModel):
    descr = models.CharField(max_length=400)


@DFM_MyGenericFakeModel.fake_me
@DFM_MyEntityModel.fake_me
class TestRelatedModelsClassDecorator(TransactionTestCase):
    def test_create_models(self):
        x_model = DFM_MyEntityModel()
        x_model.text = "xxx"
        x_model.save()
        generice_one = DFM_MyGenericFakeModel()
        generice_one.name = "qwerty"
        generice_one.entity = x_model
        generice_one.save()
        # test: instance refresh + vistor GenericForeignKey
        generice_one.refresh_from_db()
        entity = generice_one.entity
        self.assertIsNotNone(entity)
        # test: instance get + vistor GenericForeignKey
        generice_one = DFM_MyGenericFakeModel.objects.get(pk=generice_one.pk)
        entity = generice_one.entity
        self.assertIsNotNone(entity)
        self.assertEqual(entity.text, "xxx")

    def test_content_type(self):
        ct = ContentType.objects.get_for_model(DFM_MyEntityModel)
        # test: get_for_id
        entity_type = ContentType.objects.get_for_id(ct.id)
        self.assertEqual(entity_type.model, "dfm_myentitymodel")
        return


def _test_multi_entity_from_ct(testcase: TransactionTestCase):
    ct = ContentType.objects.get_for_model(DFM_MyEntityModel)
    ct2 = ContentType.objects.get_for_model(DFM_MyEntity2Model)
    # test
    entity_type: ContentType = ContentType.objects.get_for_id(ct.id)
    testcase.assertEqual(entity_type.model, "dfm_myentitymodel")
    #
    entity_type2: ContentType = ContentType.objects.get_for_id(ct2.id)
    testcase.assertEqual(entity_type2.model, "dfm_myentity2model")
    return


def _test_multi_entity_from_gfk(testcase: TransactionTestCase):
    x_model = DFM_MyEntityModel()
    x_model.text = "xxx"
    x_model.save()
    y_model = DFM_MyEntity2Model()
    y_model.descr = "yyy"
    y_model.save()
    generice_x = DFM_MyGenericFakeModel()
    generice_x.name = "g_x"
    generice_x.entity = x_model
    generice_x.save()
    generice_y = DFM_MyGenericFakeModel()
    generice_y.name = "g_y"
    generice_y.entity = y_model
    generice_y.save()
    #
    generice_x.refresh_from_db()
    generice_y.refresh_from_db()
    #
    testcase.assertEqual(generice_x.entity_type.model_class(), DFM_MyEntityModel)
    testcase.assertEqual(generice_y.entity_type.model_class(), DFM_MyEntity2Model)
    return


@DFM_MyGenericFakeModel.fake_me
@DFM_MyEntityModel.fake_me
@DFM_MyEntity2Model.fake_me
class TestMultiFakeMe1(TransactionTestCase):
    def test_ct(self):
        _test_multi_entity_from_ct(self)
        return

    def test_gfk(self):
        _test_multi_entity_from_gfk(self)
        return


@DFM_MyGenericFakeModel.fake_me
@DFM_MyEntityModel.fake_me
@DFM_MyEntity2Model.fake_me
class TestMultiFakeMe2(TransactionTestCase):
    def test_ct(self):
        _test_multi_entity_from_ct(self)
        return

    def test_gfk(self):
        _test_multi_entity_from_gfk(self)
        return


@DFM_MyGenericFakeModel.fake_me
@DFM_MyEntityModel.fake_me
@DFM_MyEntity2Model.fake_me
class TestAppConfig(TransactionTestCase):
    def test_collect(self):
        app_models = DjangoFakeModelsCollector.singleton().models()
        # error: --keepdb 时,会保留其他测试文件创建的 model class {{
        # self.assertSetEqual(
        #     set(app_models),
        #     set(
        #         [
        #             DFM_MyGenericFakeModel,
        #             DFM_MyEntityModel,
        #             DFM_MyEntity2Model,
        #         ]
        #     ),
        # )
        # }}
        for i in [
            DFM_MyGenericFakeModel,
            DFM_MyEntityModel,
            DFM_MyEntity2Model,
        ]:
            self.assertIn(
                i,
                set(app_models),
                msg=f"model {i} not in collected models: {app_models}",
            )
        return

-------------------------end code -------------------------

The code is bad, but it runs! :)

Thank you for the answer !
I honestly don't remember for what project this need come from but I hope it will be able to help someone in the futur.

给后面看到开发者留个记录吧 :)


这段代码还是不支持 ManyToMany + prefetch_related
Please leave a note for the developers to see later:
This piece of code still does not support ManyToMany + prefetch_related

-------------------------show code -------------------------
case_apps.py

import threading
from django.apps import AppConfig
from django.contrib.contenttypes.models import ContentType

from .const import APP_LABEL


class DjangoFakeModelsCollector:
    _instance = None
    # 并发锁
    _locker = threading.Lock()
    _models = []

    @classmethod
    def with_lock(cls):
        class LockManager:
            def __enter__(self):
                cls._locker.acquire()

            def __exit__(self, exc_type, exc_value, traceback):
                cls._locker.release()

        return LockManager()

    @classmethod
    def singleton(cls):
        cls._locker.acquire()
        try:
            if not cls._instance:
                cls._instance = cls()
        finally:
            cls._locker.release()
        return cls._instance

    def collect(self, *models):
        with self.with_lock():
            for model in models:
                if model in self._models:
                    continue
                self._models.append(model)
                # add model to content type cache
                ct = ContentType.objects.get_for_model(model)
                del ct

        return

    def uncollect(self, *models):
        with self.with_lock():
            for model in models:
                self._models.remove(model)
                # remove model from content type cache
                ct = ContentType.objects.get_for_model(model)
                key = (ct.app_label, ct.model)
                ct_using = ContentType.objects.using
                ContentType.objects._cache.get(ct_using, {}).pop(key, None)
                ContentType.objects._cache.get(ct_using, {}).pop(ct.id, None)

        return

    def models(self):
        with self.with_lock():
            return self._models


class DjangoFakeModelsApp(AppConfig):
    default_auto_field = "django.db.models.BigAutoField"
    name = APP_LABEL
    path = "."

    def get_model(self, model_name, require_ready=True):
        """
        Return the model with the given case-insensitive model_name.

        Raise LookupError if no model exists with this name.
        """
        try:
            models = DjangoFakeModelsCollector.singleton().models()
            return next((i for i in models if i._meta.model_name == model_name))
        except StopIteration:
            raise LookupError("App '%s' doesn't have a '%s' model." % (self.label, model_name))

    def get_models(self, include_auto_created=False, include_swapped=False):
        """
        Return an iterable of models.

        By default, the following models aren't included:

        - auto-created models for many-to-many relations without
        an explicit intermediate table,
        - models that have been swapped out.

        Set the corresponding keyword argument to True to include such models.
        Keyword arguments aren't documented; they're a private API.
        """
        models = DjangoFakeModelsCollector.singleton().models()
        for model in models:
            yield model


def attach_fake_appconfig():
    from django.apps import apps
    from .models import FakeModel

    app_label = FakeModel.Meta.app_label

    with DjangoFakeModelsCollector.singleton().with_lock():
        if apps.app_configs.get(app_label):
            return

        fake_app = DjangoFakeModelsApp(app_label, None)
        apps.app_configs[app_label] = fake_app
        """
        WARNINGS: django-fake-model - 默认不支持 prefetch_related
            * PS: 需要处理 relation_objects 来反向生成 ManyToManyRef
            * DEBUG: PR_RConnection._meta._populate_directed_relation_graph()
            # PS2: 不想用 app.ready() 那就只能手动清理缓存了,然后重新生成 relation_objects
        """
        apps.get_models.cache_clear()
        apps.get_swappable_settings_name.cache_clear()
    return


def apps_cache_clear():
    """多测试用例下还是会崩溃,所以需要手动清理缓存"""
    from django.apps import apps

    with DjangoFakeModelsCollector.singleton().with_lock():
        apps.get_models.cache_clear()
        apps.get_swappable_settings_name.cache_clear()

case_extension.py

from __future__ import unicode_literals
from django.test import SimpleTestCase

from .utils import detect_django_maybe_keepdb, ignore_duplicate_foreign_key_operational_error
from .case_apps import DjangoFakeModelsCollector, attach_fake_appconfig


class CaseExtension(SimpleTestCase):
    _models = tuple()

    @classmethod
    def append_model(cls, model):
        cls._models += (model,)

    def _pre_setup(self):
        super(CaseExtension, self)._pre_setup()
        with ignore_duplicate_foreign_key_operational_error():
            self._map_models("create_table")
        DjangoFakeModelsCollector.singleton().collect(*self._models)
        attach_fake_appconfig()

    def _post_teardown(self):
        # If we don't remove them in reverse order, then if we created table A
        # after table B and it has a foreignkey to table B, then trying to
        # remove B first will fail on some configurations, as documented
        # in issue #1
        if not detect_django_maybe_keepdb():
            DjangoFakeModelsCollector.singleton().uncollect(*self._models)
        self._map_models("delete_table", reverse=True)
        super(CaseExtension, self)._post_teardown()

    def _map_models(self, method_name, reverse=False):
        for model in reversed(self._models) if reverse else self._models:
            try:
                getattr(model, method_name)()
            except AttributeError:
                raise TypeError("{0} doesn't support table method {1}".format(model, method_name))

test_m2m_prefetch_related.py

from __future__ import unicode_literals

from django.db import models
from django.test import TransactionTestCase

from common.tutils.django_fake_model.case_apps import apps_cache_clear
from ...django_fake_model import models as f


class DFM_RConnection(f.FakeModel):
    """测一下自定义 related"""

    name = models.CharField(max_length=100)


class DFM_RModel(f.FakeModel):
    name = models.CharField(max_length=100)
    rconnections = models.ManyToManyField(
        DFM_RConnection,
        blank=True,
    )


@DFM_RModel.fake_me
@DFM_RConnection.fake_me
class TestM2mPrefetch(TransactionTestCase):
    def test_create_models(self):
        # clean
        apps_cache_clear()
        DFM_RConnection.objects.all().delete()
        DFM_RModel.objects.all().delete()
        # init
        r1 = DFM_RModel.objects.create(name="test1")
        r1_c1 = DFM_RConnection.objects.create(name="test1_c1")
        r1_c2 = DFM_RConnection.objects.create(name="test1_c2")
        r1.rconnections.add(r1_c1, r1_c2)
        #
        r2 = DFM_RModel.objects.create(name="test2")
        r2_c1 = DFM_RConnection.objects.create(name="test2_c1")
        r2_c2 = DFM_RConnection.objects.create(name="test2_c2")
        r2_c3 = DFM_RConnection.objects.create(name="test2_c3")
        r2.rconnections.add(r2_c1, r2_c2, r2_c3)

        def _assert_data(_qs):
            self.assertEqual(_qs.count(), 2)
            self.assertEqual(_qs[0].rconnections.count(), 2)
            self.assertEqual(_qs[1].rconnections.count(), 3)
            #
            self.assertEqual(_qs[0].name, "test1")
            self.assertEqual(_qs[1].name, "test2")
            #
            self.assertTrue(_qs[0].rconnections.filter(name="test1_c1").exists())
            self.assertTrue(_qs[0].rconnections.filter(name="test1_c2").exists())
            self.assertTrue(_qs[1].rconnections.filter(name="test2_c1").exists())
            self.assertTrue(_qs[1].rconnections.filter(name="test2_c2").exists())
            self.assertTrue(_qs[1].rconnections.filter(name="test2_c3").exists())

        # test: no-prefetch
        qs = DFM_RModel.objects.all()
        print(list(qs))  # query
        for i in qs:
            i_cache = getattr(i, "_prefetched_objects_cache", {}) or {}
            self.assertTrue("rconnections" not in i_cache)
        _assert_data(qs)
        #
        # test: prefetch
        qs = DFM_RModel.objects.all()
        qs = qs.prefetch_related("rconnections")
        print(list(qs))  # query
        for i in qs:
            i_cache = getattr(i, "_prefetched_objects_cache", {}) or {}
            self.assertTrue("rconnections" in i_cache)
        _assert_data(qs)
        return