Skip to content

Async session yields as async_generator and cannot be used as actual session #945

@brunolnetto

Description

@brunolnetto

Given below implementation, I try to test using an asynchronous session. My attempt goes in the following way:

models.py

from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection

class Paginator:
    def __init__(
        self, 
        conn: Union[Connection, AsyncConnection],
        query: str, 
        params: dict = None, 
        batch_size: int = 10
    ):
        self.conn =  conn
        self.query = query
        self.params = params
        self.batch_size = batch_size
        self.current_offset = 0
        self.total_count = None

    async def _get_total_count_async(self) -> int:
        """Fetch the total count of records asynchronously."""
        count_query = f"SELECT COUNT(*) FROM ({self.query}) as total"
        query=text(count_query).bindparams(**(self.params or {}))
        result = await self.conn.execute(query)
        return result.scalar()

test_models.py

def prepare_db(uri: str, db_name: str):
    # Connect to the default PostgreSQL database
    default_engine = create_engine(uri, isolation_level="AUTOCOMMIT")
    
    # Drop the database if it exists
    with default_engine.connect() as conn:
        try:
            query=text(f"DROP DATABASE IF EXISTS {db_name};")
            conn.execute(query)
            print(f"Dropped database '{db_name}' if it existed.")
        except Exception as e:
            print(f"Error dropping database: {e}")

    # Create the database
    try:
        with default_engine.connect() as conn:
            conn.execute(text(f"CREATE DATABASE {db_name};"))
            print(f"Database '{db_name}' created.")
    except Exception as e:
        print(f"Error creating database: {e}")
        pytest.fail(f"Database creation failed: {e}")
        
    # Create the test table and populate it with data
    with default_engine.connect() as conn:
        try:
            # Create the table
            conn.execute(
                text("""
                    CREATE TABLE IF NOT EXISTS test_table (
                        id SERIAL PRIMARY KEY, 
                        name TEXT
                    );
                """)
            )

            # Check available tables
            result = conn.execute(
                text("""
                    SELECT 
                        table_name 
                    FROM 
                        information_schema.tables 
                    WHERE 
                        table_schema='public';
                """)
            )
            tables = [row[0] for row in result]

            # Clear existing data
            conn.execute(text("DELETE FROM test_table;"))

            # Insert new data
            conn.execute(
                text("""
                    INSERT INTO test_table (name) 
                    VALUES ('Alice'), ('Bob'), ('Charlie'), ('David');
                """)
            )
            conn.commit()

        except Exception as e:
            print(f"Error during table operations: {e}")
            pytest.fail(f"Table creation or data insertion failed: {e}")

@pytest.fixture(scope='function')
async def async_session():
    async_engine=create_async_engine('postgresql+asyncpg://localhost:5432/db')
    prepare_db('postgresql://localhost:5432', 'db')
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()

        yield session

        await session.rollback()

@pytest.mark.asyncio
async def test_get_total_count_async(async_session):
    # Prepare the paginator
    paginator = Paginator(
        conn=session,
        query="SELECT * FROM test_table",
        batch_size=2
    )

    # Perform the total count query asynchronously
    total_count = await paginator._get_total_count_async()

    # Assertion to verify the result
    assert total_count == 0

When I run the command pytest, I obtained following error: AttributeError: 'async_generator' object has no attribute 'execute'. I am pretty sure, there is an easy way to do so, but I am unaware of it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    needsinfoRequires additional information from the issue author

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions