Skip to content

Commit 54a15b3

Browse files
eacodegenbagel897codegen-botcaroljung-cgrushilpatel0
authored
Make codebase.reset only reset changes made by the sdk (#74)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed - [ ] I have read and agree to the [Contributor License Agreement](../CLA.md) --------- Co-authored-by: bagel897 <ellenagarwal897@gmail.com> Co-authored-by: codegen-bot <team+codegenbot@codegen.sh> Co-authored-by: Carol Jung <165736129+caroljung-cg@users.noreply.github.com> Co-authored-by: Rushil Patel <rpatel@codegen.com> Co-authored-by: Christine Wang <christine@codegen.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Jay Hack <jayhack@users.noreply.github.com> Co-authored-by: tomcodgen <tkucar@codegen.com> Co-authored-by: jemeza-codegen <165736868+jemeza-codegen@users.noreply.github.com>
1 parent 927ebe4 commit 54a15b3

File tree

19 files changed

+624
-37
lines changed

19 files changed

+624
-37
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ graph-sitter-types/typings/**
6565
coverage.json
6666
tests/integration/verified_codemods/codemod_data/repo_commits.json
6767
.codegen/*
68+
.benchmarks/*

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ dev-dependencies = [
139139
"black>=24.8.0",
140140
"isort>=5.13.2",
141141
"emoji>=2.14.0",
142+
"pytest-benchmark[histogram]>=5.1.0",
143+
"loguru>=0.7.3",
142144
]
143145
keyring-provider = "subprocess"
144146
#extra-index-url = ["https://aws@codegen-922078275900.d.codeartifact.us-east-1.amazonaws.com/pypi/codegen/simple/"]

src/codegen/git/repo_operator/local_repo_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bo
7272
def create_from_commit(cls, repo_path: str, commit: str, url: str) -> Self:
7373
"""Do a shallow checkout of a particular commit to get a repository from a given remote URL."""
7474
op = cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False)
75+
op.discard_changes()
7576
if op.get_active_branch_or_commit() != commit:
76-
op.discard_changes()
7777
op.create_remote("origin", url)
7878
op.git_cli.remotes["origin"].fetch(commit, depth=1)
7979
op.checkout_commit(commit)

src/codegen/git/repo_operator/remote_repo_operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def __init__(
4242
setup_option: SetupOption = SetupOption.PULL_OR_CLONE,
4343
shallow: bool = True,
4444
github_type: GithubType = GithubType.GithubEnterprise,
45+
bot_commit: bool = True,
4546
) -> None:
46-
super().__init__(repo_config=repo_config, base_dir=base_dir)
47+
super().__init__(repo_config=repo_config, base_dir=base_dir, bot_commit=bot_commit)
4748
self.github_type = github_type
4849
self.setup_repo_dir(setup_option=setup_option, shallow=shallow)
4950

src/codegen/git/repo_operator/repo_operator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,21 @@ def viz_file_path(self) -> str:
6767
def git_cli(self) -> GitCLI:
6868
"""Note: this is recursive, may want to look out"""
6969
git_cli = GitCLI(self.repo_path)
70+
has_username = False
71+
has_email = False
72+
with git_cli.config_reader(None) as reader:
73+
if reader.has_option("user", "name"):
74+
has_username = True
75+
if reader.has_option("user", "email"):
76+
has_email = True
7077
with git_cli.config_writer("repository") as writer:
71-
if self.bot_commit:
78+
if not has_username or not has_email or self.bot_commit:
7279
if not writer.has_section("user"):
7380
writer.add_section("user")
74-
writer.set("user", "name", CODEGEN_BOT_NAME)
75-
writer.set("user", "email", CODEGEN_BOT_EMAIL)
81+
if not has_username or self.bot_commit:
82+
writer.set("user", "name", CODEGEN_BOT_NAME)
83+
if not has_email or self.bot_commit:
84+
writer.set("user", "email", CODEGEN_BOT_EMAIL)
7685
return git_cli
7786

7887
@property

src/codegen/sdk/codebase/codebase_graph.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from codegen.sdk.core.interfaces.importable import Importable
3131
from codegen.sdk.core.node_id_factory import NodeId
3232
from codegen.sdk.enums import Edge, EdgeType, NodeType, ProgrammingLanguage
33+
from codegen.sdk.extensions.io import write_changes
3334
from codegen.sdk.extensions.sort import sort_editables
3435
from codegen.sdk.extensions.utils import uncache_all
3536
from codegen.sdk.typescript.external.ts_declassify.ts_declassify import TSDeclassify
@@ -107,6 +108,7 @@ class CodebaseGraph:
107108
flags: Flags
108109
session_options: SessionOptions = SessionOptions()
109110
projects: list[ProjectConfig]
111+
unapplied_diffs: list[DiffLite]
110112

111113
def __init__(
112114
self,
@@ -161,6 +163,7 @@ def __init__(
161163
self.synced_commit = None
162164
self.pending_syncs = []
163165
self.all_syncs = []
166+
self.unapplied_diffs = []
164167
self.pending_files = set()
165168
self.flags = Flags()
166169

@@ -232,9 +235,40 @@ def apply_diffs(self, diff_list: list[DiffLite]) -> None:
232235
self.generation += 1
233236
self._process_diff_files(by_sync_type)
234237

238+
def _reset_files(self, syncs: list[DiffLite]) -> None:
239+
files_to_write = []
240+
files_to_remove = []
241+
modified_files = set()
242+
for sync in syncs:
243+
if sync.path in modified_files:
244+
continue
245+
if sync.change_type == ChangeType.Removed:
246+
files_to_write.append((sync.path, sync.old_content))
247+
modified_files.add(sync.path)
248+
logger.info(f"Removing {sync.path} from disk")
249+
elif sync.change_type == ChangeType.Modified:
250+
files_to_write.append((sync.path, sync.old_content))
251+
modified_files.add(sync.path)
252+
elif sync.change_type == ChangeType.Renamed:
253+
files_to_write.append((sync.rename_from, sync.old_content))
254+
files_to_remove.append(sync.rename_to)
255+
modified_files.add(sync.rename_from)
256+
modified_files.add(sync.rename_to)
257+
elif sync.change_type == ChangeType.Added:
258+
files_to_remove.append(sync.path)
259+
modified_files.add(sync.path)
260+
logger.info(f"Writing {len(files_to_write)} files to disk and removing {len(files_to_remove)} files")
261+
write_changes(files_to_remove, files_to_write)
262+
263+
@stopwatch
264+
def reset_codebase(self) -> None:
265+
self._reset_files(self.all_syncs + self.pending_syncs + self.unapplied_diffs)
266+
self.unapplied_diffs.clear()
267+
235268
@stopwatch
236269
def undo_applied_diffs(self) -> None:
237270
self.transaction_manager.clear_transactions()
271+
self.reset_codebase()
238272
self.check_changes()
239273
self.pending_syncs.clear() # Discard pending changes
240274
if len(self.all_syncs) > 0:
@@ -256,6 +290,9 @@ def _revert_diffs(self, diff_list: list[DiffLite]) -> None:
256290

257291
def save_commit(self, commit: GitCommit) -> None:
258292
if commit is not None:
293+
logger.info(f"Saving commit {commit.hexsha} to graph")
294+
self.all_syncs.clear()
295+
self.unapplied_diffs.clear()
259296
self.synced_commit = commit
260297
if self.config.feature_flags.verify_graph:
261298
self.old_graph = self._graph.copy()
@@ -630,9 +667,11 @@ def commit_transactions(self, sync_graph: bool = True, sync_file: bool = True, f
630667
# Commit transactions for all contexts
631668
files_to_lock = self.transaction_manager.to_commit(files)
632669
diffs = self.transaction_manager.commit(files_to_lock)
633-
# Filter diffs to only include files that are still in the graph
634-
diffs = [diff for diff in diffs if self.get_file(diff.path) is not None]
635-
self.pending_syncs.extend(diffs)
670+
for diff in diffs:
671+
if self.get_file(diff.path) is None:
672+
self.unapplied_diffs.append(diff)
673+
else:
674+
self.pending_syncs.append(diff)
636675

637676
# Write files if requested
638677
if sync_file:

src/codegen/sdk/codebase/diff_lite.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from enum import Enum, auto
1+
from enum import IntEnum, auto
22
from os import PathLike
33
from pathlib import Path
44
from typing import NamedTuple, Self
@@ -7,7 +7,7 @@
77
from watchfiles import Change
88

99

10-
class ChangeType(Enum):
10+
class ChangeType(IntEnum):
1111
Modified = auto()
1212
Removed = auto()
1313
Renamed = auto()
@@ -40,8 +40,9 @@ class DiffLite(NamedTuple):
4040

4141
change_type: ChangeType
4242
path: Path
43-
rename_from: str | None = None
44-
rename_to: str | None = None
43+
rename_from: Path | None = None
44+
rename_to: Path | None = None
45+
old_content: bytes | None = None
4546

4647
@classmethod
4748
def from_watch_change(cls, change: Change, path: PathLike) -> Self:
@@ -52,11 +53,15 @@ def from_watch_change(cls, change: Change, path: PathLike) -> Self:
5253

5354
@classmethod
5455
def from_git_diff(cls, git_diff: Diff):
56+
old = None
57+
if git_diff.a_blob:
58+
old = git_diff.a_blob.data_stream.read()
5559
return cls(
5660
change_type=ChangeType.from_git_change_type(git_diff.change_type),
57-
path=Path(git_diff.a_path),
58-
rename_from=git_diff.rename_from,
59-
rename_to=git_diff.rename_to,
61+
path=Path(git_diff.a_path) if git_diff.a_path else None,
62+
rename_from=Path(git_diff.rename_from) if git_diff.rename_from else None,
63+
rename_to=Path(git_diff.rename_to) if git_diff.rename_to else None,
64+
old_content=old,
6065
)
6166

6267
@classmethod

src/codegen/sdk/codebase/transaction_manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44
from typing import TYPE_CHECKING
55

6-
from codegen.sdk.codebase.diff_lite import DiffLite
6+
from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite
77
from codegen.sdk.codebase.transactions import (
88
EditTransaction,
99
FileAddTransaction,
@@ -163,16 +163,16 @@ def to_commit(self, files: set[Path] | None = None) -> set[Path]:
163163
return set(self.queued_transactions.keys())
164164
return files.intersection(self.queued_transactions)
165165

166-
def commit(self, files: set[Path]) -> set[DiffLite]:
166+
def commit(self, files: set[Path]) -> list[DiffLite]:
167167
"""Execute transactions in bulk for each file, in reverse order of start_byte.
168-
Returns the set of diffs that were committed.
168+
Returns the list of diffs that were committed.
169169
"""
170170
if self._commiting:
171171
logger.warn("Skipping commit, already committing")
172-
return set()
172+
return []
173173
self._commiting = True
174174
try:
175-
diffs: set[DiffLite] = set()
175+
diffs: list[DiffLite] = []
176176
if not self.queued_transactions or len(self.queued_transactions) == 0:
177177
return diffs
178178

@@ -187,9 +187,16 @@ def commit(self, files: set[Path]) -> set[DiffLite]:
187187
logger.info(f"Committing {len(self.queued_transactions[file])} transactions for {file}")
188188
for file_path in files:
189189
file_transactions = self.queued_transactions.pop(file_path, [])
190+
modified = False
190191
for transaction in file_transactions:
191192
# Add diff IF the file is a source file
192-
diffs.add(transaction.get_diff())
193+
diff = transaction.get_diff()
194+
if diff.change_type == ChangeType.Modified:
195+
if not modified:
196+
modified = True
197+
diffs.append(diff)
198+
else:
199+
diffs.append(diff)
193200
transaction.execute()
194201
return diffs
195202
finally:

src/codegen/sdk/codebase/transactions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def execute(self) -> None:
127127

128128
def get_diff(self) -> DiffLite:
129129
"""Gets the diff produced by this transaction"""
130-
return DiffLite(ChangeType.Modified, self.file_path)
130+
return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes)
131131

132132
def diff_str(self) -> str:
133133
"""Human-readable string representation of the change"""
@@ -170,7 +170,7 @@ def execute(self) -> None:
170170

171171
def get_diff(self) -> DiffLite:
172172
"""Gets the diff produced by this transaction"""
173-
return DiffLite(ChangeType.Modified, self.file_path)
173+
return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes)
174174

175175
def diff_str(self) -> str:
176176
"""Human-readable string representation of the change"""
@@ -205,7 +205,7 @@ def execute(self) -> None:
205205

206206
def get_diff(self) -> DiffLite:
207207
"""Gets the diff produced by this transaction"""
208-
return DiffLite(ChangeType.Modified, self.file_path)
208+
return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes)
209209

210210
def diff_str(self) -> str:
211211
"""Human-readable string representation of the change"""
@@ -269,7 +269,7 @@ def execute(self) -> None:
269269

270270
def get_diff(self) -> DiffLite:
271271
"""Gets the diff produced by this transaction"""
272-
return DiffLite(ChangeType.Renamed, self.file_path, self.file_path, self.new_file_path)
272+
return DiffLite(ChangeType.Renamed, self.file_path, self.file_path, self.new_file_path, old_content=self.file.content_bytes)
273273

274274
def diff_str(self) -> str:
275275
"""Human-readable string representation of the change"""
@@ -294,7 +294,7 @@ def execute(self) -> None:
294294

295295
def get_diff(self) -> DiffLite:
296296
"""Gets the diff produced by this transaction"""
297-
return DiffLite(ChangeType.Removed, self.file_path)
297+
return DiffLite(ChangeType.Removed, self.file_path, old_content=self.file.content_bytes)
298298

299299
def diff_str(self) -> str:
300300
"""Human-readable string representation of the change"""

src/codegen/sdk/core/codebase.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def current_commit(self) -> GitCommit | None:
738738
return self._op.git_cli.head.commit
739739

740740
@stopwatch
741-
def reset(self) -> None:
741+
def reset(self, git_reset: bool = False) -> None:
742742
"""Resets the codebase by:
743743
- Discarding any staged/unstaged changes
744744
- Resetting stop codemod limits: (max seconds, max transactions, max AI requests)
@@ -751,7 +751,8 @@ def reset(self) -> None:
751751
- .ipynb files (Jupyter notebooks, where you are likely developing)
752752
"""
753753
logger.info("Resetting codebase ...")
754-
self._op.discard_changes() # Discard any changes made to the raw file state
754+
if git_reset:
755+
self._op.discard_changes() # Discard any changes made to the raw file state
755756
self._num_ai_requests = 0
756757
self.reset_logs()
757758
self.G.undo_applied_diffs()
@@ -818,12 +819,14 @@ def get_diffs(self, base: str | None = None) -> list[Diff]:
818819
return self._op.get_diffs(base)
819820

820821
@noapidoc
821-
def get_diff(self, base: str | None = None) -> str:
822+
def get_diff(self, base: str | None = None, stage_files: bool = False) -> str:
822823
"""Produce a single git diff for all files."""
823-
self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff
824+
if stage_files:
825+
self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff
824826
if base is None:
825-
return self._op.git_cli.git.diff(patch=True, full_index=True, staged=True)
826-
return self._op.git_cli.git.diff(base, full_index=True)
827+
diff = self._op.git_cli.git.diff("HEAD", patch=True, full_index=True)
828+
return diff
829+
return self._op.git_cli.git.diff(base, patch=True, full_index=True)
827830

828831
@noapidoc
829832
def clean_repo(self):

0 commit comments

Comments
 (0)