Skip to content

Commit 57b59ec

Browse files
authored
Chore!: Refactor sqlmesh test output to use rich (#4715)
1 parent 60fa7bc commit 57b59ec

File tree

16 files changed

+459
-205
lines changed

16 files changed

+459
-205
lines changed

sqlmesh/core/console.py

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import uuid
88
import logging
99
import textwrap
10+
from itertools import zip_longest
1011
from pathlib import Path
1112
from hyperscript import h
1213
from rich.console import Console as RichConsole
@@ -26,6 +27,7 @@
2627
from rich.tree import Tree
2728
from sqlglot import exp
2829

30+
from sqlmesh.core.test.result import ModelTextTestResult
2931
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
3032
from sqlmesh.core.linter.rule import RuleViolation
3133
from sqlmesh.core.model import Model
@@ -46,6 +48,7 @@
4648
NodeAuditsErrors,
4749
format_destructive_change_msg,
4850
)
51+
from sqlmesh.utils.rich import strip_ansi_codes
4952

5053
if t.TYPE_CHECKING:
5154
import ipywidgets as widgets
@@ -316,6 +319,17 @@ def log_destructive_change(
316319
"""Display a destructive change error or warning to the user."""
317320

318321

322+
class UnitTestConsole(abc.ABC):
323+
@abc.abstractmethod
324+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
325+
"""Display the test result and output.
326+
327+
Args:
328+
result: The unittest test result that contains metrics like num success, fails, ect.
329+
target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect.
330+
"""
331+
332+
319333
class Console(
320334
PlanBuilderConsole,
321335
LinterConsole,
@@ -327,6 +341,7 @@ class Console(
327341
DifferenceConsole,
328342
TableDiffConsole,
329343
BaseConsole,
344+
UnitTestConsole,
330345
abc.ABC,
331346
):
332347
"""Abstract base class for defining classes used for displaying information to the user and also interact
@@ -460,18 +475,6 @@ def plan(
460475
fail. Default: False
461476
"""
462477

463-
@abc.abstractmethod
464-
def log_test_results(
465-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
466-
) -> None:
467-
"""Display the test result and output.
468-
469-
Args:
470-
result: The unittest test result that contains metrics like num success, fails, ect.
471-
output: The generated output from the unittest.
472-
target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect.
473-
"""
474-
475478
@abc.abstractmethod
476479
def show_sql(self, sql: str) -> None:
477480
"""Display to the user SQL."""
@@ -668,9 +671,7 @@ def plan(
668671
if auto_apply:
669672
plan_builder.apply()
670673

671-
def log_test_results(
672-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
673-
) -> None:
674+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
674675
pass
675676

676677
def show_sql(self, sql: str) -> None:
@@ -1952,10 +1953,12 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
19521953
):
19531954
plan_builder.apply()
19541955

1955-
def log_test_results(
1956-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
1957-
) -> None:
1956+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
19581957
divider_length = 70
1958+
1959+
self._log_test_details(result)
1960+
self._print("\n")
1961+
19591962
if result.wasSuccessful():
19601963
self._print("=" * divider_length)
19611964
self._print(
@@ -1972,9 +1975,13 @@ def log_test_results(
19721975
)
19731976
for test, _ in result.failures + result.errors:
19741977
if isinstance(test, ModelTest):
1975-
self._print(f"Failure Test: {test.model.name} {test.test_name}")
1978+
self._print(f"Failure Test: {test.path}::{test.test_name}")
19761979
self._print("=" * divider_length)
1977-
self._print(output)
1980+
1981+
def _captured_unit_test_results(self, result: ModelTextTestResult) -> str:
1982+
with self.console.capture() as capture:
1983+
self._log_test_details(result)
1984+
return strip_ansi_codes(capture.get())
19781985

19791986
def show_sql(self, sql: str) -> None:
19801987
self._print(Syntax(sql, "sql", word_wrap=True), crop=False)
@@ -2492,6 +2499,63 @@ def show_linter_violations(
24922499
else:
24932500
self.log_warning(msg)
24942501

2502+
def _log_test_details(self, result: ModelTextTestResult) -> None:
2503+
"""
2504+
This is a helper method that encapsulates the logic for logging the relevant unittest for the result.
2505+
The top level method (`log_test_results`) reuses `_log_test_details` differently based on the console.
2506+
2507+
Args:
2508+
result: The unittest test result that contains metrics like num success, fails, ect.
2509+
"""
2510+
tests_run = result.testsRun
2511+
errors = result.errors
2512+
failures = result.failures
2513+
skipped = result.skipped
2514+
is_success = not (errors or failures)
2515+
2516+
infos = []
2517+
if failures:
2518+
infos.append(f"failures={len(failures)}")
2519+
if errors:
2520+
infos.append(f"errors={len(errors)}")
2521+
if skipped:
2522+
infos.append(f"skipped={skipped}")
2523+
2524+
self._print("\n", end="")
2525+
2526+
for (test_case, failure), test_failure_tables in zip_longest( # type: ignore
2527+
failures, result.failure_tables
2528+
):
2529+
self._print(unittest.TextTestResult.separator1)
2530+
self._print(f"FAIL: {test_case}")
2531+
2532+
if test_description := test_case.shortDescription():
2533+
self._print(test_description)
2534+
self._print(f"{unittest.TextTestResult.separator2}")
2535+
2536+
if not test_failure_tables:
2537+
self._print(failure)
2538+
else:
2539+
for failure_table in test_failure_tables:
2540+
self._print(failure_table)
2541+
self._print("\n", end="")
2542+
2543+
for test_case, error in errors:
2544+
self._print(unittest.TextTestResult.separator1)
2545+
self._print(f"ERROR: {test_case}")
2546+
self._print(f"{unittest.TextTestResult.separator2}")
2547+
self._print(error)
2548+
2549+
# Output final report
2550+
self._print(unittest.TextTestResult.separator2)
2551+
test_duration_msg = f" in {result.duration:.3f}s" if result.duration else ""
2552+
self._print(
2553+
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
2554+
)
2555+
self._print(
2556+
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
2557+
)
2558+
24952559

24962560
def _cells_match(x: t.Any, y: t.Any) -> bool:
24972561
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""
@@ -2763,9 +2827,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
27632827
)
27642828
self.display(radio)
27652829

2766-
def log_test_results(
2767-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
2768-
) -> None:
2830+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
27692831
import ipywidgets as widgets
27702832

27712833
divider_length = 70
@@ -2781,12 +2843,14 @@ def log_test_results(
27812843
h(
27822844
"span",
27832845
{"style": {**shared_style, **success_color}},
2784-
f"Successfully Ran {str(result.testsRun)} Tests Against {target_dialect}",
2846+
f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}",
27852847
)
27862848
)
27872849
footer = str(h("span", {"style": shared_style}, "=" * divider_length))
27882850
self.display(widgets.HTML("<br>".join([header, message, footer])))
27892851
else:
2852+
output = self._captured_unit_test_results(result)
2853+
27902854
fail_color = {"color": "#db3737"}
27912855
fail_shared_style = {**shared_style, **fail_color}
27922856
header = str(h("span", {"style": fail_shared_style}, "-" * divider_length))
@@ -3137,21 +3201,22 @@ def stop_promotion_progress(self, success: bool = True) -> None:
31373201
def log_success(self, message: str) -> None:
31383202
self._print(message)
31393203

3140-
def log_test_results(
3141-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
3142-
) -> None:
3204+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
31433205
if result.wasSuccessful():
31443206
self._print(
31453207
f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n"
31463208
)
31473209
else:
3210+
self._print("```")
3211+
self._log_test_details(result)
3212+
self._print("```\n\n")
3213+
31483214
self._print(
31493215
f"**Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}**\n\n"
31503216
)
31513217
for test, _ in result.failures + result.errors:
31523218
if isinstance(test, ModelTest):
31533219
self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n")
3154-
self._print(f"```{output}```\n\n")
31553220

31563221
def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
31573222
if snapshot_names:
@@ -3530,9 +3595,7 @@ def show_model_difference_summary(
35303595
for modified in context_diff.modified_snapshots:
35313596
self._write(f" Modified: {modified}")
35323597

3533-
def log_test_results(
3534-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
3535-
) -> None:
3598+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
35363599
self._write("Test Results:", result)
35373600

35383601
def show_sql(self, sql: str) -> None:

sqlmesh/core/context.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import time
4141
import traceback
4242
import typing as t
43-
import unittest.result
4443
from functools import cached_property
4544
from io import StringIO
4645
from itertools import chain
@@ -2061,7 +2060,7 @@ def test(
20612060

20622061
test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
20632062

2064-
return run_tests(
2063+
result = run_tests(
20652064
model_test_metadata=test_meta,
20662065
models=self._models,
20672066
config=self.config,
@@ -2074,6 +2073,13 @@ def test(
20742073
default_catalog_dialect=self.config.dialect or "",
20752074
)
20762075

2076+
self.console.log_test_results(
2077+
result,
2078+
self.test_connection_config._engine_adapter.DIALECT,
2079+
)
2080+
2081+
return result
2082+
20772083
@python_api_analytics
20782084
def audit(
20792085
self,
@@ -2496,28 +2502,20 @@ def import_state(self, input_file: Path, clear: bool = False, confirm: bool = Tr
24962502

24972503
def _run_tests(
24982504
self, verbosity: Verbosity = Verbosity.DEFAULT
2499-
) -> t.Tuple[unittest.result.TestResult, str]:
2505+
) -> t.Tuple[ModelTextTestResult, str]:
25002506
test_output_io = StringIO()
25012507
result = self.test(stream=test_output_io, verbosity=verbosity)
25022508
return result, test_output_io.getvalue()
25032509

2504-
def _run_plan_tests(
2505-
self, skip_tests: bool = False
2506-
) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]:
2510+
def _run_plan_tests(self, skip_tests: bool = False) -> t.Optional[ModelTextTestResult]:
25072511
if not skip_tests:
2508-
result, test_output = self._run_tests()
2509-
if result.testsRun > 0:
2510-
self.console.log_test_results(
2511-
result,
2512-
test_output,
2513-
self.test_connection_config._engine_adapter.DIALECT,
2514-
)
2512+
result = self.test()
25152513
if not result.wasSuccessful():
25162514
raise PlanError(
25172515
"Cannot generate plan due to failing test(s). Fix test(s) and run again."
25182516
)
2519-
return result, test_output
2520-
return None, None
2517+
return result
2518+
return None
25212519

25222520
@property
25232521
def _model_tables(self) -> t.Dict[str, str]:

0 commit comments

Comments
 (0)