Skip to content

Commit a25691c

Browse files
authored
Add step magics (#493)
* Add step magics * fix * fix test cases * fix test cases * Bump version * Update pyproject.toml * update libraries * add partial svelte support * unlimit call llm * remove depscan test for debugging
1 parent c955ada commit a25691c

File tree

37 files changed

+916
-893
lines changed

37 files changed

+916
-893
lines changed

patchwork/common/context_strategy/javascript.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
(statement_block) @node
2727
""".strip()
2828

29-
_javascript_exts = [".js", ".ts"]
29+
_javascript_exts = [".js", ".ts", ".svelte"]
3030
_jsx_exts = [".jsx", ".tsx"]
3131

3232

patchwork/step.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,66 @@
1-
from typing_extensions import Protocol
1+
import abc
2+
from enum import Flag, auto
23

4+
from patchwork.logger import logger
35

4-
class Step(Protocol):
5-
"""
6-
Protocol for a Step.
7-
Steps do not have to inherit from this class, but they must implement the run method.
8-
The __init__ method should have a single argument, inputs, which is a dictionary of inputs.
9-
This is the only opportunity to set the inputs as a class property to be used in `run`.
10-
"""
116

7+
class StepStatus(Flag):
8+
COMPLETED = auto()
9+
FAILED = auto()
10+
SKIPPED = auto()
11+
12+
def __str__(self):
13+
return self.name.lower()
14+
15+
16+
class Step(abc.ABC):
17+
def __init__(self, inputs: dict):
18+
"""
19+
Initializes the step.
20+
:param inputs: a dictionary of inputs
21+
"""
22+
self.__status = StepStatus.COMPLETED
23+
self.__status_msg = None
24+
self.__step_name = self.__class__.__name__
25+
# abit of a hack to wrap the implemented run method
26+
self.original_run = self.run
27+
self.run = self.__managed_run
28+
29+
def __managed_run(self, *args, **kwargs):
30+
logger.info(f"Run started {self.__step_name}")
31+
32+
exc = None
33+
try:
34+
output = self.original_run(*args, **kwargs)
35+
except Exception as e:
36+
exc = e
37+
38+
is_fail = self.__status == StepStatus.FAILED or exc is not None
39+
if self.__status_msg is not None:
40+
message_logger = logger.error if is_fail else logger.info
41+
message_logger(f"Step {self.__step_name} message: {self.__status_msg}")
42+
43+
if exc is not None:
44+
logger.error(f"Step {self.__step_name} failed")
45+
raise exc
46+
47+
if is_fail:
48+
raise ValueError(f"Step {self.__step_name} failed")
49+
50+
logger.info(f"Run {self.__status} {self.__step_name}")
51+
return output
52+
53+
def set_status(self, status: StepStatus, msg: str = None):
54+
if status not in StepStatus:
55+
raise ValueError(f"Invalid status: {status}")
56+
self.__status = status
57+
self.__status_msg = msg
58+
59+
@property
60+
def status(self) -> StepStatus:
61+
return self.__status
62+
63+
@abc.abstractmethod
1264
def run(self) -> dict:
1365
"""
1466
Runs the step.

patchwork/steps/AnalyzeImpact/AnalyzeImpact.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33

44
from patchwork.logger import logger
5-
from patchwork.step import Step
5+
from patchwork.step import Step, StepStatus
66

77
_PURL_TO_LANGUAGE_ = {
88
"pypi": "python",
@@ -76,8 +76,7 @@ def find_dependency_usage(directory, dependency, language, methods):
7676

7777
class AnalyzeImpact(Step):
7878
def __init__(self, inputs: dict):
79-
logger.info(f"Run started {self.__class__.__name__}")
80-
79+
super().__init__(inputs)
8180
required_keys = {"extracted_responses", "library_name", "platform_type"}
8281
if not all(key in inputs.keys() for key in required_keys):
8382
raise ValueError(f'Missing required data: "{required_keys}"')
@@ -86,6 +85,10 @@ def __init__(self, inputs: dict):
8685

8786
def run(self) -> dict:
8887
extracted_responses = self.inputs["extracted_responses"]
88+
if len(extracted_responses) == 0:
89+
self.set_status(StepStatus.SKIPPED, "No extracted responses found")
90+
return dict(files_to_patch=[])
91+
8992
name = self.inputs["library_name"]
9093
platform_type = self.inputs["platform_type"]
9194

@@ -142,5 +145,4 @@ def run(self) -> dict:
142145
)
143146
extracted_data.append(data)
144147

145-
logger.info(f"Run completed {self.__class__.__name__}")
146148
return dict(files_to_patch=extracted_data)

patchwork/steps/CallAPI/CallAPI.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class CallAPI(Step):
99
def __init__(self, inputs):
10+
super().__init__(inputs)
1011
self.url = inputs["url"]
1112
self.method = inputs["method"]
1213
possible_headers = inputs.get("headers", {})

patchwork/steps/CallCode2Prompt/CallCode2Prompt.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import subprocess
33
from pathlib import Path
44

5-
from patchwork.logger import logger
6-
from patchwork.step import Step
5+
from patchwork.step import Step, StepStatus
76

87
FOLDER_PATH = "folder_path"
98

@@ -12,8 +11,7 @@ class CallCode2Prompt(Step):
1211
required_keys = {FOLDER_PATH}
1312

1413
def __init__(self, inputs: dict):
15-
logger.info(f"Run started {self.__class__.__name__}")
16-
14+
super().__init__(inputs)
1715
if not all(key in inputs.keys() for key in self.required_keys):
1816
raise ValueError(f'Missing required data: "{self.required_keys}"')
1917

@@ -28,9 +26,6 @@ def __init__(self, inputs: dict):
2826
with open(self.code_file_path, "a") as file:
2927
pass # No need to write anything, just create the file if it doesn't exist
3028

31-
# Prepare for data extraction
32-
self.extracted_data = []
33-
3429
def run(self) -> dict:
3530
cmd = [
3631
"code2prompt",
@@ -54,13 +49,13 @@ def run(self) -> dict:
5449
with open(self.code_file_path, "r") as file:
5550
file_content = file.read()
5651
except FileNotFoundError:
57-
logger.info(f"Unable to find file: {self.code_file_path}")
52+
self.set_status(StepStatus.FAILED, f"Unable to find file: {self.code_file_path}")
53+
return dict(files_to_patch=[])
5854

5955
lines = file_content.splitlines(keepends=True)
6056

61-
self.extracted_data.append(
62-
dict(uri=self.code_file_path, startLine=0, endLine=len(lines), fullContent=prompt_content_md)
57+
return dict(
58+
files_to_patch=[
59+
dict(uri=self.code_file_path, startLine=0, endLine=len(lines), fullContent=prompt_content_md)
60+
]
6361
)
64-
65-
logger.info(f"Run completed {self.__class__.__name__}")
66-
return dict(files_to_patch=self.extracted_data)

patchwork/steps/CallLLM/CallLLM.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing_extensions import Any, Protocol
1313

1414
from patchwork.logger import logger
15-
from patchwork.step import Step
15+
from patchwork.step import Step, StepStatus
1616

1717
TOKEN_URL = "https://app.patched.codes/signin"
1818
_DEFAULT_PATCH_URL = "https://patchwork.patched.codes/v1"
@@ -158,8 +158,7 @@ def parse_model_args(self, model_args: dict) -> dict:
158158

159159
class CallLLM(Step):
160160
def __init__(self, inputs: dict):
161-
logger.info(f"Run started {self.__class__.__name__}")
162-
161+
super().__init__(inputs)
163162
# Set 'openai_key' from inputs or environment if not already set
164163
inputs.setdefault("openai_api_key", os.environ.get("OPENAI_API_KEY"))
165164

@@ -178,7 +177,7 @@ def __init__(self, inputs: dict):
178177
else:
179178
raise ValueError('Missing required data: "prompt_file" or "prompts"')
180179

181-
self.call_limit = int(inputs.get("max_llm_calls", 50))
180+
self.call_limit = int(inputs.get("max_llm_calls", -1))
182181
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
183182
self.client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
184183
self.save_responses_to_file = inputs.get("save_responses_to_file", None)
@@ -223,13 +222,22 @@ def __init__(self, inputs: dict):
223222

224223
def run(self) -> dict:
225224
prompt_length = len(self.prompts)
225+
if prompt_length == 0:
226+
self.set_status(StepStatus.SKIPPED, "No prompts to process")
227+
return dict(openai_responses=[])
228+
226229
if prompt_length > self.call_limit:
227230
logger.debug(
228231
f"Number of prompts ({prompt_length}) exceeds the call limit ({self.call_limit}). "
229232
f"Only the first {self.call_limit} prompts will be processed."
230233
)
231234

232-
contents = self.llm.call(list(islice(self.prompts, self.call_limit)))
235+
if self.call_limit > 0:
236+
prompts = list(islice(self.prompts, self.call_limit))
237+
else:
238+
prompts = self.prompts
239+
240+
contents = self.llm.call(prompts)
233241

234242
if self.save_responses_to_file:
235243
# Convert relative path to absolute path
@@ -247,5 +255,4 @@ def run(self) -> dict:
247255
}
248256
f.write(json.dumps(data) + "\n")
249257

250-
logger.info(f"Run completed {self.__class__.__name__}")
251258
return dict(openai_responses=contents)

patchwork/steps/CommitChanges/CommitChanges.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from patchwork.common.utils.utils import get_current_branch
1111
from patchwork.logger import logger
12-
from patchwork.step import Step
12+
from patchwork.step import Step, StepStatus
1313

1414

1515
@contextlib.contextmanager
@@ -100,8 +100,7 @@ class CommitChanges(Step):
100100
required_keys = {"modified_code_files"}
101101

102102
def __init__(self, inputs: dict):
103-
logger.info(f"Run started {self.__class__.__name__}")
104-
103+
super().__init__(inputs)
105104
if not all(key in inputs.keys() for key in self.required_keys):
106105
raise ValueError(f'Missing required data: "{self.required_keys}"')
107106

@@ -127,7 +126,7 @@ def run(self) -> dict:
127126
modified_files = {Path(modified_code_file["path"]).resolve() for modified_code_file in self.modified_code_files}
128127
true_modified_files = modified_files.intersection(repo_changed_files.union(repo_untracked_files))
129128
if not self.enabled or len(true_modified_files) < 1:
130-
logger.debug("Branch creation is disabled.")
129+
self.set_status(StepStatus.SKIPPED, "Branch creation is disabled.")
131130
from_branch = get_current_branch(repo)
132131
from_branch_name = from_branch.name if not from_branch.is_remote() else from_branch.remote_head
133132
return dict(target_branch=from_branch_name)
@@ -142,7 +141,6 @@ def run(self) -> dict:
142141
repo.git.add(modified_file)
143142
commit_with_msg(repo, f"Patched {modified_file}")
144143

145-
logger.info(f"Run completed {self.__class__.__name__}")
146144
return dict(
147145
base_branch=from_branch,
148146
target_branch=to_branch,

patchwork/steps/CreateIssue/CreateIssue.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
ScmPlatformClientProtocol,
99
get_slug_from_remote_url,
1010
)
11-
from patchwork.logger import logger
1211
from patchwork.step import Step
1312

1413

1514
class CreateIssue(Step):
1615
required_keys = {"issue_title", "issue_text", "scm_url"}
1716

1817
def __init__(self, inputs: dict):
19-
logger.info(f"Run started {self.__class__.__name__}")
20-
18+
super().__init__(inputs)
2119
if not all(key in inputs.keys() for key in self.required_keys):
2220
raise ValueError(f'Missing required data: "{self.required_keys}"')
2321

patchwork/steps/CreateIssueComment/CreateIssueComment.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
GitlabClient,
44
ScmPlatformClientProtocol,
55
)
6-
from patchwork.logger import logger
7-
from patchwork.step import Step
6+
from patchwork.step import Step, StepStatus
87

98

109
class CreateIssueComment(Step):
1110
required_keys = {"issue_url", "issue_text"}
1211

1312
def __init__(self, inputs: dict):
14-
logger.info(f"Run started {self.__class__.__name__}")
15-
13+
super().__init__(inputs)
1614
if not all(key in inputs.keys() for key in self.required_keys):
1715
raise ValueError(f'Missing required data: "{self.required_keys}"')
1816

@@ -35,7 +33,7 @@ def run(self) -> dict:
3533
slug, issue_id = self.scm_client.get_slug_and_id_from_url(self.issue_url)
3634
url = self.scm_client.create_issue_comment(slug, self.issue_text, issue_id=issue_id)
3735
except Exception as e:
38-
logger.error(e)
39-
return {}
36+
self.set_status(StepStatus.FAILED, f"Failed to create issue comment")
37+
raise e
4038

4139
return dict(issue_comment_url=url)

patchwork/steps/CreatePR/CreatePR.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ class CreatePR(Step):
1717
required_keys = {"target_branch"}
1818

1919
def __init__(self, inputs: dict):
20-
logger.info(f"Run started {self.__class__.__name__}")
21-
20+
super().__init__(inputs)
2221
if not all(key in inputs.keys() for key in self.required_keys):
2322
raise ValueError(f'Missing required data: "{self.required_keys}"')
2423

@@ -87,7 +86,6 @@ def run(self) -> dict:
8786
)
8887

8988
logger.info(f"[green]PR created at [link={url}]{url}[/link][/]", extra={"markup": True})
90-
logger.info(f"Run completed {self.__class__.__name__}")
9189
return {"pr_url": url}
9290

9391

0 commit comments

Comments
 (0)