pytest-dev/pytest-asyncio

Async session yields as async_generator and cannot be used as actual session

brunolnetto opened this issue · 2 comments

Given below implementation, I try to test using an asynchronous session. My attempt goes in the following way:

models.py

from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection

class Paginator:
    def __init__(
        self, 
        conn: Union[Connection, AsyncConnection],
        query: str, 
        params: dict = None, 
        batch_size: int = 10
    ):
        self.conn =  conn
        self.query = query
        self.params = params
        self.batch_size = batch_size
        self.current_offset = 0
        self.total_count = None

    async def _get_total_count_async(self) -> int:
        """Fetch the total count of records asynchronously."""
        count_query = f"SELECT COUNT(*) FROM ({self.query}) as total"
        query=text(count_query).bindparams(**(self.params or {}))
        result = await self.conn.execute(query)
        return result.scalar()

test_models.py

def prepare_db(uri: str, db_name: str):
    # Connect to the default PostgreSQL database
    default_engine = create_engine(uri, isolation_level="AUTOCOMMIT")
    
    # Drop the database if it exists
    with default_engine.connect() as conn:
        try:
            query=text(f"DROP DATABASE IF EXISTS {db_name};")
            conn.execute(query)
            print(f"Dropped database '{db_name}' if it existed.")
        except Exception as e:
            print(f"Error dropping database: {e}")

    # Create the database
    try:
        with default_engine.connect() as conn:
            conn.execute(text(f"CREATE DATABASE {db_name};"))
            print(f"Database '{db_name}' created.")
    except Exception as e:
        print(f"Error creating database: {e}")
        pytest.fail(f"Database creation failed: {e}")
        
    # Create the test table and populate it with data
    with default_engine.connect() as conn:
        try:
            # Create the table
            conn.execute(
                text("""
                    CREATE TABLE IF NOT EXISTS test_table (
                        id SERIAL PRIMARY KEY, 
                        name TEXT
                    );
                """)
            )

            # Check available tables
            result = conn.execute(
                text("""
                    SELECT 
                        table_name 
                    FROM 
                        information_schema.tables 
                    WHERE 
                        table_schema='public';
                """)
            )
            tables = [row[0] for row in result]

            # Clear existing data
            conn.execute(text("DELETE FROM test_table;"))

            # Insert new data
            conn.execute(
                text("""
                    INSERT INTO test_table (name) 
                    VALUES ('Alice'), ('Bob'), ('Charlie'), ('David');
                """)
            )
            conn.commit()

        except Exception as e:
            print(f"Error during table operations: {e}")
            pytest.fail(f"Table creation or data insertion failed: {e}")

@pytest.fixture(scope='function')
async def async_session():
    async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
    prepare_db('postgresql://localhost:5432', 'db')
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()

        yield session

        await session.rollback()

@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
    # Prepare the paginator
    paginator = Paginator(
        conn=session,
        query="SELECT * FROM test_table",
        batch_size=2
    )

    # Perform the total count query asynchronously
    total_count = await paginator._get_total_count_async()

    # Assertion to verify the result
    assert total_count == 0

When I run the command pytest, I obtained following error: AttributeError: 'async_generator' object has no attribute 'execute'. I am pretty sure, there is an easy way to do so, but I am unaware of it.

It's very hard for me to help out, because I cannot execute your example code even after adding the missing imports.

My wild guess is that pytest-asyncio's strict mode will result in the described error when used with @pytest.fixture. If you use strict mode, you need to decorate async fixtures with @pytest_asyncio.fixture.

If that doesn't solve your issue, I ask you to make your code example self-contained, so that it can be copied, pasted, and run. Ideally, the reproducer will also become more minimal in the process.

The test involves creating and populating a test database, defining the fixtures, providing the implementation of the paginator and then writing the test down.

Since the fixture yields an async_generator, I just used async for async_session in async_session_factory: and moved on. Thanks any way.

Side note: Your profile picture looks neat. :D