Skip to content

Commit d1c2ac8

Browse files
committed
Switch from expecttest/assertExpectedInline to assertExpectedJournal
Implements our own assertExpectedJournal that writes expected results to a separate file rather than inline. This: 1) Make test files easier to read/edit (especially for AI coding tools) 2) Fixes a race in expecttest where multiple edits to the same file errored stack-info: PR: #241, branch: jansel/stack/80
1 parent 98a3d61 commit d1c2ac8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+11871
-7459
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ repos:
55
- id: check-symlinks
66
- id: destroyed-symlinks
77
- id: trailing-whitespace
8+
exclude: '\.expected$'
89
- id: end-of-file-fixer
10+
exclude: '\.expected$'
911
- id: check-yaml
1012
- id: check-toml
1113
- id: check-ast

helion/_testing.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from __future__ import annotations
22

3+
import collections
34
import importlib
5+
import inspect
6+
import operator
7+
import os
48
from pathlib import Path
9+
import re
510
import sys
611
from typing import TYPE_CHECKING
712
from typing import Callable
13+
import unittest
814

915
import torch
1016
from triton.testing import do_bench
@@ -139,3 +145,145 @@ def check_example(
139145
)
140146
skip_accuracy or torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
141147
return code
148+
149+
150+
class AssertExpectedJournal:
151+
"""
152+
Manages a <testfile>.expected file that contains expected output for TestCase.assertExpectedJournal() calls.
153+
154+
This replaces the previous `expecttest` assertExpectedInline approach by storing expected output
155+
in external .expected files rather than inline strings in test files. This provides better
156+
organization and avoids cluttering test files with large code blocks.
157+
158+
The .expected file format uses sections like:
159+
--- assertExpectedJournal(TestClass.test_method)
160+
expected output here
161+
162+
--- assertExpectedJournal(TestClass.test_method)
163+
second expected output for same test
164+
165+
Environment variable EXPECTTEST_ACCEPT=1 can be used to update expected outputs.
166+
"""
167+
168+
def __init__(self, cls: type[TestCase]) -> None:
169+
pyfile = os.path.abspath(inspect.getfile(cls))
170+
assert "/test/" in pyfile
171+
assert pyfile.endswith(".py")
172+
self.filename: Path = Path(pyfile[:-3] + ".expected")
173+
self._cache: dict[str, list[str]] | None = None
174+
self._current_id: str | None = None
175+
self._current_index: int = 0
176+
177+
@property
178+
def cache(self) -> dict[str, list[str]]:
179+
if self._cache is None:
180+
return self.reload()
181+
return self._cache
182+
183+
def reload(self) -> dict[str, list[str]]:
184+
if self.filename.exists():
185+
data = self.filename.read_text()
186+
else:
187+
data = ""
188+
result = collections.defaultdict(list)
189+
for name, expected in re.findall(
190+
r"--- assertExpectedJournal\(([^)]*)\)\n(.*?)(?=^--- assertExpectedJournal\(|\Z)",
191+
data,
192+
re.MULTILINE | re.DOTALL,
193+
):
194+
result[name].append(expected.strip())
195+
self._cache = result
196+
return result
197+
198+
def save(self) -> None:
199+
tmp = f"{self.filename}.tmp{os.getpid()}"
200+
with open(tmp, "w") as f:
201+
f.write(
202+
f"This file is automatically generated by assertExpectedJournal calls in {self.filename.stem}.py.\n"
203+
"Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.\n\n"
204+
)
205+
for name, expected_values in sorted(
206+
self.cache.items(), key=operator.itemgetter(0)
207+
):
208+
f.writelines(
209+
f"--- assertExpectedJournal({name})\n{expected}\n\n"
210+
for expected in expected_values
211+
)
212+
os.rename(tmp, self.filename)
213+
214+
@staticmethod
215+
def normalize_id(test_id: str) -> str:
216+
match = re.search(r"\b([^.]+\.[^.]+)$", test_id)
217+
assert match, f"Test ID '{test_id}' does not match expected format"
218+
return match.group(1)
219+
220+
def lookup(self, test_id: str, value: str) -> tuple[str, str]:
221+
test_id = self.normalize_id(test_id)
222+
if self._current_id != test_id:
223+
self._current_id = test_id
224+
self._current_index = 0
225+
226+
expected_values = self.cache[test_id]
227+
if self._current_index < len(expected_values):
228+
expected = expected_values[self._current_index]
229+
else:
230+
assert self._current_index == len(expected_values)
231+
expected_values.append("")
232+
expected = ""
233+
234+
value = value.strip()
235+
if value != expected and os.environ.get("EXPECTTEST_ACCEPT", "0") not in {
236+
"0",
237+
"false",
238+
"False",
239+
"",
240+
}:
241+
expected_values[self._current_index] = value
242+
# Reload to play nicer with other processes
243+
self.reload()[test_id][:] = expected_values
244+
self.save()
245+
expected = value
246+
print(
247+
f"Expected output for {test_id} updated: {len(expected)} => {len(value)} bytes",
248+
file=sys.stderr,
249+
)
250+
self._current_index += 1
251+
return value, expected
252+
253+
254+
class TestCase(unittest.TestCase):
255+
maxDiff = 16384
256+
257+
@classmethod
258+
def setUpClass(cls) -> None:
259+
cls._expected_journal = AssertExpectedJournal(cls)
260+
super().setUpClass()
261+
262+
@classmethod
263+
def tearDownClass(cls) -> None:
264+
super().tearDownClass()
265+
del cls._expected_journal
266+
267+
def assertExpectedJournal(self, value: str) -> None:
268+
"""
269+
Assert that the given value matches the expected output stored in <testfile>.expected.
270+
271+
This method replaces assertExpectedInline for code generation tests. Instead of storing
272+
expected output as inline strings in test files, it uses external .expected files for
273+
better organization.
274+
275+
Args:
276+
value: The actual output to compare (usually generated Triton code)
277+
278+
Raises:
279+
AssertionError: If value doesn't match expected output
280+
281+
Note:
282+
Use EXPECTTEST_ACCEPT=1 environment variable to update expected outputs.
283+
"""
284+
value, expected = self._expected_journal.lookup(self.id(), value)
285+
self.assertMultiLineEqual(
286+
value,
287+
expected,
288+
msg="To accept the new output, re-run test with env EXPECTTEST_ACCEPT=1",
289+
)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ dependencies = [
2424

2525
[project.optional-dependencies]
2626
dev = [
27-
"expecttest",
2827
"pytest",
2928
"pre-commit"
3029
]

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
expecttest
21
pytest
32
typing_extensions
43
pre-commit

0 commit comments

Comments
 (0)