element
- """
-
- stdout_body = re.split(r"_\s{3,}", text)[-1]
- stdout_filtered = list(
- filter(re.compile(r".*>E\s").match, stdout_body.splitlines())
- )
- html_body = "".join(f"{line}
" for line in stdout_filtered)
-
- test_runs = f"""Click here to expand
{html_body}
"""
- return test_runs
-
-
-class TestResultOutput(ipywidgets.VBox):
- """Class to display the test results in a structured way"""
-
- def __init__(
- self,
- name: str = "",
- syntax_error: bool = False,
- success: bool = False,
- test_outputs: List[TestResult] = None,
- ):
- output_config = format_success_failure(syntax_error, success, name)
- output_cell = ipywidgets.Output()
-
- with output_cell:
- custom_div_style = '"border: 1px solid; border-color: lightgray; background-color: whitesmoke; margin: 5px; padding: 10px;"'
- display(HTML("Test results
"))
- display(
- HTML(
- f"""{output_config.name}
{output_config.result}"""
- )
- )
-
- if not syntax_error:
- if len(test_outputs) > 0 and test_outputs[0].stdout:
- display(
- HTML(
- f"Code output:
{test_outputs[0].stdout}
"
- )
- )
-
- display(
- HTML(
- f"""
- We tested your solution solution_{name}
with {'1 input' if len(test_outputs) == 1 else str(len(test_outputs)) + ' different inputs'}.
- {"All tests passed!
" if success else "Below you find the details for each test run:"}
- """
- )
- )
-
- if not success:
- for test in test_outputs:
- test_name = test.test_name
- if match := re.search(r"\[.*?\]", test_name):
- test_name = re.sub(r"\[|\]", "", match.group())
-
- display(
- HTML(
- f"""
-
-
{"✔" if test.success else "❌"} Test {test_name}
- {format_long_stdout(filters.ansi.ansi2html(test.stderr)) if not test.success else ""}
-
- """
- )
- )
- else:
- display(
- HTML(
- "Your code cannot run because of the following error:
"
- )
- )
-
- super().__init__(children=[output_cell])
-
-
-class ResultCollector:
- """A class that will collect the result of a test. If behaves a bit like a visitor pattern"""
-
- def __init__(self) -> None:
- self.tests: Dict[str, TestResult] = {}
-
- def pytest_runtest_logreport(self, report: pytest.TestReport):
- # Only collect the results if it did not fail
- if report.when == "teardown" and report.nodeid not in self.tests:
- self.tests[report.nodeid] = TestResult(
- report.capstdout, report.capstderr, report.nodeid, not report.failed
- )
-
- def pytest_exception_interact(
- self, node: pytest.Item, call: pytest.CallInfo, report: pytest.TestReport
- ):
- # We need to collect the results and the stderr if the test failed
- if report.failed:
- self.tests[node.nodeid] = TestResult(
- report.capstdout,
- str(call.excinfo.getrepr() if call.excinfo else ""),
- report.nodeid,
- False,
- )
-
-
-@pytest.fixture
-def function_to_test():
- """Function to test, overridden at runtime by the cell magic"""
-
-
@magics_class
class TestMagic(Magics):
"""Class to add the test cell magic"""
- shell: InteractiveShell
+ shell: Optional[InteractiveShell] # type: ignore
+ cells: Dict[str, int] = {}
@cell_magic
def ipytest(self, line: str, cell: str):
"""The `%%ipytest` cell magic"""
+ # Check that the magic is called from a notebook
+ if not self.shell:
+ raise InstanceNotFoundError("InteractiveShell")
+
# Get the module containing the test(s)
module_name = get_module_name(line, self.shell.user_global_ns)
@@ -253,7 +95,21 @@ def ipytest(self, line: str, cell: str):
functions_to_run[name.removeprefix("solution_")] = function
if not functions_to_run:
- raise ValueError("No function to test defined in the cell")
+ raise FunctionNotFoundError
+
+ # Store execution count information for each cell
+ if (ipython := get_ipython()) is None:
+ raise InstanceNotFoundError("IPython")
+
+ cell_id = ipython.parent_header["metadata"]["cellId"]
+ if cell_id in self.cells:
+ self.cells[cell_id] += 1
+ else:
+ self.cells[cell_id] = 1
+
+ # Parse the AST tree of the file containing the test functions,
+ # to extract and store all information of function definitions and imports
+ ast_parser = AstParser(module_file)
outputs = []
for name, function in functions_to_run.items():
@@ -277,12 +133,19 @@ def ipytest(self, line: str, cell: str):
pytest_stdout.getvalue()
pytest_stderr.getvalue()
+ # reset execution count on success
+ success = result == pytest.ExitCode.OK
+ if success:
+ self.cells[cell_id] = 0
+
outputs.append(
TestResultOutput(
+ list(result_collector.tests.values()),
name,
False,
- result == pytest.ExitCode.OK,
- list(result_collector.tests.values()),
+ success,
+ self.cells[cell_id],
+ ast_parser.get_solution_code(name),
)
)
@@ -292,31 +155,16 @@ def ipytest(self, line: str, cell: str):
display(
Javascript(
"""
- var output_divs = document.querySelectorAll(".jp-OutputArea-executeResult");
- for (let div of output_divs) {
- div.setAttribute("style", "display: none;");
- }
- """
- )
- )
-
- # remove syntax error styling
- display(
- Javascript(
+ var output_divs = document.querySelectorAll(".jp-OutputArea-executeResult");
+ for (let div of output_divs) {
+ div.setAttribute("style", "display: none;");
+ }
"""
- var output_divs = document.querySelectorAll(".jp-Cell-outputArea");
- for (let div of output_divs) {
- var div_str = String(div.innerHTML);
- if (div_str.includes("alert-success") | div_str.includes("alert-danger")) {
- div.setAttribute("style", "padding-bottom: 0;");
- }
- }
- """
)
)
- except Exception:
- # Catches syntax errors and creates a custom warning
+ except SyntaxError:
+ # Catches syntax errors
display(
TestResultOutput(
syntax_error=True,
@@ -324,21 +172,6 @@ def ipytest(self, line: str, cell: str):
)
)
- display(
- Javascript(
- """
- var syntax_error_containers = document.querySelectorAll('div[data-mime-type="application/vnd.jupyter.stderr"]');
- for (let container of syntax_error_containers) {
- var syntax_error_div = container.parentNode;
- var container_div = syntax_error_div.parentNode;
- const container_style = "position: relative; padding-bottom: " + syntax_error_div.clientHeight + "px;";
- container_div.setAttribute("style", container_style);
- syntax_error_div.setAttribute("style", "position: absolute; bottom: 10px;");
- }
- """
- )
- )
-
def load_ipython_extension(ipython):
"""
@@ -346,4 +179,5 @@ def load_ipython_extension(ipython):
can be loaded via `%load_ext module.path` or be configured to be
autoloaded by IPython at startup time.
"""
+
ipython.register_magics(TestMagic)
diff --git a/tutorial/tests/testsuite_helpers.py b/tutorial/tests/testsuite_helpers.py
new file mode 100644
index 00000000..7e52cd36
--- /dev/null
+++ b/tutorial/tests/testsuite_helpers.py
@@ -0,0 +1,380 @@
+import ast
+import pathlib
+import re
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Set
+
+import ipywidgets
+import pytest
+from IPython.core.display import HTML, Javascript
+from IPython.display import Code, display
+from nbconvert import filters
+
+
+@dataclass
+class TestResult:
+ """Container class to store the test results when we collect them"""
+
+ stdout: str
+ stderr: str
+ test_name: str
+ success: bool
+
+
+@dataclass
+class OutputConfig:
+ """Container class to store the information to display in the test output"""
+
+ style: str
+ name: str
+ result: str
+
+
+def format_success_failure(
+ syntax_error: bool, success: bool, name: str
+) -> OutputConfig:
+ """
+ Depending on the test results, returns a fragment that represents
+ either an error message, a success message, or a syntax error warning
+ """
+
+ if syntax_error:
+ return OutputConfig(
+ "alert-warning",
+ "Tests COULD NOT RUN for this cell.",
+ "🤔 Careful, looks like you have a syntax error.",
+ )
+
+ if not success:
+ return OutputConfig(
+ "alert-danger",
+ f"Tests FAILED for the function {name}
",
+ "😱 Your solution was not correct!",
+ )
+
+ return OutputConfig(
+ "alert-success",
+ f"Tests PASSED for the function {name}
",
+ "🙌 Congratulations, your solution was correct!",
+ )
+
+
+def format_long_stdout(text: str) -> str:
+ """
+ Format the error message lines of a long test stdout
+ as an HTML that expands, by using the element
+ """
+
+ stdout_body = re.split(r"_\s{3,}", text)[-1]
+ stdout_filtered = list(
+ filter(re.compile(r".*>E\s").match, stdout_body.splitlines())
+ )
+ stdout_str = "".join(f"{line}
" for line in stdout_filtered)
+ stdout_edited = re.sub(r"E\s+[\+\s]*", "", stdout_str)
+ stdout_edited = re.sub(
+ r"\bfunction\ssolution_[\w\s\d]*", "your_solution", stdout_edited
+ )
+ stdout_edited = re.sub(r"\breference_\w+\(", "reference_solution(", stdout_edited)
+
+ test_runs = f"""
+
+ Click here to expand
+ {stdout_edited}
+
+ """
+ return test_runs
+
+
+class TestResultOutput(ipywidgets.VBox):
+ """Class to display the test results in a structured way"""
+
+ def __init__(
+ self,
+ test_outputs: Optional[List[TestResult]] = None,
+ name: str = "",
+ syntax_error: bool = False,
+ success: bool = False,
+ cell_exec_count: int = 0,
+ solution_body: str = "",
+ ):
+ reveal_solution = cell_exec_count > 2 or success
+ output_config = format_success_failure(syntax_error, success, name)
+ output_cell = ipywidgets.Output()
+
+ # For each test, create an alert box with the appropriate message,
+ # print the code output and display code errors in case of failure
+ with output_cell:
+ custom_div_style = '"border: 1px solid; border-color: lightgray; background-color: #FAFAFA; margin: 5px; padding: 10px;"'
+ display(HTML("Test results
"))
+ display(
+ HTML(
+ f"""{output_config.name}
{output_config.result}"""
+ )
+ )
+
+ if not syntax_error and isinstance(test_outputs, List):
+ if len(test_outputs) > 0 and test_outputs[0].stdout:
+ display(
+ HTML(
+ f"""
+ 👉 Code output:
+ {test_outputs[0].stdout}
+ """
+ )
+ )
+
+ display(
+ HTML(
+ f"""
+ 👉 We tested your solution solution_{name}
with {'1 input' if len(test_outputs) == 1 else str(len(test_outputs)) + ' different inputs'}.
+ {"All tests passed!
" if success else "Below you find the details for each test run:"}
+ """
+ )
+ )
+
+ if not success:
+ for test in test_outputs:
+ test_name = test.test_name
+ if match := re.search(r"\[.*?\]", test_name):
+ test_name = re.sub(r"\[|\]", "", match.group())
+
+ display(
+ HTML(
+ f"""
+
+
{"✔" if test.success else "❌"} Test {test_name}
+ {format_long_stdout(filters.ansi.ansi2html(test.stderr)) if not test.success else ""}
+
+ """
+ )
+ )
+
+ if not reveal_solution:
+ display(
+ HTML(
+ f"📝 A proposed solution will appear after {3 - cell_exec_count} more failed attempt{'s' if cell_exec_count < 2 else ''}.
"
+ )
+ )
+ else:
+ # display syntax error custom alert
+ display(
+ HTML(
+ "👉 Your code cannot run because of the following error:
"
+ )
+ )
+
+ # fix syntax error styling
+ display(
+ Javascript(
+ """
+ var syntax_error_containers = document.querySelectorAll('div[data-mime-type="application/vnd.jupyter.stderr"]');
+ for (let container of syntax_error_containers) {
+ var syntax_error_div = container.parentNode;
+ var container_div = syntax_error_div.parentNode;
+ const container_style = "position: relative; padding-bottom: " + syntax_error_div.clientHeight + "px;";
+ container_div.setAttribute("style", container_style);
+ syntax_error_div.setAttribute("style", "position: absolute; bottom: 0;");
+ }
+ """
+ )
+ )
+
+ # fix css styling
+ display(
+ Javascript(
+ """
+ var divs = document.querySelectorAll(".jupyter-widget-Collapse-contents");
+ for (let div of divs) {
+ div.setAttribute("style", "padding: 0");
+ }
+ divs = document.querySelectorAll(".widget-vbox");
+ for (let div of divs) {
+ div.setAttribute("style", "background: #EAF0FB");
+ }
+ """
+ )
+ )
+
+ display(
+ Javascript(
+ """
+ var output_divs = document.querySelectorAll(".jp-Cell-outputArea");
+ for (let div of output_divs) {
+ var div_str = String(div.innerHTML);
+ if (div_str.includes("alert-success") | div_str.includes("alert-danger")) {
+ div.setAttribute("style", "padding-bottom: 0;");
+ }
+ }
+ """
+ )
+ )
+
+ # After 3 failed attempts or on success, reveal the proposed solution
+ # using a Code box inside an Accordion to display the str containing all code
+ solution_output = ipywidgets.Output()
+ with solution_output:
+ display(HTML("👉 Proposed solution:
"))
+
+ solution_code = ipywidgets.Output()
+ with solution_code:
+ display(Code(language="python", data=f"{solution_body}"))
+
+ solution_accordion = ipywidgets.Accordion(
+ titles=("Click here to reveal",), children=[solution_code]
+ )
+
+ solution_box = ipywidgets.Box(
+ children=[solution_output, solution_accordion],
+ layout={
+ "display": "block" if reveal_solution else "none",
+ "padding": "0 20px 0 0",
+ },
+ )
+
+ super().__init__(children=[output_cell, solution_box])
+
+
+@pytest.fixture
+def function_to_test():
+ """Function to test, overridden at runtime by the cell magic"""
+
+
+class FunctionInjectionPlugin:
+ """A class to inject a function to test"""
+
+ def __init__(self, function_to_test: Callable) -> None:
+ self.function_to_test = function_to_test
+
+ def pytest_generate_tests(self, metafunc: pytest.Metafunc) -> None:
+ # Override the abstract `function_to_test` fixture function
+ if "function_to_test" in metafunc.fixturenames:
+ metafunc.parametrize("function_to_test", [self.function_to_test])
+
+
+class ResultCollector:
+ """A class that will collect the result of a test. If behaves a bit like a visitor pattern"""
+
+ def __init__(self) -> None:
+ self.tests: Dict[str, TestResult] = {}
+
+ def pytest_runtest_logreport(self, report: pytest.TestReport):
+ # Only collect the results if it did not fail
+ if report.when == "teardown" and report.nodeid not in self.tests:
+ self.tests[report.nodeid] = TestResult(
+ report.capstdout, report.capstderr, report.nodeid, not report.failed
+ )
+
+ def pytest_exception_interact(
+ self, node: pytest.Item, call: pytest.CallInfo, report: pytest.TestReport
+ ):
+ # We need to collect the results and the stderr if the test failed
+ if report.failed:
+ self.tests[node.nodeid] = TestResult(
+ report.capstdout,
+ str(call.excinfo.getrepr() if call.excinfo else ""),
+ report.nodeid,
+ False,
+ )
+
+
+class AstParser:
+ """
+ Helper class for extraction of function definitions and imports.
+ To find all reference solutions:
+ Parse the module file using the AST module and retrieve all function definitions and imports.
+ For each reference solution store the names of all other functions used inside of it.
+ """
+
+ def __init__(self, module_file: pathlib.Path) -> None:
+ self.module_file = module_file
+ self.function_defs = {}
+ self.function_imports = {}
+ self.called_function_names = {}
+
+ tree = ast.parse(self.module_file.read_text(encoding="utf-8"))
+
+ for node in tree.body:
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
+ self.function_defs[node.name] = node
+ elif isinstance(node, (ast.Import, ast.ImportFrom)) and hasattr(
+ node, "module"
+ ):
+ for n in node.names:
+ self.function_imports[n.name] = node.module
+
+ for node in tree.body:
+ if (
+ node in self.function_defs.values()
+ and hasattr(node, "name")
+ and node.name.startswith("reference_")
+ ):
+ self.called_function_names[node.name] = self.retrieve_functions(
+ {**self.function_defs, **self.function_imports}, node, {node.name}
+ )
+
+ def retrieve_functions(
+ self, all_functions: Dict, node: object, called_functions: Set[object]
+ ) -> Set[object]:
+ """
+ Recursively walk the AST tree to retrieve all function definitions in a file
+ """
+
+ if isinstance(node, ast.AST):
+ for n in ast.walk(node):
+ match n:
+ case ast.Call(ast.Name(id=name)):
+ called_functions.add(name)
+ if name in all_functions:
+ called_functions = self.retrieve_functions(
+ all_functions, all_functions[name], called_functions
+ )
+ for child in ast.iter_child_nodes(n):
+ called_functions = self.retrieve_functions(
+ all_functions, child, called_functions
+ )
+
+ return called_functions
+
+ def get_solution_code(self, name):
+ """
+ Find the respective reference solution for the executed function.
+ Create a str containing its code and the code of all other functions used,
+ whether coming from the same file or an imported one.
+ """
+
+ solution_functions = self.called_function_names[f"reference_{name}"]
+ solution_code = ""
+
+ for f in solution_functions:
+ if f in self.function_defs:
+ solution_code += ast.unparse(self.function_defs[f]) + "\n\n"
+ elif f in self.function_imports:
+ function_file = pathlib.Path(
+ f"{self.function_imports[f].replace('.', '/')}.py"
+ )
+ if function_file.exists():
+ function_file_tree = ast.parse(
+ function_file.read_text(encoding="utf-8")
+ )
+ for node in function_file_tree.body:
+ if (
+ isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
+ and node.name == f
+ ):
+ solution_code += ast.unparse(node) + "\n\n"
+
+ return solution_code
+
+
+class FunctionNotFoundError(Exception):
+ """Custom exception raised when the solution code cannot be parsed"""
+
+ def __init__(self) -> None:
+ super().__init__("No functions to test defined in the cell")
+
+
+class InstanceNotFoundError(Exception):
+ """Custom exception raised when an instance cannot be found"""
+
+ def __init__(self, name: str) -> None:
+ super().__init__(f"Could not get {name} instance")
diff --git a/tutorial/toc.py b/tutorial/toc.py
index 2a8814a0..ab4896a7 100755
--- a/tutorial/toc.py
+++ b/tutorial/toc.py
@@ -3,12 +3,18 @@
import argparse as ap
import pathlib
import re
-from collections import namedtuple
+from typing import NamedTuple
import nbformat
from nbformat import NotebookNode
-TocEntry = namedtuple("TocEntry", ["level", "text", "anchor"])
+
+class TocEntry(NamedTuple):
+ """Table of contents entry"""
+
+ level: int
+ text: str
+ anchor: str
def extract_markdown_cells(notebook: NotebookNode) -> str:
@@ -21,30 +27,32 @@ def extract_markdown_cells(notebook: NotebookNode) -> str:
def extract_toc(notebook: str) -> list[TocEntry]:
"""Extract the table of contents from a markdown string"""
toc = []
- line_re = re.compile(r"(#+)\s+(.+)")
- for line in notebook.splitlines():
- if groups := re.match(line_re, line):
- heading, text, *_ = groups.groups()
- level = len(heading)
+
+ for match in re.findall(r"```py.*\n#|^(#{1,6})\s+(.+)", notebook, re.MULTILINE):
+ if all(match):
+ level, text = match
anchor = "-".join(text.replace("`", "").split())
- toc.append(TocEntry(level, text, anchor))
+ toc.append(TocEntry(len(level), text, anchor))
+
return toc
def markdown_toc(toc: list[TocEntry]) -> str:
"""Build a string representation of the toc as a nested markdown list"""
- lines = []
- for entry in toc:
- line = f"{' ' * entry.level}- [{entry.text}](#{entry.anchor})"
- lines.append(line)
- return "\n".join(lines)
+ return "\n".join(
+ f"{' ' * entry.level}- [{entry.text}](#{entry.anchor})" for entry in toc
+ )
-def build_toc(nb_path: pathlib.Path, placeholder: str = "[TOC]") -> NotebookNode:
- """Build a table of contents for a notebook and insert it at the location of a placeholder"""
- # Read the notebook
- nb_obj: NotebookNode = nbformat.read(nb_path, nbformat.NO_CONVERT)
- md_cells = extract_markdown_cells(nb_obj)
+def add_toc_and_backlinks(
+ notebook_path: pathlib.Path, placeholder: str = "[TOC]"
+) -> NotebookNode:
+ """Replace a `placeholder` cell with a table of contents and add backlinks to each header"""
+ # Read notebook
+ notebook: NotebookNode = nbformat.read(notebook_path, nbformat.NO_CONVERT)
+
+ # Extract markdown cells
+ md_cells = extract_markdown_cells(notebook)
# Build tree
toc_tree = extract_toc(md_cells)
@@ -55,12 +63,19 @@ def build_toc(nb_path: pathlib.Path, placeholder: str = "[TOC]") -> NotebookNode
# Insert it a the location of a placeholder
toc_header = "# Table of Contents"
- for cell in nb_obj.cells:
+ for cell in notebook.cells:
+ # Add backlinks
+ if cell.cell_type == "markdown":
+ cell.source = re.sub(
+ r"^(#{1,6})\s+(.+)", r"\1 \2 [↩](#Table-of-Contents)", cell.source
+ )
+
+ # Replace placeholder with toc
if cell.source.startswith((placeholder, toc_header)):
cell.source = f"{toc_header}\n{toc_repr}"
cell.cell_type = "markdown"
- return nb_obj
+ return notebook
def main():
@@ -90,7 +105,7 @@ def main():
output_nb = pathlib.Path(args.output)
with output_nb.open("w", encoding="utf-8") as file:
- nbformat.write(build_toc(input_nb), file)
+ nbformat.write(add_toc_and_backlinks(input_nb), file)
if args.force:
input_nb.unlink()