diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 877b59f05..3ae1d6ae9 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -23,6 +23,7 @@ from codegen.extensions.tools.replacement_edit import replacement_edit from codegen.extensions.tools.reveal_symbol import reveal_symbol from codegen.extensions.tools.search import search +from codegen.extensions.tools.search_files_by_name import search_files_by_name from codegen.extensions.tools.semantic_edit import semantic_edit from codegen.extensions.tools.semantic_search import semantic_search from codegen.sdk.core.codebase import Codebase @@ -1024,3 +1025,30 @@ def _run( result = perform_reflection(context_summary=context_summary, findings_so_far=findings_so_far, current_challenges=current_challenges, reflection_focus=reflection_focus, codebase=self.codebase) return result.render() + + +class SearchFilesByNameInput(BaseModel): + """Input for searching files by name pattern.""" + + pattern: str = Field(..., description="Glob pattern to search for (e.g. '*.py', 'test_*.py')") + + +class SearchFilesByNameTool(BaseTool): + """Tool for searching files by filename across a codebase.""" + + name: ClassVar[str] = "search_files_by_name" + description: ClassVar[str] = """ + Search for files and directories by glob pattern across the active codebase. This is useful when you need to: + - Find specific file types (e.g., '*.py', '*.tsx') + - Locate configuration files (e.g., 'package.json', 'requirements.txt') + - Find files with specific names (e.g., 'README.md', 'Dockerfile') + """ + args_schema: ClassVar[type[BaseModel]] = SearchFilesByNameInput + codebase: Codebase = Field(exclude=True) + + def __init__(self, codebase: Codebase): + super().__init__(codebase=codebase) + + def _run(self, pattern: str) -> str: + """Execute the glob pattern search using fd.""" + return search_files_by_name(self.codebase, pattern).render() diff --git a/src/codegen/extensions/tools/__init__.py b/src/codegen/extensions/tools/__init__.py index 8f49b68a8..44305e61a 100644 --- a/src/codegen/extensions/tools/__init__.py +++ b/src/codegen/extensions/tools/__init__.py @@ -22,6 +22,7 @@ from .reveal_symbol import reveal_symbol from .run_codemod import run_codemod from .search import search +from .search_files_by_name import search_files_by_name from .semantic_edit import semantic_edit from .semantic_search import semantic_search from .view_file import view_file @@ -52,6 +53,7 @@ "run_codemod", # Search operations "search", + "search_files_by_name", # Edit operations "semantic_edit", "semantic_search", diff --git a/src/codegen/extensions/tools/search_files_by_name.py b/src/codegen/extensions/tools/search_files_by_name.py new file mode 100644 index 000000000..bc595a25c --- /dev/null +++ b/src/codegen/extensions/tools/search_files_by_name.py @@ -0,0 +1,72 @@ +import shutil +import subprocess +from typing import ClassVar + +from pydantic import Field + +from codegen.extensions.tools.observation import Observation +from codegen.sdk.core.codebase import Codebase +from codegen.shared.logging.get_logger import get_logger + +logger = get_logger(__name__) + + +class SearchFilesByNameResultObservation(Observation): + """Response from searching files by filename pattern.""" + + pattern: str = Field( + description="The glob pattern that was searched for", + ) + files: list[str] = Field( + description="List of matching file paths", + ) + + str_template: ClassVar[str] = "Found {total} files matching pattern: {pattern}" + + @property + def total(self) -> int: + return len(self.files) + + +def search_files_by_name( + codebase: Codebase, + pattern: str, +) -> SearchFilesByNameResultObservation: + """Search for files by name pattern in the codebase. + + Args: + codebase: The codebase to search in + pattern: Glob pattern to search for (e.g. "*.py", "test_*.py") + """ + try: + if shutil.which("fd") is None: + logger.warning("fd is not installed, falling back to find") + results = subprocess.check_output( + ["find", "-name", pattern], + cwd=codebase.repo_path, + timeout=30, + ) + files = [path.removeprefix("./") for path in results.decode("utf-8").strip().split("\n")] if results.strip() else [] + + else: + logger.info(f"Searching for files with pattern: {pattern}") + results = subprocess.check_output( + ["fd", "-g", pattern], + cwd=codebase.repo_path, + timeout=30, + ) + files = results.decode("utf-8").strip().split("\n") if results.strip() else [] + + return SearchFilesByNameResultObservation( + status="success", + pattern=pattern, + files=files, + ) + + except Exception as e: + return SearchFilesByNameResultObservation( + status="error", + error=f"Error searching files: {e!s}", + pattern=pattern, + files=[], + ) diff --git a/tests/unit/codegen/extensions/test_tools.py b/tests/unit/codegen/extensions/test_tools.py index 70bf68512..006d4f1e3 100644 --- a/tests/unit/codegen/extensions/test_tools.py +++ b/tests/unit/codegen/extensions/test_tools.py @@ -15,6 +15,7 @@ replacement_edit, reveal_symbol, run_codemod, + search_files_by_name, semantic_edit, semantic_search, view_file, @@ -282,6 +283,21 @@ def test_move_symbol(codebase): assert result.target_file == "src/target.py" +def test_search_files_by_name(codebase): + """Test searching files by name.""" + create_file(codebase, "src/main.py", "print('hello')") + create_file(codebase, "src/target.py", "print('world')") + result = search_files_by_name(codebase, "*.py") + assert result.status == "success" + assert len(result.files) == 2 + assert "src/main.py" in result.files + assert "src/target.py" in result.files + result = search_files_by_name(codebase, "main.py") + assert result.status == "success" + assert len(result.files) == 1 + assert "src/main.py" in result.files + + def test_reveal_symbol(codebase): """Test revealing symbol relationships.""" result = reveal_symbol(