diff --git a/cpp_linter_hooks/util.py b/cpp_linter_hooks/util.py index a4fe95e..c766123 100644 --- a/cpp_linter_hooks/util.py +++ b/cpp_linter_hooks/util.py @@ -1,20 +1,25 @@ import sys import shutil -import toml import subprocess from pathlib import Path import logging from typing import Optional, List -from packaging.version import Version, InvalidVersion + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib LOG = logging.getLogger(__name__) def get_version_from_dependency(tool: str) -> Optional[str]: + """Get the version of a tool from the pyproject.toml dependencies.""" pyproject_path = Path(__file__).parent.parent / "pyproject.toml" if not pyproject_path.exists(): return None - data = toml.load(pyproject_path) + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) dependencies = data.get("project", {}).get("dependencies", []) for dep in dependencies: if dep.startswith(f"{tool}=="): @@ -22,8 +27,8 @@ def get_version_from_dependency(tool: str) -> Optional[str]: return None -DEFAULT_CLANG_FORMAT_VERSION = get_version_from_dependency("clang-format") or "20.1.7" -DEFAULT_CLANG_TIDY_VERSION = get_version_from_dependency("clang-tidy") or "20.1.0" +DEFAULT_CLANG_FORMAT_VERSION = get_version_from_dependency("clang-format") +DEFAULT_CLANG_TIDY_VERSION = get_version_from_dependency("clang-tidy") CLANG_FORMAT_VERSIONS = [ @@ -108,29 +113,21 @@ def get_version_from_dependency(tool: str) -> Optional[str]: def _resolve_version(versions: List[str], user_input: Optional[str]) -> Optional[str]: + """Resolve the version based on user input and available versions.""" if user_input is None: return None + if user_input in versions: + return user_input try: - user_ver = Version(user_input) - except InvalidVersion: + # Check if the user input is a valid version + return next(v for v in versions if v.startswith(user_input) or v == user_input) + except StopIteration: + LOG.warning("Version %s not found in available versions", user_input) return None - candidates = [Version(v) for v in versions] - if user_input.count(".") == 0: - matches = [v for v in candidates if v.major == user_ver.major] - elif user_input.count(".") == 1: - matches = [ - v - for v in candidates - if f"{v.major}.{v.minor}" == f"{user_ver.major}.{user_ver.minor}" - ] - else: - return str(user_ver) if user_ver in candidates else None - - return str(max(matches)) if matches else None - def _get_runtime_version(tool: str) -> Optional[str]: + """Get the runtime version of a tool.""" try: output = subprocess.check_output([tool, "--version"], text=True) if tool == "clang-tidy": @@ -144,6 +141,7 @@ def _get_runtime_version(tool: str) -> Optional[str]: def _install_tool(tool: str, version: str) -> Optional[Path]: + """Install a tool using pip.""" try: subprocess.check_call( [sys.executable, "-m", "pip", "install", f"{tool}=={version}"] @@ -155,6 +153,7 @@ def _install_tool(tool: str, version: str) -> Optional[Path]: def _resolve_install(tool: str, version: Optional[str]) -> Optional[Path]: + """Resolve the installation of a tool, checking for version and installing if necessary.""" user_version = _resolve_version( CLANG_FORMAT_VERSIONS if tool == "clang-format" else CLANG_TIDY_VERSIONS, version, @@ -191,6 +190,7 @@ def is_installed(tool: str) -> Optional[Path]: def ensure_installed(tool: str, version: Optional[str] = None) -> str: + """Ensure a tool is installed, resolving its version if necessary.""" LOG.info("Ensuring %s is installed", tool) tool_path = _resolve_install(tool, version) if tool_path: diff --git a/pyproject.toml b/pyproject.toml index 9ad8181..798646d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,7 @@ classifiers = [ dependencies = [ "clang-format==20.1.7", "clang-tidy==20.1.0", - "toml>=0.10.2", - "packaging>=20.0", + "tomli>=1.1.0; python_version < '3.11'", ] dynamic = ["version"] diff --git a/testing/run.sh b/testing/run.sh index 3407656..b176b06 100644 --- a/testing/run.sh +++ b/testing/run.sh @@ -29,7 +29,7 @@ failed_cases=`grep -c "Failed" result.txt` echo $failed_cases " cases failed." -if [ $failed_cases -eq 9 ]; then +if [ $failed_cases -eq 10 ]; then echo "==============================" echo "Test cpp-linter-hooks success." echo "==============================" diff --git a/tests/test_util.py b/tests/test_util.py index dfeca88..86a4ecd 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -112,7 +112,7 @@ def test_get_version_from_dependency_success(): with ( patch("pathlib.Path.exists", return_value=True), - patch("toml.load", return_value=mock_toml_content), + patch("cpp_linter_hooks.util.tomllib.load", return_value=mock_toml_content), ): result = get_version_from_dependency("clang-format") assert result == "20.1.7" @@ -136,7 +136,7 @@ def test_get_version_from_dependency_missing_dependency(): with ( patch("pathlib.Path.exists", return_value=True), - patch("toml.load", return_value=mock_toml_content), + patch("cpp_linter_hooks.util.tomllib.load", return_value=mock_toml_content), ): result = get_version_from_dependency("clang-format") assert result is None @@ -149,7 +149,7 @@ def test_get_version_from_dependency_malformed_toml(): with ( patch("pathlib.Path.exists", return_value=True), - patch("toml.load", return_value=mock_toml_content), + patch("cpp_linter_hooks.util.tomllib.load", return_value=mock_toml_content), ): result = get_version_from_dependency("clang-format") assert result is None @@ -161,11 +161,11 @@ def test_get_version_from_dependency_malformed_toml(): "user_input,expected", [ (None, None), - ("20", "20.1.7"), # Should find latest 20.x - ("20.1", "20.1.7"), # Should find latest 20.1.x + ("20", "20.1.0"), # Should find first 20.x + ("20.1", "20.1.0"), # Should find first 20.1.x ("20.1.7", "20.1.7"), # Exact match - ("18", "18.1.8"), # Should find latest 18.x - ("18.1", "18.1.8"), # Should find latest 18.1.x + ("18", "18.1.0"), # Should find first 18.x + ("18.1", "18.1.0"), # Should find first 18.1.x ("99", None), # Non-existent major version ("20.99", None), # Non-existent minor version ("invalid", None), # Invalid version string @@ -182,9 +182,9 @@ def test_resolve_version_clang_format(user_input, expected): "user_input,expected", [ (None, None), - ("20", "20.1.0"), # Should find latest 20.x - ("18", "18.1.8"), # Should find latest 18.x - ("19", "19.1.0.1"), # Should find latest 19.x + ("20", "20.1.0"), # Should find first 20.x + ("18", "18.1.1"), # Should find first 18.x + ("19", "19.1.0"), # Should find first 19.x ("99", None), # Non-existent major version ], )