Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Configuration and shared fixtures for pytest."""

import shutil
import tempfile
from collections.abc import AsyncGenerator, Generator
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -35,6 +37,23 @@
TestSessionLocal = async_sessionmaker(bind=test_engine, expire_on_commit=False)


class _TempUploadsContainer:
"""Container to hold the temporary uploads directory path."""

path: Path | None = None


@pytest.fixture(scope="session", autouse=True)
def temp_uploads_dir_session() -> Generator[Path, None, None]:
"""Create a temporary directory for test uploads and clean up after session."""
temp_dir = Path(tempfile.mkdtemp(prefix="projectvote_test_uploads_"))
_TempUploadsContainer.path = temp_dir
yield temp_dir
# Clean up after all tests
if temp_dir.exists():
shutil.rmtree(temp_dir)


@pytest_asyncio.fixture(scope="session", autouse=True)
async def dispose_test_engine() -> AsyncGenerator[None, None]:
"""Ensure the test database engine is properly disposed after the test session."""
Expand Down Expand Up @@ -80,16 +99,20 @@ def get_test_board_members() -> list[str]:
def get_overridden_settings() -> Settings:
settings_data: dict[str, Any] = {
"board_members": ",".join(TEST_BOARD_MEMBERS),
"mail_driver": "console",
"mail_driver": "console", # Use console driver - no actual email sending
"mail_password": SecretStr("test-password"),
}
if settings_override:
settings_data.update(settings_override)
return Settings(**settings_data)
settings = Settings(**settings_data)
# Override project_root to use temp directory for file uploads
if _TempUploadsContainer.path:
settings.project_root = _TempUploadsContainer.path
return settings

mocker.patch(
"projectvote.backend.email_service.send_email", new_callable=mocker.AsyncMock
)
# Mock the send_email function to prevent actual email sending
# Patch where it's USED (in main.py), not where it's defined
mocker.patch("projectvote.backend.main.send_email", new_callable=mocker.AsyncMock)

app.dependency_overrides[get_db] = get_test_db
app.dependency_overrides[get_board_members] = get_test_board_members
Expand All @@ -100,3 +123,16 @@ def get_overridden_settings() -> Settings:
) as client:
yield client
app.dependency_overrides.clear()


@pytest.fixture(name="test_settings")
def test_settings_fixture() -> Settings:
"""Provide the test settings with temp directory override."""
settings = Settings(
board_members=",".join(TEST_BOARD_MEMBERS),
mail_driver="console",
mail_password=SecretStr("test-password"),
)
if _TempUploadsContainer.path:
settings.project_root = _TempUploadsContainer.path
return settings
36 changes: 4 additions & 32 deletions tests/test_file_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

import io
from http import HTTPStatus
from pathlib import Path

import pytest
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from projectvote.backend.config import Settings
from projectvote.backend.models import Application, Attachment, VoteRecord


@pytest.mark.asyncio
async def test_create_application_with_attachment(
client: AsyncClient, session: AsyncSession
client: AsyncClient, session: AsyncSession, test_settings: Settings
) -> None:
"""Test creating an application with a file attachment."""
# Arrange
Expand Down Expand Up @@ -54,14 +54,11 @@ async def test_create_application_with_attachment(
assert attachment.filename == file_name
assert attachment.mime_type == "text/plain"

# Verify file on disk
attachment_path = Path(attachment.filepath)
# Verify file on disk (attachment.filepath is relative to project_root)
attachment_path = test_settings.project_root / attachment.filepath
assert attachment_path.exists()
assert attachment_path.read_bytes() == file_content

# Clean up the created file
attachment_path.unlink()


@pytest.mark.asyncio
async def test_create_application_without_attachment(
Expand Down Expand Up @@ -139,11 +136,6 @@ async def test_get_attachment(client: AsyncClient, session: AsyncSession) -> Non
assert response.content == file_content
assert response.headers["content-type"] == "text/plain; charset=utf-8"

# Clean up the created file
attachment_path = Path(attachment.filepath)
if attachment_path.exists():
attachment_path.unlink()


@pytest.mark.asyncio
async def test_get_attachment_invalid_token(
Expand Down Expand Up @@ -176,11 +168,6 @@ async def test_get_attachment_invalid_token(
# Assert
assert response.status_code == HTTPStatus.NOT_FOUND

# Clean up
attachment_path = Path(attachment.filepath)
if attachment_path.exists():
attachment_path.unlink()


@pytest.mark.asyncio
async def test_get_attachment_wrong_application(
Expand Down Expand Up @@ -233,11 +220,6 @@ async def test_get_attachment_wrong_application(
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["detail"] == "Attachment not found."

# Clean up
attachment_path = Path(attachment_app1.filepath)
if attachment_path.exists():
attachment_path.unlink()


@pytest.mark.asyncio
async def test_archive_and_vote_details_include_attachments(
Expand Down Expand Up @@ -271,7 +253,6 @@ async def test_archive_and_vote_details_include_attachments(
assert "attachments" in app_in_archive
assert len(app_in_archive["attachments"]) == 1
assert app_in_archive["attachments"][0]["filename"] == file_name
attachment_id = app_in_archive["attachments"][0]["id"]

# --- Test /vote/{token} ---
vote_result = await session.execute(
Expand All @@ -287,12 +268,3 @@ async def test_archive_and_vote_details_include_attachments(
assert "attachments" in vote_data["application"]
assert len(vote_data["application"]["attachments"]) == 1
assert vote_data["application"]["attachments"][0]["filename"] == file_name

# --- Cleanup ---
attachment_result = await session.execute(
select(Attachment).where(Attachment.id == attachment_id)
)
attachment = attachment_result.scalar_one()
attachment_path = Path(attachment.filepath)
if attachment_path.exists():
attachment_path.unlink()