Skip to content

Commit a3790bc

Browse files
committed
Add some initial tests
1 parent 064d406 commit a3790bc

15 files changed

+1194
-2
lines changed

.github/workflows/ci-checks.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,29 @@ jobs:
3636

3737
- name: Run pre-commit
3838
run: uv run pre-commit run --all-files
39+
40+
job-image-processing-unit-tests:
41+
name: Image Processing Unit Tests
42+
runs-on: ubuntu-latest
43+
44+
steps:
45+
- name: Checkout code
46+
uses: actions/checkout@v3
47+
48+
- name: Set up Python
49+
uses: actions/setup-python@v3
50+
with:
51+
python-version: ${{ env.MIN_PYTHON_VERSION }}
52+
53+
- name: Install uv
54+
uses: astral-sh/setup-uv@v4
55+
with:
56+
enable-cache: true
57+
58+
- name: Install the project
59+
run: uv sync
60+
working-directory: image_processing
61+
62+
- name: Run PyTest
63+
run: uv run run pytest --cov=image_processing
64+
working-directory: image_processing

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ repos:
1818

1919
# Python checks
2020
- id: name-tests-test
21+
args: [--pytest-test-first]
2122

2223
# JSON files
2324
- id: pretty-format-json

image_processing/.coveragerc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[run]
2+
omit =
3+
tests/*
4+
*/__init__.py
5+
6+
[report]
7+
omit =
8+
tests/*
9+
*/__init__.py
10+
exclude_lines =
11+
if __name__ == "__main__":

image_processing/pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,9 @@ dev = [
4343
"pygments>=2.18.0",
4444
"ruff>=0.8.1",
4545
"python-dotenv>=1.0.1",
46+
"coverage>=7.6.12",
47+
"pytest>=8.3.4",
48+
"pytest-asyncio>=0.25.3",
49+
"pytest-cov>=6.0.0",
50+
"pytest-mock>=3.14.0",
4651
]

image_processing/pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
pythonpath = src/image_processing

image_processing/src/image_processing/__init__.py

Whitespace-only changes.

image_processing/src/image_processing/mark_up_cleaner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def get_sections(self, text) -> list:
1818
list: The sections related to text
1919
"""
2020
# Updated regex pattern to capture markdown headers like ### Header
21-
combined_pattern = r"(?<=\n|^)[#]+\s*(.*?)(?=\n)"
22-
doc_metadata = re.findall(combined_pattern, text, re.DOTALL)
21+
combined_pattern = r"^[#]+\s*(.*?)(?=\n|$)"
22+
doc_metadata = re.findall(combined_pattern, text, re.MULTILINE)
2323
return self.clean_sections(doc_metadata)
2424

2525
def get_figure_ids(self, text: str) -> list:

image_processing/tests/image_processing/__init__.py

Whitespace-only changes.
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import pytest
5+
import base64
6+
import io
7+
from PIL import Image
8+
from unittest.mock import AsyncMock, MagicMock
9+
from tenacity import RetryError
10+
from openai import OpenAIError, RateLimitError
11+
from figure_analysis import FigureAnalysis
12+
from layout_holders import FigureHolder
13+
from httpx import Response, Request
14+
15+
# ------------------------
16+
# Fixtures for Image Data
17+
# ------------------------
18+
19+
20+
@pytest.fixture
21+
def image_data_100x100():
22+
"""Return a base64-encoded PNG image of size 100x100."""
23+
img = Image.new("RGB", (100, 100), color="red")
24+
buffer = io.BytesIO()
25+
img.save(buffer, format="PNG")
26+
data = buffer.getvalue()
27+
return base64.b64encode(data).decode("utf-8")
28+
29+
30+
@pytest.fixture
31+
def image_data_50x50():
32+
"""Return a base64-encoded PNG image of size 50x50 (small image)."""
33+
img = Image.new("RGB", (50, 50), color="blue")
34+
buffer = io.BytesIO()
35+
img.save(buffer, format="PNG")
36+
data = buffer.getvalue()
37+
return base64.b64encode(data).decode("utf-8")
38+
39+
40+
# ------------------------
41+
# Fixtures for FigureHolder
42+
# ------------------------
43+
44+
45+
@pytest.fixture
46+
def valid_figure(image_data_100x100):
47+
"""
48+
A valid figure with sufficient size.
49+
Example: FigureHolder(figure_id='12345', description="Figure 1", uri="https://example.com/12345.png", offset=50, length=17)
50+
"""
51+
return FigureHolder(
52+
figure_id="12345",
53+
description="Figure 1",
54+
uri="https://example.com/12345.png",
55+
offset=50,
56+
length=17,
57+
data=image_data_100x100,
58+
)
59+
60+
61+
@pytest.fixture
62+
def small_figure(image_data_50x50):
63+
"""A figure whose image is too small (both dimensions below 75)."""
64+
return FigureHolder(
65+
figure_id="small1",
66+
description="",
67+
uri="https://example.com/small1.png",
68+
offset=0,
69+
length=10,
70+
data=image_data_50x50,
71+
)
72+
73+
74+
# ------------------------
75+
# Tests for get_image_size
76+
# ------------------------
77+
78+
79+
def test_get_image_size(valid_figure):
80+
analysis = FigureAnalysis()
81+
width, height = analysis.get_image_size(valid_figure)
82+
assert width == 100
83+
assert height == 100
84+
85+
86+
def test_get_image_size_small(small_figure):
87+
analysis = FigureAnalysis()
88+
width, height = analysis.get_image_size(small_figure)
89+
assert width == 50
90+
assert height == 50
91+
92+
93+
# ------------------------
94+
# Tests for understand_image_with_gptv
95+
# ------------------------
96+
97+
98+
@pytest.mark.asyncio
99+
async def test_understand_image_with_gptv_small(small_figure):
100+
"""
101+
If both width and height are below 75, the image should be considered too small,
102+
and its description set to "Irrelevant Image".
103+
"""
104+
analysis = FigureAnalysis()
105+
updated_figure = await analysis.understand_image_with_gptv(small_figure)
106+
assert updated_figure.description == "Irrelevant Image"
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_understand_image_with_gptv_success(valid_figure, monkeypatch):
111+
"""
112+
Test the success branch of understand_image_with_gptv.
113+
Patch AsyncAzureOpenAI to simulate a successful response.
114+
"""
115+
analysis = FigureAnalysis()
116+
117+
# Set up required environment variables.
118+
monkeypatch.setenv("OpenAI__ApiVersion", "2023-07-01-preview")
119+
monkeypatch.setenv("OpenAI__MiniCompletionDeployment", "deployment123")
120+
monkeypatch.setenv("OpenAI__Endpoint", "https://example.openai.azure.com")
121+
122+
# Create a dummy response object to mimic the client's response.
123+
dummy_response = MagicMock()
124+
dummy_choice = MagicMock()
125+
dummy_message = MagicMock()
126+
dummy_message.content = "Generated image description"
127+
dummy_choice.message = dummy_message
128+
dummy_response.choices = [dummy_choice]
129+
130+
# Create a dummy async client whose chat.completions.create returns dummy_response.
131+
dummy_client = AsyncMock()
132+
dummy_client.chat.completions.create.return_value = dummy_response
133+
134+
# Create a dummy async context manager that returns dummy_client.
135+
dummy_async_context = AsyncMock()
136+
dummy_async_context.__aenter__.return_value = dummy_client
137+
138+
# Patch AsyncAzureOpenAI so that instantiating it returns our dummy context manager.
139+
monkeypatch.setattr(
140+
"figure_analysis.AsyncAzureOpenAI", lambda **kwargs: dummy_async_context
141+
)
142+
143+
# Call the function and verify the description is set from the dummy response.
144+
updated_figure = await analysis.understand_image_with_gptv(valid_figure)
145+
assert updated_figure.description == "Generated image description"
146+
147+
# Now simulate the case when the API returns an empty description.
148+
dummy_message.content = ""
149+
updated_figure = await analysis.understand_image_with_gptv(valid_figure)
150+
assert updated_figure.description == "Irrelevant Image"
151+
152+
153+
@pytest.mark.asyncio
154+
async def test_understand_image_with_gptv_policy_violation(valid_figure, monkeypatch):
155+
"""
156+
If the OpenAI API raises an error with "ResponsibleAIPolicyViolation" in its message,
157+
the description should be set to "Irrelevant Image".
158+
"""
159+
analysis = FigureAnalysis()
160+
monkeypatch.setenv("OpenAI__ApiVersion", "2023-07-01-preview")
161+
monkeypatch.setenv("OpenAI__MiniCompletionDeployment", "deployment123")
162+
monkeypatch.setenv("OpenAI__Endpoint", "https://example.openai.azure.com")
163+
164+
# Define a dummy exception that mimics an OpenAI error with a ResponsibleAIPolicyViolation message.
165+
class DummyOpenAIError(OpenAIError):
166+
def __init__(self, message):
167+
self.message = message
168+
169+
async def dummy_create(*args, **kwargs):
170+
raise DummyOpenAIError("Error: ResponsibleAIPolicyViolation occurred")
171+
172+
dummy_client = AsyncMock()
173+
dummy_client.chat.completions.create.side_effect = dummy_create
174+
dummy_async_context = AsyncMock()
175+
dummy_async_context.__aenter__.return_value = dummy_client
176+
monkeypatch.setattr(
177+
"figure_analysis.AsyncAzureOpenAI", lambda **kwargs: dummy_async_context
178+
)
179+
180+
updated_figure = await analysis.understand_image_with_gptv(valid_figure)
181+
assert updated_figure.description == "Irrelevant Image"
182+
183+
184+
@pytest.mark.asyncio
185+
async def test_understand_image_with_gptv_general_error(valid_figure, monkeypatch):
186+
"""
187+
If the OpenAI API raises an error that does not include "ResponsibleAIPolicyViolation",
188+
the error should propagate.
189+
"""
190+
analysis = FigureAnalysis()
191+
monkeypatch.setenv("OpenAI__ApiVersion", "2023-07-01-preview")
192+
monkeypatch.setenv("OpenAI__MiniCompletionDeployment", "deployment123")
193+
monkeypatch.setenv("OpenAI__Endpoint", "https://example.openai.azure.com")
194+
195+
class DummyOpenAIError(OpenAIError):
196+
def __init__(self, message):
197+
self.message = message
198+
199+
async def dummy_create(*args, **kwargs):
200+
raise DummyOpenAIError("Some other error")
201+
202+
dummy_client = AsyncMock()
203+
dummy_client.chat.completions.create.side_effect = dummy_create
204+
dummy_async_context = AsyncMock()
205+
dummy_async_context.__aenter__.return_value = dummy_client
206+
monkeypatch.setattr(
207+
"figure_analysis.AsyncAzureOpenAI", lambda **kwargs: dummy_async_context
208+
)
209+
210+
with pytest.raises(RetryError) as e:
211+
await analysis.understand_image_with_gptv(valid_figure)
212+
213+
root_cause = e.last_attempt.exception()
214+
assert isinstance(root_cause, DummyOpenAIError)
215+
216+
217+
# ------------------------
218+
# Tests for analyse
219+
# ------------------------
220+
221+
222+
@pytest.mark.asyncio
223+
async def test_analyse_success(valid_figure, monkeypatch):
224+
"""
225+
Test the successful execution of the analyse method.
226+
Patch understand_image_with_gptv to return a figure with an updated description.
227+
"""
228+
analysis = FigureAnalysis()
229+
record = {"recordId": "rec1", "data": {"figure": valid_figure.model_dump()}}
230+
231+
async def dummy_understand(figure):
232+
figure.description = "Updated Description"
233+
return figure
234+
235+
monkeypatch.setattr(analysis, "understand_image_with_gptv", dummy_understand)
236+
result = await analysis.analyse(record)
237+
assert result["recordId"] == "rec1"
238+
assert result["data"]["updated_figure"]["description"] == "Updated Description"
239+
assert result["errors"] is None
240+
241+
242+
@pytest.mark.asyncio
243+
async def test_analyse_retry_rate_limit(valid_figure, monkeypatch):
244+
"""
245+
Simulate a RetryError whose last attempt raised a RateLimitError.
246+
The analyse method should return an error message indicating a rate limit error.
247+
"""
248+
analysis = FigureAnalysis()
249+
record = {"recordId": "rec2", "data": {"figure": valid_figure.model_dump()}}
250+
251+
# Create a mock request object
252+
dummy_request = Request(
253+
method="POST", url="https://api.openai.com/v1/chat/completions"
254+
)
255+
256+
# Create a mock response object with the request set
257+
dummy_response = Response(
258+
status_code=429, content=b"Rate limit exceeded", request=dummy_request
259+
)
260+
261+
# Create a RateLimitError instance
262+
dummy_rate_error = RateLimitError(
263+
message="Rate limit exceeded",
264+
response=dummy_response,
265+
body="Rate limit exceeded",
266+
)
267+
dummy_retry_error = RetryError(
268+
last_attempt=MagicMock(exception=lambda: dummy_rate_error)
269+
)
270+
271+
async def dummy_understand(figure):
272+
raise dummy_retry_error
273+
274+
monkeypatch.setattr(analysis, "understand_image_with_gptv", dummy_understand)
275+
result = await analysis.analyse(record)
276+
assert result["recordId"] == "rec2"
277+
assert result["data"] is None
278+
assert result["errors"] is not None
279+
assert "rate limit error" in result["errors"][0]["message"].lower()
280+
281+
282+
@pytest.mark.asyncio
283+
async def test_analyse_general_exception(valid_figure, monkeypatch):
284+
"""
285+
If understand_image_with_gptv raises a general Exception,
286+
analyse should catch it and return an error response.
287+
"""
288+
analysis = FigureAnalysis()
289+
record = {"recordId": "rec3", "data": {"figure": valid_figure.model_dump()}}
290+
291+
async def dummy_understand(figure):
292+
raise Exception("General error")
293+
294+
monkeypatch.setattr(analysis, "understand_image_with_gptv", dummy_understand)
295+
result = await analysis.analyse(record)
296+
assert result["recordId"] == "rec3"
297+
assert result["data"] is None
298+
assert result["errors"] is not None
299+
assert "check the logs for more details" in result["errors"][0]["message"].lower()

0 commit comments

Comments
 (0)