diff --git a/.gitignore b/.gitignore index a4fa1397f..4b1da69de 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ graph-sitter-types/typings/** coverage.json tests/integration/verified_codemods/codemod_data/repo_commits.json .codegen/* +.benchmarks/* \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d32e379d0..f7474a234 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,8 @@ dev-dependencies = [ "black>=24.8.0", "isort>=5.13.2", "emoji>=2.14.0", + "pytest-benchmark[histogram]>=5.1.0", + "loguru>=0.7.3", ] keyring-provider = "subprocess" #extra-index-url = ["https://aws@codegen-922078275900.d.codeartifact.us-east-1.amazonaws.com/pypi/codegen/simple/"] diff --git a/src/codegen/git/repo_operator/local_repo_operator.py b/src/codegen/git/repo_operator/local_repo_operator.py index 635e703c5..151a2cace 100644 --- a/src/codegen/git/repo_operator/local_repo_operator.py +++ b/src/codegen/git/repo_operator/local_repo_operator.py @@ -72,8 +72,8 @@ def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bo def create_from_commit(cls, repo_path: str, commit: str, url: str) -> Self: """Do a shallow checkout of a particular commit to get a repository from a given remote URL.""" op = cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False) + op.discard_changes() if op.get_active_branch_or_commit() != commit: - op.discard_changes() op.create_remote("origin", url) op.git_cli.remotes["origin"].fetch(commit, depth=1) op.checkout_commit(commit) diff --git a/src/codegen/git/repo_operator/remote_repo_operator.py b/src/codegen/git/repo_operator/remote_repo_operator.py index 3384fbfd0..e0d526236 100644 --- a/src/codegen/git/repo_operator/remote_repo_operator.py +++ b/src/codegen/git/repo_operator/remote_repo_operator.py @@ -42,8 +42,9 @@ def __init__( setup_option: SetupOption = SetupOption.PULL_OR_CLONE, shallow: bool = True, github_type: GithubType = GithubType.GithubEnterprise, + bot_commit: bool = True, ) -> None: - super().__init__(repo_config=repo_config, base_dir=base_dir) + super().__init__(repo_config=repo_config, base_dir=base_dir, bot_commit=bot_commit) self.github_type = github_type self.setup_repo_dir(setup_option=setup_option, shallow=shallow) diff --git a/src/codegen/git/repo_operator/repo_operator.py b/src/codegen/git/repo_operator/repo_operator.py index 4f12f06bb..3a96d06ad 100644 --- a/src/codegen/git/repo_operator/repo_operator.py +++ b/src/codegen/git/repo_operator/repo_operator.py @@ -67,12 +67,21 @@ def viz_file_path(self) -> str: def git_cli(self) -> GitCLI: """Note: this is recursive, may want to look out""" git_cli = GitCLI(self.repo_path) + has_username = False + has_email = False + with git_cli.config_reader(None) as reader: + if reader.has_option("user", "name"): + has_username = True + if reader.has_option("user", "email"): + has_email = True with git_cli.config_writer("repository") as writer: - if self.bot_commit: + if not has_username or not has_email or self.bot_commit: if not writer.has_section("user"): writer.add_section("user") - writer.set("user", "name", CODEGEN_BOT_NAME) - writer.set("user", "email", CODEGEN_BOT_EMAIL) + if not has_username or self.bot_commit: + writer.set("user", "name", CODEGEN_BOT_NAME) + if not has_email or self.bot_commit: + writer.set("user", "email", CODEGEN_BOT_EMAIL) return git_cli @property diff --git a/src/codegen/sdk/codebase/codebase_graph.py b/src/codegen/sdk/codebase/codebase_graph.py index f3ffbc8be..415a9b201 100644 --- a/src/codegen/sdk/codebase/codebase_graph.py +++ b/src/codegen/sdk/codebase/codebase_graph.py @@ -30,6 +30,7 @@ from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.enums import Edge, EdgeType, NodeType, ProgrammingLanguage +from codegen.sdk.extensions.io import write_changes from codegen.sdk.extensions.sort import sort_editables from codegen.sdk.extensions.utils import uncache_all from codegen.sdk.typescript.external.ts_declassify.ts_declassify import TSDeclassify @@ -107,6 +108,7 @@ class CodebaseGraph: flags: Flags session_options: SessionOptions = SessionOptions() projects: list[ProjectConfig] + unapplied_diffs: list[DiffLite] def __init__( self, @@ -161,6 +163,7 @@ def __init__( self.synced_commit = None self.pending_syncs = [] self.all_syncs = [] + self.unapplied_diffs = [] self.pending_files = set() self.flags = Flags() @@ -232,9 +235,40 @@ def apply_diffs(self, diff_list: list[DiffLite]) -> None: self.generation += 1 self._process_diff_files(by_sync_type) + def _reset_files(self, syncs: list[DiffLite]) -> None: + files_to_write = [] + files_to_remove = [] + modified_files = set() + for sync in syncs: + if sync.path in modified_files: + continue + if sync.change_type == ChangeType.Removed: + files_to_write.append((sync.path, sync.old_content)) + modified_files.add(sync.path) + logger.info(f"Removing {sync.path} from disk") + elif sync.change_type == ChangeType.Modified: + files_to_write.append((sync.path, sync.old_content)) + modified_files.add(sync.path) + elif sync.change_type == ChangeType.Renamed: + files_to_write.append((sync.rename_from, sync.old_content)) + files_to_remove.append(sync.rename_to) + modified_files.add(sync.rename_from) + modified_files.add(sync.rename_to) + elif sync.change_type == ChangeType.Added: + files_to_remove.append(sync.path) + modified_files.add(sync.path) + logger.info(f"Writing {len(files_to_write)} files to disk and removing {len(files_to_remove)} files") + write_changes(files_to_remove, files_to_write) + + @stopwatch + def reset_codebase(self) -> None: + self._reset_files(self.all_syncs + self.pending_syncs + self.unapplied_diffs) + self.unapplied_diffs.clear() + @stopwatch def undo_applied_diffs(self) -> None: self.transaction_manager.clear_transactions() + self.reset_codebase() self.check_changes() self.pending_syncs.clear() # Discard pending changes if len(self.all_syncs) > 0: @@ -256,6 +290,9 @@ def _revert_diffs(self, diff_list: list[DiffLite]) -> None: def save_commit(self, commit: GitCommit) -> None: if commit is not None: + logger.info(f"Saving commit {commit.hexsha} to graph") + self.all_syncs.clear() + self.unapplied_diffs.clear() self.synced_commit = commit if self.config.feature_flags.verify_graph: self.old_graph = self._graph.copy() @@ -630,9 +667,11 @@ def commit_transactions(self, sync_graph: bool = True, sync_file: bool = True, f # Commit transactions for all contexts files_to_lock = self.transaction_manager.to_commit(files) diffs = self.transaction_manager.commit(files_to_lock) - # Filter diffs to only include files that are still in the graph - diffs = [diff for diff in diffs if self.get_file(diff.path) is not None] - self.pending_syncs.extend(diffs) + for diff in diffs: + if self.get_file(diff.path) is None: + self.unapplied_diffs.append(diff) + else: + self.pending_syncs.append(diff) # Write files if requested if sync_file: diff --git a/src/codegen/sdk/codebase/diff_lite.py b/src/codegen/sdk/codebase/diff_lite.py index e8ed12e89..f38839ef5 100644 --- a/src/codegen/sdk/codebase/diff_lite.py +++ b/src/codegen/sdk/codebase/diff_lite.py @@ -1,4 +1,4 @@ -from enum import Enum, auto +from enum import IntEnum, auto from os import PathLike from pathlib import Path from typing import NamedTuple, Self @@ -7,7 +7,7 @@ from watchfiles import Change -class ChangeType(Enum): +class ChangeType(IntEnum): Modified = auto() Removed = auto() Renamed = auto() @@ -40,8 +40,9 @@ class DiffLite(NamedTuple): change_type: ChangeType path: Path - rename_from: str | None = None - rename_to: str | None = None + rename_from: Path | None = None + rename_to: Path | None = None + old_content: bytes | None = None @classmethod def from_watch_change(cls, change: Change, path: PathLike) -> Self: @@ -52,11 +53,15 @@ def from_watch_change(cls, change: Change, path: PathLike) -> Self: @classmethod def from_git_diff(cls, git_diff: Diff): + old = None + if git_diff.a_blob: + old = git_diff.a_blob.data_stream.read() return cls( change_type=ChangeType.from_git_change_type(git_diff.change_type), - path=Path(git_diff.a_path), - rename_from=git_diff.rename_from, - rename_to=git_diff.rename_to, + path=Path(git_diff.a_path) if git_diff.a_path else None, + rename_from=Path(git_diff.rename_from) if git_diff.rename_from else None, + rename_to=Path(git_diff.rename_to) if git_diff.rename_to else None, + old_content=old, ) @classmethod diff --git a/src/codegen/sdk/codebase/transaction_manager.py b/src/codegen/sdk/codebase/transaction_manager.py index b7b050822..f91e13142 100644 --- a/src/codegen/sdk/codebase/transaction_manager.py +++ b/src/codegen/sdk/codebase/transaction_manager.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from codegen.sdk.codebase.diff_lite import DiffLite +from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite from codegen.sdk.codebase.transactions import ( EditTransaction, FileAddTransaction, @@ -163,16 +163,16 @@ def to_commit(self, files: set[Path] | None = None) -> set[Path]: return set(self.queued_transactions.keys()) return files.intersection(self.queued_transactions) - def commit(self, files: set[Path]) -> set[DiffLite]: + def commit(self, files: set[Path]) -> list[DiffLite]: """Execute transactions in bulk for each file, in reverse order of start_byte. - Returns the set of diffs that were committed. + Returns the list of diffs that were committed. """ if self._commiting: logger.warn("Skipping commit, already committing") - return set() + return [] self._commiting = True try: - diffs: set[DiffLite] = set() + diffs: list[DiffLite] = [] if not self.queued_transactions or len(self.queued_transactions) == 0: return diffs @@ -187,9 +187,16 @@ def commit(self, files: set[Path]) -> set[DiffLite]: logger.info(f"Committing {len(self.queued_transactions[file])} transactions for {file}") for file_path in files: file_transactions = self.queued_transactions.pop(file_path, []) + modified = False for transaction in file_transactions: # Add diff IF the file is a source file - diffs.add(transaction.get_diff()) + diff = transaction.get_diff() + if diff.change_type == ChangeType.Modified: + if not modified: + modified = True + diffs.append(diff) + else: + diffs.append(diff) transaction.execute() return diffs finally: diff --git a/src/codegen/sdk/codebase/transactions.py b/src/codegen/sdk/codebase/transactions.py index 7fcdfd8c1..a43d0c399 100644 --- a/src/codegen/sdk/codebase/transactions.py +++ b/src/codegen/sdk/codebase/transactions.py @@ -127,7 +127,7 @@ def execute(self) -> None: def get_diff(self) -> DiffLite: """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Modified, self.file_path) + return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) def diff_str(self) -> str: """Human-readable string representation of the change""" @@ -170,7 +170,7 @@ def execute(self) -> None: def get_diff(self) -> DiffLite: """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Modified, self.file_path) + return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) def diff_str(self) -> str: """Human-readable string representation of the change""" @@ -205,7 +205,7 @@ def execute(self) -> None: def get_diff(self) -> DiffLite: """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Modified, self.file_path) + return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) def diff_str(self) -> str: """Human-readable string representation of the change""" @@ -269,7 +269,7 @@ def execute(self) -> None: def get_diff(self) -> DiffLite: """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Renamed, self.file_path, self.file_path, self.new_file_path) + return DiffLite(ChangeType.Renamed, self.file_path, self.file_path, self.new_file_path, old_content=self.file.content_bytes) def diff_str(self) -> str: """Human-readable string representation of the change""" @@ -294,7 +294,7 @@ def execute(self) -> None: def get_diff(self) -> DiffLite: """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Removed, self.file_path) + return DiffLite(ChangeType.Removed, self.file_path, old_content=self.file.content_bytes) def diff_str(self) -> str: """Human-readable string representation of the change""" diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 54675dbf1..b94b5eae2 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -738,7 +738,7 @@ def current_commit(self) -> GitCommit | None: return self._op.git_cli.head.commit @stopwatch - def reset(self) -> None: + def reset(self, git_reset: bool = False) -> None: """Resets the codebase by: - Discarding any staged/unstaged changes - Resetting stop codemod limits: (max seconds, max transactions, max AI requests) @@ -751,7 +751,8 @@ def reset(self) -> None: - .ipynb files (Jupyter notebooks, where you are likely developing) """ logger.info("Resetting codebase ...") - self._op.discard_changes() # Discard any changes made to the raw file state + if git_reset: + self._op.discard_changes() # Discard any changes made to the raw file state self._num_ai_requests = 0 self.reset_logs() self.G.undo_applied_diffs() @@ -818,12 +819,14 @@ def get_diffs(self, base: str | None = None) -> list[Diff]: return self._op.get_diffs(base) @noapidoc - def get_diff(self, base: str | None = None) -> str: + def get_diff(self, base: str | None = None, stage_files: bool = False) -> str: """Produce a single git diff for all files.""" - self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff + if stage_files: + self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff if base is None: - return self._op.git_cli.git.diff(patch=True, full_index=True, staged=True) - return self._op.git_cli.git.diff(base, full_index=True) + diff = self._op.git_cli.git.diff("HEAD", patch=True, full_index=True) + return diff + return self._op.git_cli.git.diff(base, patch=True, full_index=True) @noapidoc def clean_repo(self): diff --git a/src/codegen/sdk/extensions/io.pyx b/src/codegen/sdk/extensions/io.pyx new file mode 100644 index 000000000..6b2c8fcd8 --- /dev/null +++ b/src/codegen/sdk/extensions/io.pyx @@ -0,0 +1,12 @@ +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +import os + + +def write_changes(files_to_remove: list[Path], files_to_write: list[tuple[Path, bytes]]): + # Start at the oldest sync and then apply non-conflicting newer changes + with ThreadPoolExecutor() as executor: + for file_to_remove in files_to_remove: + executor.submit(os.remove, file_to_remove) + for file_to_write, content in files_to_write: + executor.submit(file_to_write.write_bytes, content) diff --git a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py index c033181d0..99604e0d4 100644 --- a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py +++ b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py @@ -12,7 +12,8 @@ @pytest.fixture def op(repo_config, request, tmpdir): - yield RemoteRepoOperator(repo_config, shallow=request.param, base_dir=tmpdir) + op = RemoteRepoOperator(repo_config, shallow=request.param, base_dir=tmpdir, bot_commit=False) + yield op @pytest.mark.parametrize("op", shallow_options, ids=lambda x: f"shallow={x}", indirect=True) diff --git a/tests/integration/codemod/conftest.py b/tests/integration/codemod/conftest.py index 33363c3a9..20fd0b362 100644 --- a/tests/integration/codemod/conftest.py +++ b/tests/integration/codemod/conftest.py @@ -156,8 +156,9 @@ def _codebase(repo: Repo, op: RepoOperator, request) -> YieldFixture[Codebase]: projects = [ProjectConfig(repo_operator=op, programming_language=repo.language, subdirectories=repo.subdirectories, base_path=repo.base_path)] Codebases[repo.name] = Codebase(projects=projects, config=CodebaseConfig(feature_flags=feature_flags)) codebase = Codebases[repo.name] - codebase.reset() + codebase.reset(git_reset=True) yield codebase + codebase.reset(git_reset=True) @pytest.fixture diff --git a/tests/shared/codemod/codebase_comparison_utils.py b/tests/shared/codemod/codebase_comparison_utils.py index ec4feb384..35306cf75 100644 --- a/tests/shared/codemod/codebase_comparison_utils.py +++ b/tests/shared/codemod/codebase_comparison_utils.py @@ -76,7 +76,7 @@ def compare_codebase_diff( diff = codebase.get_diff() + "\n" if not snapshot._snapshot_update: modified = gather_modified_files(codebase) - codebase.reset() + codebase.reset(git_reset=True) logger.info("Converting diff file to expected repository") if convert_diff_to_repo(expected_dir, expected_diff, codebase): return compare_codebase_with_snapshot(codebase, expected_dir, diff_path, snapshot, modified) diff --git a/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py b/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py new file mode 100644 index 000000000..3a8ebf689 --- /dev/null +++ b/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py @@ -0,0 +1,46 @@ +from pathlib import Path + +import pytest + +from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.enums import ProgrammingLanguage + + +def generate_files(num_files: int, extension: str = "py") -> dict[str, str]: + return {f"file{i}.{extension}": f"# comment {i}" for i in range(num_files)} + + +NUM_FILES = 1000 + + +def setup_codebase(num_files: int, extension: str, tmp_path: Path): + files = generate_files(num_files, extension) + with get_codebase_session(files=files, programming_language=ProgrammingLanguage.PYTHON, tmpdir=Path(tmp_path), sync_graph=False) as codebase: + for file in files: + codebase.get_file(file).edit(f"# comment2 {file}") + return codebase, files + + +def reset_codebase(codebase: Codebase): + codebase.reset() + + +@pytest.mark.benchmark(group="sdk-benchmark", min_time=1, max_time=5, disable_gc=True) +@pytest.mark.parametrize("extension", ["txt", "py"]) +def test_codebase_reset_stress_test(extension: str, tmp_path, benchmark): + def setup(): + codebase, _ = setup_codebase(NUM_FILES, extension, tmp_path) + return ((codebase,), {}) + + benchmark.pedantic(reset_codebase, setup=setup) + + +@pytest.mark.timeout(5, func_only=True) +@pytest.mark.parametrize("extension", ["txt", "py"]) +def test_codebase_reset_correctness(extension: str, tmp_path): + codebase, files = setup_codebase(NUM_FILES, extension, tmp_path) + codebase.reset() + for file, original_content in files.items(): + assert (tmp_path / file).exists() + assert (tmp_path / file).read_text() == original_content diff --git a/tests/unit/codegen/sdk/code_generation/test_api_doc_generation.py b/tests/unit/codegen/sdk/code_generation/test_api_doc_generation.py index aa1753837..0e1222c64 100644 --- a/tests/unit/codegen/sdk/code_generation/test_api_doc_generation.py +++ b/tests/unit/codegen/sdk/code_generation/test_api_doc_generation.py @@ -48,6 +48,7 @@ def test_api_doc_generation_sanity(codebase, language: ProgrammingLanguage) -> N @pytest.mark.timeout(120) +@pytest.mark.xdist_group("codegen") def test_mdx_api_doc_generation_sanity(codebase) -> None: docs_json = generate_docs_json(codebase, "HEAD") diff --git a/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py b/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py new file mode 100644 index 000000000..a084ae143 --- /dev/null +++ b/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py @@ -0,0 +1,346 @@ +import pytest + +from codegen.sdk.core.codebase import Codebase + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"a.py": "a", "b.py": "b"}, {"a.py": "b", "b.py": "b"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset(codebase: Codebase, assert_expected, tmp_path): + # External change should be preserved + (tmp_path / "a.py").write_text("b") + # Programmatic change should be reset + codebase.get_file("b.py").edit("changed") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"a.py": "a"}, {"a.py": "b"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_external_changes(codebase: Codebase, assert_expected): + # External change should be preserved + codebase.get_file("a.py").path.write_text("b") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"a.py": "a"}, {"a.py": "a", "new.py": "new content"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_manual_file_add(codebase: Codebase, assert_expected, tmp_path): + # Manually create a new file - should be preserved + new_file = tmp_path / "new.py" + new_file.write_text("new content") + # Make programmatic change that should be reset + codebase.get_file("a.py").edit("changed") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"a.py": "a", "b.py": "b"}, {"a.py": "a"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_manual_file_delete(codebase: Codebase, assert_expected): + # Manual deletion should be preserved + codebase.get_file("b.py").path.unlink() + # Programmatic change should be reset + codebase.get_file("a.py").edit("changed") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"old.py": "content"}, {"new.py": "content"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_manual_file_rename(codebase: Codebase, tmp_path, assert_expected): + # Manual rename should be preserved + old_path = codebase.get_file("old.py").path + new_path = tmp_path / "new.py" + old_path.rename(new_path) + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + { + "src/main.py": "def main():\n print('hello')", + "src/utils/helpers.py": "def helper():\n return True", + "tests/test_main.py": "def test_main():\n assert True", + }, + { + "src/main.py": "def main():\n print('modified')", + "src/utils/helpers.py": "def helper():\n return True", + "tests/test_main.py": "def test_main():\n assert False", + }, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_nested_directories(codebase: Codebase, assert_expected, tmp_path): + """Test reset with nested directory structure.""" + # External changes should be preserved + (tmp_path / "src/main.py").write_text("def main():\n print('modified')") + (tmp_path / "tests/test_main.py").write_text("def test_main():\n assert False") + # Programmatic changes should be reset + codebase.get_file("src/utils/helpers.py").edit("def helper():\n return False") + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + { + "app.py": "import json\n\ndata = {\n 'name': 'test',\n 'value': 123\n}", + "config.json": '{\n "debug": true,\n "port": 8080\n}', + "README.md": "# Project\nThis is a test project.", + }, + { + "app.py": "import json\n\ndata = {\n 'name': 'test',\n 'value': 123\n}", + "config.json": '{\n "debug": false,\n "env": "prod"\n}', + "README.md": "# Modified Project\nUpdated documentation.", + }, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_mixed_content(codebase: Codebase, assert_expected, tmp_path): + """Test reset with different types of file content.""" + # External changes should be preserved + (tmp_path / "config.json").write_text('{\n "debug": false,\n "env": "prod"\n}') + (tmp_path / "README.md").write_text("# Modified Project\nUpdated documentation.") + # Programmatic changes should be reset + codebase.get_file("app.py").edit("import json\n\ndata = {'name': 'modified'}") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + { + "module.py": """class ComplexClass: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + return self.value + + def decrement(self): + self.value -= 1 + return self.value + + def reset(self): + self.value = 0 + return self.value""", + }, + { + "module.py": """class ComplexClass: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + return self.value + + def decrement(self): + self.value -= 1 + return self.value + + def reset(self): + self.value = 0 + return self.value""", + }, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_large_file(codebase: Codebase, assert_expected): + """Test reset with a larger file containing multiple methods.""" + codebase.get_file("module.py").edit("""class ModifiedClass: + def __init__(self): + self.value = 100""") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"src/a.py": "original content"}, {"src/a.py": "modified content", "src/b.py": "new file content"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_preserves_external_changes(codebase: Codebase, assert_expected, tmp_path): + # Make external changes to existing file + src_dir = tmp_path / "src" + src_dir.mkdir(exist_ok=True) + (src_dir / "a.py").write_text("modified content") + + # Add new file externally + (src_dir / "b.py").write_text("new file content") + + # Reset should detect and preserve these changes + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + {"src/main.py": "def main():\n pass", "src/utils.py": "def helper():\n pass"}, + {"src/main.py": "def main():\n return 42", "src/utils.py": "def helper():\n pass", "src/new_module.py": "# New module"}, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_mixed_changes(codebase: Codebase, assert_expected, tmp_path): + # Make programmatic change that should be reset + codebase.get_file("src/utils.py").edit("def helper():\n return None") + + # Make external changes that should be preserved + src_dir = tmp_path / "src" + (src_dir / "main.py").write_text("def main():\n return 42") + (src_dir / "new_module.py").write_text("# New module") + + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ({"config/settings.py": "DEBUG = False"}, {"config/settings.py": "DEBUG = True", "config/local.py": "# Local overrides"}), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_nested_external_changes(codebase: Codebase, assert_expected, tmp_path): + # Create nested directory structure with changes + config_dir = tmp_path / "config" + config_dir.mkdir(exist_ok=True) + + # Modify existing file + (config_dir / "settings.py").write_text("DEBUG = True") + + # Add new file in nested directory + (config_dir / "local.py").write_text("# Local overrides") + + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.xfail(reason="Needs CG-10484") +@pytest.mark.parametrize( + "original, expected", + [ + ( + {"file.py": "initial content"}, + {"file.py": "final external content"}, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_multiple_programmatic_edits(codebase: Codebase, assert_expected): + """Test reset after multiple programmatic edits to the same file.""" + # Make multiple programmatic changes that should all be reset + codebase.get_file("file.py").edit("first edit") + codebase.get_file("file.py").edit("second edit") + codebase.get_file("file.py").edit("third edit") + + # Make external change that should be preserved + codebase.get_file("file.py").path.write_text("final external content") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.xfail(reason="Needs CG-10484") +@pytest.mark.parametrize( + "original, expected", + [ + ( + {"file.py": "def main():\n return 0"}, + {"file.py": "def main():\n return 42"}, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_interleaved_changes(codebase: Codebase, assert_expected): + """Test reset with interleaved programmatic and external changes.""" + # Interleave programmatic and external changes + codebase.get_file("file.py").edit("def main():\n return 1") + codebase.get_file("file.py").path.write_text("def main():\n return 42") + codebase.get_file("file.py").edit("def main():\n return 2") + codebase.commit() + codebase.reset() + assert_expected(codebase) + + +@pytest.mark.parametrize( + "original, expected", + [ + ( + { + "file.py": """ +class Test: + def method1(self): + pass +""" + }, + { + "file.py": """ +class Test: + def method1(self): + pass +""" + }, + ), + ], + indirect=["original", "expected"], +) +def test_codebase_reset_complex_changes(codebase: Codebase, assert_expected): + """Test reset with a mix of content additions, modifications, and external changes.""" + # Make several programmatic changes + for i in range(5): + codebase.get_file("file.py").insert_after(f"# comment {i}") + codebase.commit() + + codebase.reset() + assert_expected(codebase) diff --git a/tests/unit/codegen/sdk/conftest.py b/tests/unit/codegen/sdk/conftest.py new file mode 100644 index 000000000..a9b12f9a5 --- /dev/null +++ b/tests/unit/codegen/sdk/conftest.py @@ -0,0 +1,41 @@ +import pytest + +from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.enums import ProgrammingLanguage + + +@pytest.fixture +def original(request): + return request.param + + +@pytest.fixture +def expected(request): + return request.param + + +@pytest.fixture +def programming_language(request): + return request.param + + +@pytest.fixture +def codebase(tmp_path, original: dict[str, str], programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON): + with get_codebase_session(files=original, programming_language=programming_language, tmpdir=tmp_path) as codebase: + yield codebase + + +@pytest.fixture +def assert_expected(expected: dict[str, str], tmp_path): + def assert_expected(codebase: Codebase): + codebase.commit() + for file in expected: + assert tmp_path.joinpath(file).exists() + assert tmp_path.joinpath(file).read_text() == expected[file] + assert codebase.get_file(file).content.strip() == expected[file].strip() + for file in codebase.files: + if file.file.path.exists(): + assert file.filepath in expected + + return assert_expected diff --git a/uv.lock b/uv.lock index 7a4d001bd..41dbf5b5b 100644 --- a/uv.lock +++ b/uv.lock @@ -367,7 +367,6 @@ wheels = [ [[package]] name = "codegen" -version = "0.5.3.dev23+g28194e2" source = { editable = "." } dependencies = [ { name = "anthropic" }, @@ -453,6 +452,7 @@ dev = [ { name = "pre-commit" }, { name = "pre-commit-uv" }, { name = "pytest" }, + { name = "pytest-benchmark", extra = ["histogram"] }, { name = "pytest-cov" }, { name = "pytest-mock" }, { name = "pytest-timeout" }, @@ -482,7 +482,7 @@ requires-dist = [ { name = "hatchling", specifier = ">=1.25.0" }, { name = "humanize", specifier = ">=4.10.0,<5.0.0" }, { name = "lazy-object-proxy", specifier = ">=0.0.0" }, - { name = "loguru", specifier = ">=0.7.2,<1.0.0" }, + { name = "loguru", specifier = ">=0.7.3" }, { name = "mini-racer", specifier = ">=0.12.4" }, { name = "networkx", specifier = ">=3.4.1" }, { name = "openai", specifier = "==1.59.9" }, @@ -545,6 +545,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pre-commit-uv", specifier = ">=4.1.4" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-benchmark", extras = ["histogram"], specifier = ">=5.1.0" }, { name = "pytest-cov", specifier = ">=6.0.0,<6.0.1" }, { name = "pytest-mock", specifier = ">=3.14.0,<4.0.0" }, { name = "pytest-timeout", specifier = ">=2.3.1" }, @@ -1079,6 +1080,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, ] +[[package]] +name = "importlib-metadata" +version = "8.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/08/c1395a292bb23fd03bdf572a1357c5a733d3eecbab877641ceacab23db6e/importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580", size = 55767 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 }, +] + [[package]] name = "inflect" version = "5.6.2" @@ -1693,6 +1706,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 }, ] +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 }, +] + [[package]] name = "pycparser" version = "2.22" @@ -1769,6 +1791,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/d7/f1b7db88d8e4417c5d47adad627a93547f44bdc9028372dbd2313f34a855/pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a", size = 62725 }, ] +[[package]] +name = "pygal" +version = "3.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/7b/8f50821a0f1585881ef40ae13ecb7603b0d81ef99fedf992ec35e6b6f7d5/pygal-3.0.5.tar.gz", hash = "sha256:c0a0f34e5bc1c01975c2bfb8342ad521e293ad42e525699dd00c4d7a52c14b71", size = 80489 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/7d/b5d656dbeb73f488ce7409a75108a775f6cf8e20624ed8025a9476cbc1bb/pygal-3.0.5-py3-none-any.whl", hash = "sha256:a3268a5667b470c8fbbb0eca7e987561a7321caeba589d40e4c1bc16dbe71393", size = 129548 }, +] + +[[package]] +name = "pygaljs" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/19/3a53f34232a9e6ddad665e71c83693c5db9a31f71785105905c5bc9fbbba/pygaljs-1.0.2.tar.gz", hash = "sha256:0b71ee32495dcba5fbb4a0476ddbba07658ad65f5675e4ad409baf154dec5111", size = 89711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/6f/07dab31ca496feda35cf3455b9e9380c43b5c685bb54ad890831c790da38/pygaljs-1.0.2-py2.py3-none-any.whl", hash = "sha256:d75e18cb21cc2cda40c45c3ee690771e5e3d4652bf57206f20137cf475c0dbe8", size = 91111 }, +] + [[package]] name = "pygit2" version = "1.17.0" @@ -1967,6 +2010,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] +[[package]] +name = "pytest-benchmark" +version = "5.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/d0/a8bd08d641b393db3be3819b03e2d9bb8760ca8479080a26a5f6e540e99c/pytest-benchmark-5.1.0.tar.gz", hash = "sha256:9ea661cdc292e8231f7cd4c10b0319e56a2118e2c09d9f50e1b3d150d2aca105", size = 337810 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/d6/b41653199ea09d5969d4e385df9bbfd9a100f28ca7e824ce7c0a016e3053/pytest_benchmark-5.1.0-py3-none-any.whl", hash = "sha256:922de2dfa3033c227c96da942d1878191afa135a29485fb942e85dff1c592c89", size = 44259 }, +] + +[package.optional-dependencies] +histogram = [ + { name = "pygal" }, + { name = "pygaljs" }, + { name = "setuptools" }, +] + [[package]] name = "pytest-cov" version = "6.0.0" @@ -2931,3 +2994,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/50/05/51dcca9a9bf5e1bce wheels = [ { url = "https://files.pythonhosted.org/packages/d6/45/fc303eb433e8a2a271739c98e953728422fa61a3c1f36077a49e395c972e/xmltodict-0.14.2-py2.py3-none-any.whl", hash = "sha256:20cc7d723ed729276e808f26fb6b3599f786cbc37e06c65e192ba77c40f20aac", size = 9981 }, ] + +[[package]] +name = "zipp" +version = "3.21.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/50/bad581df71744867e9468ebd0bcd6505de3b275e06f202c2cb016e3ff56f/zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4", size = 24545 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/1a/7e4798e9339adc931158c9d69ecc34f5e6791489d469f5e50ec15e35f458/zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931", size = 9630 }, +]