Skip to content

Switch from expecttest/assertExpectedInline to assertExpectedJournal #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ repos:
- id: check-symlinks
- id: destroyed-symlinks
- id: trailing-whitespace
exclude: '\.expected$'
- id: end-of-file-fixer
exclude: '\.expected$'
- id: check-yaml
- id: check-toml
- id: check-ast
Expand Down
148 changes: 148 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from __future__ import annotations

import collections
import importlib
import inspect
import operator
import os
from pathlib import Path
import re
import sys
from typing import TYPE_CHECKING
from typing import Callable
import unittest

import torch
from triton.testing import do_bench
Expand Down Expand Up @@ -139,3 +145,145 @@ def check_example(
)
skip_accuracy or torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
return code


class AssertExpectedJournal:
"""
Manages a <testfile>.expected file that contains expected output for TestCase.assertExpectedJournal() calls.

This replaces the previous `expecttest` assertExpectedInline approach by storing expected output
in external .expected files rather than inline strings in test files. This provides better
organization and avoids cluttering test files with large code blocks.

The .expected file format uses sections like:
--- assertExpectedJournal(TestClass.test_method)
expected output here

--- assertExpectedJournal(TestClass.test_method)
second expected output for same test

Environment variable EXPECTTEST_ACCEPT=1 can be used to update expected outputs.
"""

def __init__(self, cls: type[TestCase]) -> None:
pyfile = os.path.abspath(inspect.getfile(cls))
assert "/test/" in pyfile
assert pyfile.endswith(".py")
self.filename: Path = Path(pyfile[:-3] + ".expected")
self._cache: dict[str, list[str]] | None = None
self._current_id: str | None = None
self._current_index: int = 0

@property
def cache(self) -> dict[str, list[str]]:
if self._cache is None:
return self.reload()
return self._cache

def reload(self) -> dict[str, list[str]]:
if self.filename.exists():
data = self.filename.read_text()
else:
data = ""
result = collections.defaultdict(list)
for name, expected in re.findall(
r"--- assertExpectedJournal\(([^)]*)\)\n(.*?)(?=^--- assertExpectedJournal\(|\Z)",
data,
re.MULTILINE | re.DOTALL,
):
result[name].append(expected.strip())
self._cache = result
return result

def save(self) -> None:
tmp = f"{self.filename}.tmp{os.getpid()}"
with open(tmp, "w") as f:
f.write(
f"This file is automatically generated by assertExpectedJournal calls in {self.filename.stem}.py.\n"
"Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.\n\n"
)
for name, expected_values in sorted(
self.cache.items(), key=operator.itemgetter(0)
):
f.writelines(
f"--- assertExpectedJournal({name})\n{expected}\n\n"
for expected in expected_values
)
os.rename(tmp, self.filename)

@staticmethod
def normalize_id(test_id: str) -> str:
match = re.search(r"\b([^.]+\.[^.]+)$", test_id)
assert match, f"Test ID '{test_id}' does not match expected format"
return match.group(1)

def lookup(self, test_id: str, value: str) -> tuple[str, str]:
test_id = self.normalize_id(test_id)
if self._current_id != test_id:
self._current_id = test_id
self._current_index = 0

expected_values = self.cache[test_id]
if self._current_index < len(expected_values):
expected = expected_values[self._current_index]
else:
assert self._current_index == len(expected_values)
expected_values.append("")
expected = ""

value = value.strip()
if value != expected and os.environ.get("EXPECTTEST_ACCEPT", "0") not in {
"0",
"false",
"False",
"",
}:
expected_values[self._current_index] = value
# Reload to play nicer with other processes
self.reload()[test_id][:] = expected_values
self.save()
expected = value
print(
f"Expected output for {test_id} updated: {len(expected)} => {len(value)} bytes",
file=sys.stderr,
)
self._current_index += 1
return value, expected


class TestCase(unittest.TestCase):
maxDiff = 16384

@classmethod
def setUpClass(cls) -> None:
cls._expected_journal = AssertExpectedJournal(cls)
super().setUpClass()

@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass()
del cls._expected_journal

def assertExpectedJournal(self, value: str) -> None:
"""
Assert that the given value matches the expected output stored in <testfile>.expected.

This method replaces assertExpectedInline for code generation tests. Instead of storing
expected output as inline strings in test files, it uses external .expected files for
better organization.

Args:
value: The actual output to compare (usually generated Triton code)

Raises:
AssertionError: If value doesn't match expected output

Note:
Use EXPECTTEST_ACCEPT=1 environment variable to update expected outputs.
"""
value, expected = self._expected_journal.lookup(self.id(), value)
self.assertMultiLineEqual(
value,
expected,
msg="To accept the new output, re-run test with env EXPECTTEST_ACCEPT=1",
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [

[project.optional-dependencies]
dev = [
"expecttest",
"pytest",
"pre-commit"
]
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
expecttest
pytest
typing_extensions
pre-commit
Expand Down
Loading
Loading