Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions paracelsus/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class ColumnSorts(str, Enum):
preserve = "preserve-order"


class Layouts(str, Enum):
dagre = "dagre"
elk = "elk"


if "column_sort" in PYPROJECT_SETTINGS:
SORT_DEFAULT = ColumnSorts(PYPROJECT_SETTINGS["column_sort"]).value
else:
Expand Down Expand Up @@ -82,13 +87,22 @@ def graph(
help="Specifies the method of sorting columns in diagrams.",
),
] = SORT_DEFAULT, # type: ignore # Typer will fail to render the help message, but this code works.
layout: Annotated[
Optional[Layouts],
typer.Option(
help="Specifies the layout of the diagram. Only applicable for mermaid format.",
),
] = None,
):
settings = get_pyproject_settings()
base_class = get_base_class(base_class_path, settings)

if "imports" in settings:
import_module.extend(settings["imports"])

if layout and format != Formats.mermaid:
raise ValueError("The `layout` parameter can only be used with the `mermaid` format.")

typer.echo(
get_graph_string(
base_class_path=base_class,
Expand All @@ -98,6 +112,7 @@ def graph(
python_dir=python_dir,
format=format.value,
column_sort=column_sort,
layout=layout.value if layout else None,
)
)

Expand Down Expand Up @@ -166,11 +181,20 @@ def inject(
help="Specifies the method of sorting columns in diagrams.",
),
] = SORT_DEFAULT, # type: ignore # Typer will fail to render the help message, but this code works.
layout: Annotated[
Optional[Layouts],
typer.Option(
help="Specifies the layout of the diagram. Only applicable for mermaid format.",
),
] = None,
):
settings = get_pyproject_settings()
if "imports" in settings:
import_module.extend(settings["imports"])

if layout and format != Formats.mermaid:
raise ValueError("The `layout` parameter can only be used with the `mermaid` format.")

# Generate Graph
graph = get_graph_string(
base_class_path=base_class_path,
Expand All @@ -180,6 +204,7 @@ def inject(
python_dir=python_dir,
format=format.value,
column_sort=column_sort,
layout=layout.value if layout else None,
)

comment_format = transformers[format].comment_format # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions paracelsus/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from pathlib import Path
import re
from typing import List, Set
from typing import List, Set, Optional

from sqlalchemy.schema import MetaData
from .transformers.dot import Dot
Expand All @@ -26,6 +26,7 @@ def get_graph_string(
python_dir: List[Path],
format: str,
column_sort: str,
layout: Optional[str] = None,
) -> str:
# Update the PYTHON_PATH to allow more module imports.
sys.path.append(str(os.getcwd()))
Expand Down Expand Up @@ -61,7 +62,7 @@ def get_graph_string(
filtered_metadata = filter_metadata(metadata=metadata, include_tables=include_tables)

# Save the graph structure to string.
return str(transformer(filtered_metadata, column_sort))
return str(transformer(filtered_metadata, column_sort, layout=layout))


def resolve_included_tables(
Expand Down
17 changes: 15 additions & 2 deletions paracelsus/transformers/mermaid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from sqlalchemy.sql.schema import Column, MetaData, Table
from typing import Optional
import textwrap

from .utils import sort_columns

Expand All @@ -11,10 +13,12 @@ class Mermaid:
comment_format: str = "mermaid"
metadata: MetaData
column_sort: str
layout: Optional[str]

def __init__(self, metaclass: MetaData, column_sort: str) -> None:
def __init__(self, metaclass: MetaData, column_sort: str, layout: Optional[str] = None) -> None:
self.metadata = metaclass
self.column_sort = column_sort
self.layout = layout

def _table(self, table: Table) -> str:
output = f" {table.name}"
Expand Down Expand Up @@ -89,7 +93,16 @@ def _relationships(self, column: Column) -> str:
return output

def __str__(self) -> str:
output = "erDiagram\n"
output = ""
if self.layout:
yaml_front_matter = textwrap.dedent(f"""
---
config:
layout: {self.layout}
---
""")
output = yaml_front_matter + output
output += "erDiagram\n"
for table in self.metadata.tables.values():
output += self._table(table)

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@ find = {}
[tool.setuptools_scm]
fallback_version = "0.0.0-dev"
write_to = "paracelsus/_version.py"

[tool.paracelsus]
layout = "dagre"
43 changes: 43 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,46 @@ def test_graph_with_exclusion_regex(package_path: Path):
assert "comments {" in result.stdout
assert "users {" in result.stdout
assert "post {" not in result.stdout


@pytest.mark.parametrize("layout_arg", ["dagre", "elk"])
def test_graph_layout(package_path: Path, layout_arg: Literal["dagre", "elk"]):
result = runner.invoke(
app,
[
"graph",
"example.base:Base",
"--import-module",
"example.models",
"--python-dir",
str(package_path),
"--layout",
layout_arg,
],
)

assert result.exit_code == 0
mermaid_assert(result.stdout)


@pytest.mark.parametrize("layout_arg", ["dagre", "elk"])
def test_inject_layout(package_path: Path, layout_arg: Literal["dagre", "elk"]):
result = runner.invoke(
app,
[
"inject",
str(package_path / "README.md"),
"example.base:Base",
"--import-module",
"example.models",
"--python-dir",
str(package_path),
"--layout",
layout_arg,
],
)
assert result.exit_code == 0

with open(package_path / "README.md") as fp:
readme = fp.read()
mermaid_assert(readme)
15 changes: 15 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,18 @@ def test_get_graph_string_with_include(package_path):
)
assert "posts {" in graph_string
assert "users ||--o{ posts" not in graph_string


@pytest.mark.parametrize("layout_arg", ["dagre", "elk"])
def test_get_graph_string_with_layout(layout_arg, package_path):
graph_string = get_graph_string(
base_class_path="example.base:Base",
import_module=["example.models"],
include_tables=set(),
exclude_tables=set(),
python_dir=[package_path],
format="mermaid",
column_sort="key-based",
layout=layout_arg,
)
mermaid_assert(graph_string)
30 changes: 30 additions & 0 deletions tests/transformers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import textwrap

@pytest.fixture()
def mermaid_full_string_with_no_layout(mermaid_full_string_preseve_column_sort: str) -> str:
return mermaid_full_string_preseve_column_sort

@pytest.fixture()
def mermaid_full_string_with_dagre_layout(mermaid_full_string_preseve_column_sort: str) -> str:
front_matter = textwrap.dedent(
"""
---
config:
layout: dagre
---
"""
)
return f"{front_matter}{mermaid_full_string_preseve_column_sort}"

@pytest.fixture()
def mermaid_full_string_with_elk_layout(mermaid_full_string_preseve_column_sort: str) -> str:
front_matter = textwrap.dedent(
"""
---
config:
layout: elk
---
"""
)
return f"{front_matter}{mermaid_full_string_preseve_column_sort}"
15 changes: 15 additions & 0 deletions tests/transformers/test_mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,18 @@ def test_mermaid(metaclass):
def test_mermaid_column_sort_preserve_order(metaclass, mermaid_full_string_preseve_column_sort):
mermaid = Mermaid(metaclass=metaclass, column_sort="preserve-order")
assert str(mermaid) == mermaid_full_string_preseve_column_sort


def test_mermaid_with_no_layout(metaclass, mermaid_full_string_with_no_layout):
mermaid = Mermaid(metaclass=metaclass, column_sort="preserve-order", layout=None)
assert str(mermaid) == mermaid_full_string_with_no_layout


def test_mermaid_with_dagre_layout(metaclass, mermaid_full_string_with_dagre_layout):
mermaid = Mermaid(metaclass=metaclass, column_sort="preserve-order", layout="dagre")
assert str(mermaid) == mermaid_full_string_with_dagre_layout


def test_mermaid_with_elk_layout(metaclass, mermaid_full_string_with_elk_layout):
mermaid = Mermaid(metaclass=metaclass, column_sort="preserve-order", layout="elk")
assert str(mermaid) == mermaid_full_string_with_elk_layout
Loading