diff --git a/pyproject.toml b/pyproject.toml index d585b81e9..12560ba39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,8 @@ dependencies = [ "PyGithub==2.5.0", "GitPython==3.1.44", "psutil>=5.8.0", + "fastapi[standard]<1.0.0,>=0.115.2", + "starlette<1.0.0,>=0.16.0", ] license = {file = "LICENSE"} classifiers = [ diff --git a/src/codegen/git/utils/branch_sync.py b/src/codegen/git/utils/branch_sync.py new file mode 100644 index 000000000..af7268a4e --- /dev/null +++ b/src/codegen/git/utils/branch_sync.py @@ -0,0 +1,53 @@ +import logging +from enum import StrEnum + +from git.remote import Remote + +from codegen.git.configs.constants import HIGHSIDE_REMOTE_NAME +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.git.schemas.enums import FetchResult +from codegen.git.schemas.github import GithubType +from codegen.git.utils.clone_url import get_authenticated_clone_url_for_repo_config +from codegen.utils.performance.stopwatch_utils import stopwatch + +logger = logging.getLogger(__name__) + + +class BranchSyncResult(StrEnum): + SUCCESS = "SUCCESS" + BRANCH_NOT_FOUND = "BRANCH_NOT_FOUND" + SKIP = "SKIP" + + +def get_highside_origin(op: RemoteRepoOperator) -> Remote: + remote_url = get_authenticated_clone_url_for_repo_config(op.repo_config, github_type=GithubType.Github) + + if HIGHSIDE_REMOTE_NAME in op.git_cli.remotes: + highside_origin = op.git_cli.remote(HIGHSIDE_REMOTE_NAME) + highside_origin.set_url(remote_url) + else: + highside_origin = op.git_cli.create_remote(HIGHSIDE_REMOTE_NAME, remote_url) + return highside_origin + + +@stopwatch +def fetch_highside_branch(op: RemoteRepoOperator, branch_name: str) -> FetchResult: + """Checks out a a branch from highside origin""" + # Step 1: create highside origin + remote_url = get_authenticated_clone_url_for_repo_config(repo=op.repo_config, github_type=GithubType.Github) + op.create_remote(HIGHSIDE_REMOTE_NAME, remote_url) + + # Step 2: fetch the branch from highside + res = op.fetch_remote(HIGHSIDE_REMOTE_NAME, refspec=branch_name) + if res == FetchResult.REFSPEC_NOT_FOUND: + logger.warning(f"Branch: {branch_name} not found in highside. Skipping fetch.") + return FetchResult.REFSPEC_NOT_FOUND + + # Step 3: checkout (or update existing) local branch that tracks highside remote + if op.is_branch_checked_out(branch_name): + # update currently checked out branch to match the latest highside branch + op.git_cli.git.reset("--hard", f"{HIGHSIDE_REMOTE_NAME}/{branch_name}") + else: + # create a new local branch that tracks the remote highside branch + op.git_cli.create_head(branch_name, commit=f"{HIGHSIDE_REMOTE_NAME}/{branch_name}", force=True) + return FetchResult.SUCCESS diff --git a/src/codegen/runner/__init__.py b/src/codegen/runner/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codegen/runner/constants/envvars.py b/src/codegen/runner/constants/envvars.py new file mode 100644 index 000000000..16b62f79b --- /dev/null +++ b/src/codegen/runner/constants/envvars.py @@ -0,0 +1,9 @@ +"""Environment variables used in the sandbox.""" + +# ==== [ Environment variable names ] ==== +CUSTOMER_REPO_ID = "CUSTOMER_REPO_ID" +FEATURE_FLAGS_BASE64 = "FEATURE_FLAGS_BASE64" +REPO_CONFIG_BASE64 = "REPO_CONFIG_BASE64" +LOWSIDE_TOKEN = "LOWSIDE_TOKEN" +HIGHSIDE_TOKEN = "HIGHSIDE_TOKEN" +IS_SANDBOX = "IS_SANDBOX" diff --git a/src/codegen/runner/diff/get_raw_diff.py b/src/codegen/runner/diff/get_raw_diff.py new file mode 100644 index 000000000..1a530fc17 --- /dev/null +++ b/src/codegen/runner/diff/get_raw_diff.py @@ -0,0 +1,94 @@ +import io +import logging + +from unidiff import LINE_TYPE_CONTEXT, Hunk, PatchedFile, PatchSet +from unidiff.patch import Line + +from codegen.sdk.core.codebase import Codebase + +logger = logging.getLogger(__name__) + + +def append_flag(file: PatchedFile, append_at: int, line_no: int, codebase: Codebase) -> None: + added_hunk = Hunk( + src_start=line_no, + src_len=1, + tgt_start=line_no, + tgt_len=1, + ) + line = codebase.get_file(file.path).content.split("\n")[line_no - 1] + added_hunk.append(Line(f"{line}\n", line_type=LINE_TYPE_CONTEXT)) + file.insert(append_at, added_hunk) + + +def patch_to_limited_diff_string(patch, codebase: Codebase, max_lines=10000): + diff_lines = [] + total_lines = 0 + + # Add flags that are not in the diff + filenames = [patched_file.path for patched_file in patch] + flags_not_in_diff = list(filter(lambda flag: flag.symbol.filepath not in filenames, codebase.G.flags._flags)) + + for flag in flags_not_in_diff: + filename = flag.symbol.filepath + patched_file = PatchedFile( + patch_info=f"diff --git a/{filename} b/{filename}\n", + source=f"a/{filename}", + target=f"b/{filename}", + ) + patch.append(patched_file) + + for patched_file in patch: + filtered_flags = filter(lambda flag: flag.symbol.filepath == patched_file.path, codebase.G.flags._flags) + sorted_flags = list(map(lambda flag: flag.symbol.start_point.row + 1, filtered_flags)) + sorted_flags.sort() + + for flag in sorted_flags: + is_in_diff = False + + for i, hunk in enumerate(patched_file): + contains_flag = hunk.source_start <= flag <= hunk.source_start + hunk.source_length + + if contains_flag: + is_in_diff = True + break + + is_after_flag = hunk.source_start > flag + + if is_after_flag: + is_in_diff = True + append_flag(patched_file, i, flag, codebase) + break + + if not is_in_diff: + append_flag(patched_file, len(patched_file), flag, codebase) + + # Add file header + raw_diff = str(patched_file) + diff_length = len(raw_diff.splitlines()) + + total_lines += diff_length + diff_lines.append(raw_diff) + + if total_lines >= max_lines: + break + + return "\n".join(diff_lines) + + +def get_raw_diff(codebase: Codebase, base: str = "HEAD", max_lines: int = 10000) -> str: + raw_diff = codebase.get_diff(base) + patch_set = PatchSet(io.StringIO(raw_diff)) + + raw_diff_length = len(raw_diff.split("\n")) + logger.info(f"Truncating diff (total: {raw_diff_length}) to {max_lines} lines ...") + raw_diff_trunc = patch_to_limited_diff_string(patch=patch_set, max_lines=max_lines, codebase=codebase) + + return raw_diff_trunc + + +def get_filenames_from_diff(diff: str) -> list[str]: + patch_set = PatchSet(io.StringIO(diff)) + filenames = [patched_file.path for patched_file in patch_set] + + return filenames diff --git a/src/codegen/runner/diff/syntax_highlight.py b/src/codegen/runner/diff/syntax_highlight.py new file mode 100644 index 000000000..f8342629e --- /dev/null +++ b/src/codegen/runner/diff/syntax_highlight.py @@ -0,0 +1,162 @@ +import io +import json +import logging +import os +import select +import subprocess +import time + +from unidiff import PatchedFile, PatchSet + +from codegen.utils.performance.stopwatch_utils import stopwatch + +logger = logging.getLogger(__name__) + +HIGHLIGHTED_DIFF_FILENAME = "highlighted_diff.json" + + +@stopwatch +def syntax_highlight_modified_files(codebase, raw_diff: str, flags: list[dict]) -> str: + modified_files = PatchSet(io.StringIO(raw_diff)) + highlighted_files = {} + highlighted_diff_files = {} + + # TODO: refactor this + with subprocess.Popen( + ". ~/.bashrc > /dev/null && nvm use > /dev/null && yarn run --silent highlight", + shell=True, + cwd="/codegen/codegen-frontend/app/modules/syntaxHighlight", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + universal_newlines=True, + ) as highlighter: + poll = select.poll() + poll.register(highlighter.stdout, select.POLLIN) + + for file in modified_files: + filename = file.path + modified_filename = file.target_file + highlighted_files[filename] = ( + "" if file.is_removed_file else _highlight_file(highlighter, poll, modified_filename if not modified_filename.startswith("b/") else modified_filename[2:], flags) + ) + + codebase.stash_changes() + + for file in modified_files: + filename = file.path + original_filename = file.source_file + original = "" if file.is_added_file else _highlight_file(highlighter, poll, original_filename if not original_filename.startswith("a/") else original_filename[2:], flags) + modified = highlighted_files[filename] + highlighted_hunks = _construct_diff_highlight(codebase, original.splitlines(), modified.splitlines(), file) + highlighted_diff_files[filename] = highlighted_hunks + + try: + codebase.restore_stashed_changes() + except Exception as e: + # This can happen if there are no changes stashed in the first place + logger.warning(f"Error restoring stashed changes: {e}") + + _, err = highlighter.communicate() + returncode = highlighter.returncode + + if err: + logger.error(f"Highlighter exited with error: {err}") + + if returncode != 0: + raise Exception(f"Highlighter exited with code {returncode}") + + highlighted_diff = json.dumps(highlighted_diff_files) + logger.info(f"Generated highlighted diff (size={len(highlighted_diff)})") + return highlighted_diff + + +@stopwatch +def _highlight_file(highlighter: subprocess.Popen[str], poll: select.poll, filename: str, flags: list[dict]): + stdin_input = { + "file": f"{os.getcwd()}/{filename}", + "flags": list(filter(lambda flag: flag["filepath"] == filename, flags)), + } + stdin_input = json.dumps(stdin_input) + + logger.info(f"> Highlighting {filename}...") + highlighter.stdin.write(f"{stdin_input}\n") + highlighter.stdin.flush() + highlighted = "" + + while True: + # if monotonic.monotonic() > timeout_at: + # raise Exception("Syntax highlighter timed out") + # + # poll_result = poll.poll(0.01) + # + # if not poll_result: + # continue + + # TODO: this can deadlock in case the subprocess does not write a newline + line = highlighter.stdout.readline() + + if not line: + time.sleep(0.01) + + if line == "\x03\n": + break + + highlighted += line + + return highlighted + + +def _construct_diff_highlight(codebase, source: list[str], target: list[str], patched_file: PatchedFile) -> list: + original_lines = 0 + modified_lines = 0 + full_file = "" + full_file_lines = 0 + highlighted_hunks = [] + + for hunk in patched_file: + hunk_lines = "" + + while original_lines < (hunk.source_start - 1): + full_file += f" {source[original_lines]}\n" + full_file_lines += 1 + original_lines += 1 + modified_lines += 1 + + for line in hunk: + if line.is_removed: + full_file += f"-{source[original_lines]}\n" + hunk_lines += f"-{source[original_lines]}\n" + original_lines += 1 + full_file_lines += 1 + elif line.is_added: + full_file += f"+{target[modified_lines]}\n" + hunk_lines += f"+{target[modified_lines]}\n" + modified_lines += 1 + full_file_lines += 1 + else: + if len(source) > original_lines: + full_file += f" {source[original_lines]}\n" + hunk_lines += f" {source[original_lines]}\n" + elif len(target) > modified_lines: + full_file += f" {target[modified_lines]}\n" + hunk_lines += f" {target[modified_lines]}\n" + else: + logger.warning(f"Lines {original_lines}/{modified_lines} not found in {patched_file.path} in {codebase.current_commit.hexsha}: {line}") + original_lines += 1 + modified_lines += 1 + full_file_lines += 1 + + if hunk_lines.endswith("\n"): + hunk_lines = hunk_lines[:-1] + + highlighted_hunks.append({"lines": hunk_lines, "starts_at": full_file_lines - len(hunk), "ends_at": full_file_lines - 1}) + + if original_lines < len(source): + full_file += "\n ".join(source[original_lines:]) + + # TODO: we know the file length so we can add a property to diff and determine if we can expand down even if we haven't loaded the entire file on FE yet + + return highlighted_hunks diff --git a/src/codegen/runner/enums/warmup_state.py b/src/codegen/runner/enums/warmup_state.py new file mode 100644 index 000000000..c75c6f553 --- /dev/null +++ b/src/codegen/runner/enums/warmup_state.py @@ -0,0 +1,7 @@ +from enum import StrEnum + + +class WarmupState(StrEnum): + PENDING = "PENDING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" diff --git a/src/codegen/runner/models/apis.py b/src/codegen/runner/models/apis.py new file mode 100644 index 000000000..d33cadd6f --- /dev/null +++ b/src/codegen/runner/models/apis.py @@ -0,0 +1,73 @@ +"""Dataclasses used by the sandboxes server APIs""" + +from pydantic import BaseModel + +from codegen.runner.enums.warmup_state import WarmupState +from codegen.runner.models.codemod import BranchConfig, Codemod, CodemodRunResult, CreatedBranch, GroupingConfig + +SANDBOX_SERVER_PORT = 4000 +EPHEMERAL_SANDBOX_SERVER_PORT = 4001 + +# APIs +SIGNAL_SHUTDOWN_ENDPOINT = "/signal_shutdown" +DIFF_ENDPOINT = "/diff" +BRANCH_ENDPOINT = "/branch" + +# Ephemeral sandbox apis +RUN_ON_STRING_ENDPOINT = "/run_on_string" + + +class ServerInfo(BaseModel): + repo_id: int = 0 + container_id: str = "" + is_running_codemod: bool = False + is_shutting_down: bool = False + warmup_state: WarmupState = WarmupState.PENDING + label: str | None = "" + + +class UtilizationMetrics(BaseModel): + container_id: str + timestamp: str + memory_rss_gb: float + memory_vms_gb: float + cpu_percent: float + threads_count: int + open_files_count: int + + +class SignalShutdownResponse(BaseModel): + is_ready_to_shutdown: bool + + +class GetDiffRequest(BaseModel): + codemod: Codemod + max_transactions: int | None = None + max_seconds: int | None = None + + +class GetDiffResponse(BaseModel): + result: CodemodRunResult + + +class CreateBranchRequest(BaseModel): + codemod: Codemod + grouping_config: GroupingConfig + branch_config: BranchConfig + + +class CreateBranchResponse(BaseModel): + results: list[CodemodRunResult] | None = None + branches: list[CreatedBranch] | None = None + num_flags: int | None = None + group_segments: list[str] | None = None + + +class GetRunOnStringRequest(BaseModel): + codemod_source: str + language: str + files: dict[str, str] + + +class GetRunOnStringResult(BaseModel): + result: CodemodRunResult diff --git a/src/codegen/runner/models/codemod.py b/src/codegen/runner/models/codemod.py new file mode 100644 index 000000000..c09502d97 --- /dev/null +++ b/src/codegen/runner/models/codemod.py @@ -0,0 +1,57 @@ +"""Dataclasses used by the sandbox runners""" + +from datetime import datetime + +from pydantic import BaseModel + +from codegen.git.models.codemod_context import CodemodContext +from codegen.git.models.pr_options import PROptions +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + + +class Codemod(BaseModel): + run_id: int + version_id: int + epic_title: str + user_code: str + codemod_context: CodemodContext + + # Sentry tags + epic_id: int + is_customer: bool = True + + +class GroupingConfig(BaseModel): + subdirectories: list[str] | None = None + group_by: GroupBy | None = None + max_prs: int | None = None + + +class BranchConfig(BaseModel): + base_branch: str | None = None + custom_head_branch: str | None = None + force_push_head_branch: bool = False + + +class CodemodRunResult(BaseModel): + is_complete: bool = False + observation: str | None = None + visualization: dict | None = None + observation_meta: dict | None = None + base_commit: str | None = None + logs: str | None = None + error: str | None = None + completed_at: datetime | None = None + highlighted_diff: str | None = None + pr_options: PROptions | None = None + flags: list[dict] | None = None + + +class CreatedBranch(BaseModel): + base_branch: str + head_ref: str | None + + +class SandboxRunnerTag(BaseModel): + repo_id: str + runner_id: str diff --git a/src/codegen/runner/models/configs.py b/src/codegen/runner/models/configs.py new file mode 100644 index 000000000..c40b5bcfa --- /dev/null +++ b/src/codegen/runner/models/configs.py @@ -0,0 +1,54 @@ +import base64 +import os + +from pydantic import BaseModel, ConfigDict + +from codegen.git.schemas.repo_config import RepoConfig +from codegen.runner.constants.envvars import FEATURE_FLAGS_BASE64, REPO_CONFIG_BASE64 +from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags +from codegen.sdk.secrets import Secrets + + +class RunnerFeatureFlags(BaseModel): + """Feature flags for a runner""" + + model_config = ConfigDict(frozen=True) + + sync_enabled: bool = True + track_graph: bool = False + verify_graph: bool = False + + ts_language_engine: bool = False + v8_ts_engine: bool = False + ts_dependency_manager: bool = False + + import_resolution_overrides: dict[str, str] = {} + syntax_highlight: bool = False + + def encoded_json(self): + return base64.b64encode(self.model_dump_json().encode("utf-8")).decode("utf-8") + + @staticmethod + def from_encoded_json(encoded_json: str) -> "RunnerFeatureFlags": + decoded = base64.b64decode(encoded_json).decode("utf-8") + return RunnerFeatureFlags.model_validate_json(decoded) + + +def get_codebase_config() -> CodebaseConfig: + gs_ffs = GSFeatureFlags(**get_runner_feature_flags().model_dump()) + secrets = Secrets(openai_key=os.environ["OPENAI_PASS"]) + return CodebaseConfig(secrets=secrets, feature_flags=gs_ffs) + + +def get_runner_feature_flags() -> RunnerFeatureFlags: + encoded_ffs = os.environ.get(FEATURE_FLAGS_BASE64) + if not encoded_ffs: + raise ValueError("FEATURE_FLAGS_BASE64 environment variable not found") + return RunnerFeatureFlags.from_encoded_json(encoded_ffs) + + +def get_repo_config() -> RepoConfig: + encoded_repo_config = os.environ.get(REPO_CONFIG_BASE64) + if not encoded_repo_config: + raise ValueError("REPO_CONFIG_BASE64 environment variable not found") + return RepoConfig.from_encoded_json(encoded_repo_config) diff --git a/src/codegen/runner/sandbox/executor.py b/src/codegen/runner/sandbox/executor.py new file mode 100644 index 000000000..eb21b1edc --- /dev/null +++ b/src/codegen/runner/sandbox/executor.py @@ -0,0 +1,189 @@ +import logging +from collections.abc import Callable +from datetime import UTC, datetime + +from github.PullRequest import PullRequest + +from codegen.git.models.pr_options import PROptions +from codegen.runner.diff.get_raw_diff import get_raw_diff +from codegen.runner.diff.syntax_highlight import syntax_highlight_modified_files +from codegen.runner.models.codemod import BranchConfig, Codemod, CodemodRunResult, CreatedBranch, GroupingConfig +from codegen.runner.models.configs import get_runner_feature_flags +from codegen.runner.sandbox.repo import SandboxRepo +from codegen.runner.utils.branch_name import get_head_branch_name +from codegen.runner.utils.exception_utils import update_observation_meta +from codegen.sdk.codebase.config import SessionOptions +from codegen.sdk.codebase.factory.codebase_factory import CodebaseType +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.utils import get_grouper_by_group_by +from codegen.utils.exceptions.control_flow import StopCodemodException +from codegen.utils.performance.stopwatch_utils import stopwatch +from codegen.visualizations.viz_utils import get_graph_json + +logger = logging.getLogger(__name__) + + +class SandboxExecutor: + """Responsible for executing the user defined codemod in the sandbox.""" + + codebase: CodebaseType + remote_repo: SandboxRepo + _is_syntax_highlight_enabled: bool + + def __init__(self, codebase: CodebaseType): + self.codebase = codebase + self.remote_repo = SandboxRepo(self.codebase) + self._is_syntax_highlight_enabled = get_runner_feature_flags().syntax_highlight + + async def find_flags(self, execute_func: Callable) -> list[CodeFlag]: + """Runs the execute_func in find_mode to find flags""" + self.codebase.set_find_mode(True) + await self._execute_with_try_catch(execute_func, commit=False) + code_flags = self.codebase.G.flags._flags + logger.info(f"> Found {len(self.codebase.G.flags._flags)} CodeFlags") + return code_flags + + async def find_flag_groups(self, code_flags: list[CodeFlag], grouping_config: GroupingConfig) -> list[Group]: + """Groups the code flags as specified by grouping_config""" + if grouping_config.subdirectories and len(grouping_config.subdirectories) > 0: + logger.info(f"> Filtering flags by subdirectories: {grouping_config.subdirectories}") + code_flags = [flag for flag in code_flags if any([flag.filepath.startswith(x) for x in grouping_config.subdirectories])] + logger.info(f"> Flags remaining: {len(code_flags)}") + + # =====[ Group the code flags ]===== + logger.info(f"> Grouping CodeFlags by config: {grouping_config}") + grouper = get_grouper_by_group_by(grouping_config.group_by, repo_id=self.codebase.op.repo_config.id) + groups = grouper.create_all_groups(flags=code_flags, repo_operator=self.codebase.op) + logger.info(f"> Created {len(groups)} groups") + return groups + + async def execute_flag_groups(self, codemod: Codemod, execute_func: Callable, flag_groups: list[Group], branch_config: BranchConfig) -> tuple[list[CodemodRunResult], list[CreatedBranch]]: + run_results = [] + head_branches = [] + for idx, group in enumerate(flag_groups): + if idx > 0 and run_results[-1].error: + logger.info("Skipping remaining groups because of error in previous group") + break + if group: + logger.info(f"Running group {group.segment} ({idx + 1} out of {len(flag_groups)})...") + + head_branch = branch_config.custom_head_branch or get_head_branch_name(codemod, group) + logger.info(f"Running with head branch: {head_branch}") + self.remote_repo.reset_branch(branch_config.base_branch, head_branch) + + run_result = await self.execute(execute_func, group=group) + created_branch = CreatedBranch(base_branch=branch_config.base_branch, head_ref=None) + if self.remote_repo.push_changes_to_remote(codemod, head_branch, branch_config.force_push_head_branch): + created_branch.head_ref = head_branch + + self.codebase.reset() + run_results.append(run_result) + head_branches.append(created_branch) + + self.codebase.G.flags._flags.clear() + return run_results, head_branches + + async def execute(self, execute_func: Callable, group: Group | None = None, session_options: SessionOptions = SessionOptions()) -> CodemodRunResult: + """Runs the execute_func in edit_mode and returns the saved the result""" + self.codebase.set_find_mode(False) + if group: + self.codebase.set_active_group(group) + result = await self._execute_with_try_catch(execute_func, session_options=session_options) + return await self._get_structured_run_output(result) + + async def execute_on_pr(self, execute_func: Callable, pr: PullRequest, session_options: SessionOptions = SessionOptions()) -> CodemodRunResult: + """Runs the execute_func in edit_mode and returns the saved the result""" + # TODO: only difference is this sets `set_find_mode` to True to capture flags. Shouldn't need to do this, flags should always appear. + self.codebase.set_find_mode(True) + result = await self._execute_with_try_catch(execute_func, session_options=session_options, pr=pr) + return await self._get_structured_run_output(result) + + @stopwatch + async def _execute_with_try_catch( + self, + execute_func: Callable, + *, + sync_graph: bool = False, + commit: bool = True, + session_options: SessionOptions = SessionOptions(), + pr: PullRequest | None = None, + ) -> CodemodRunResult: + """Runs the execute_func in a try/catch with a codebase session""" + logger.info(f"Running safe execute with sync_graph: {sync_graph} commit: {commit} session_options: {session_options}") + result = CodemodRunResult() + pr_options = PROptions() + try: + with self.codebase.session(sync_graph, commit, session_options=session_options): + execute_func(self.codebase, pr_options, pr=pr) + result.is_complete = True + + except StopCodemodException as e: + logger.info(f"Stopping codemod due to {e.__class__.__name__}: {e}") + result.observation_meta = update_observation_meta(e, result.observation_meta) + result.is_complete = True + + except Exception as e: + error_message = str(e) + logger.exception(e) + result.error = error_message + result.is_complete = False + + finally: + # =====[ Capture completed_at ]===== + result.completed_at = datetime.now(tz=UTC) + + # =====[ Capture PR options ]===== + result.pr_options = pr_options + + # =====[ Build graph.json ]===== + viz_results = get_graph_json(self.codebase.op) + if viz_results is not None: + result.visualization = viz_results + + return result + + async def _get_structured_run_output(self, result: CodemodRunResult) -> CodemodRunResult: + """Formats output into a CodemodRunResult""" + # =====[ Save flags ]===== + # Note: I think we should just store this on the CodemodRunResult.flags, not meta + # Also note: we should type this object, since we end up using it in several locations + flags = [ + { + "filepath": flag.symbol.filepath, + "startLine": flag.symbol.start_point.row, + "startColumn": flag.symbol.start_point.column, + "endLine": flag.symbol.end_point.row, + "endColumn": flag.symbol.end_point.column, + "message": flag.message, + "messageType": str(flag.message_type), + "messageRecipient": flag.message_recipient, + } + for flag in self.codebase.G.flags._flags + ] + result.flags = flags + if result.observation_meta is None: + result.observation_meta = {} + result.observation_meta["flags"] = flags + + # =====[ Get and store raw diff ]===== + logger.info("> Extracting diff") + raw_diff = get_raw_diff(codebase=self.codebase) + result.observation = raw_diff + result.base_commit = self.codebase.current_commit.hexsha if self.codebase.current_commit else "HEAD" + + if self._is_syntax_highlight_enabled: + logger.info("> Syntax highlighting modified files") + try: + result.highlighted_diff = syntax_highlight_modified_files(self.codebase, raw_diff, flags) + except Exception as e: + # TODO: this doesn't work during webhooks. Maybe due to installation dependencies? + logger.exception(f"Error! Failed to syntax highlight modified files: {e}") + else: + logger.info("> Skipping syntax highlighting, because feature flag is not enabled") + + # =====[ Finalize CodemodRun state ]===== + # Include logs etc. + logger.info("> Extracting/formatting logs") + result.logs = self.codebase.get_finalized_logs() + return result diff --git a/src/codegen/runner/sandbox/middlewares.py b/src/codegen/runner/sandbox/middlewares.py new file mode 100644 index 000000000..18167389f --- /dev/null +++ b/src/codegen/runner/sandbox/middlewares.py @@ -0,0 +1,74 @@ +import logging +import traceback +from collections.abc import Callable +from functools import cached_property +from http import HTTPStatus # Add this import +from typing import TypeVar + +from starlette.background import BackgroundTasks +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from codegen.runner.models.apis import ServerInfo +from codegen.runner.sandbox.runner import SandboxRunner +from codegen.utils.compilation.exceptions import UserCodeException + +logger = logging.getLogger(__name__) + +TRequest = TypeVar("TRequest", bound=Request) +TResponse = TypeVar("TResponse", bound=Response) + + +class CodemodRunMiddleware[TRequest, TResponse](BaseHTTPMiddleware): + def __init__(self, app, path: str, server_info_fn: Callable[[], ServerInfo], runner_fn: Callable[[], SandboxRunner]) -> None: + super().__init__(app) + self.path = path + self.server_info_fn = server_info_fn + self.runner_fn = runner_fn + + async def dispatch(self, request: TRequest, call_next: RequestResponseEndpoint) -> TResponse: + if request.url.path == self.path: + return await self.process_request(request, call_next) + return await call_next(request) + + @cached_property + def server_info(self) -> ServerInfo: + return self.server_info_fn() + + @cached_property + def runner(self) -> SandboxRunner: + return self.runner_fn() + + async def process_request(self, request: TRequest, call_next: RequestResponseEndpoint) -> TResponse: + self.server_info.is_running_codemod = True + background_tasks = BackgroundTasks() + try: + logger.info(f"> (CodemodRunMiddleware) Request: {request.url.path}") + self.runner.codebase.viz.clear_graphviz_data() + response = await call_next(request) + background_tasks.add_task(self.cleanup_after_codemod, is_exception=False) + response.background = background_tasks + return response + + except UserCodeException as e: + message = f"Invalid user code for {request.url.path}" + logger.info(message) + self.server_info.is_running_codemod = False + return JSONResponse(status_code=HTTPStatus.BAD_REQUEST, content={"detail": message, "error": str(e), "traceback": traceback.format_exc()}) + + except Exception as e: + message = f"Unexpected error for {request.url.path}" + logger.exception(message) + res = JSONResponse(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"detail": message, "error": str(e), "traceback": traceback.format_exc()}) + background_tasks.add_task(self.cleanup_after_codemod, is_exception=True) + res.background = background_tasks + return res + + async def cleanup_after_codemod(self, is_exception: bool = False): + if is_exception: + # TODO: instead of committing transactions, we should just rollback + logger.info("Committing pending transactions due to exception") + self.runner.codebase.G.commit_transactions(sync_graph=False) + self.runner.reset_runner() + self.server_info.is_running_codemod = False diff --git a/src/codegen/runner/sandbox/repo.py b/src/codegen/runner/sandbox/repo.py new file mode 100644 index 000000000..297b86030 --- /dev/null +++ b/src/codegen/runner/sandbox/repo.py @@ -0,0 +1,84 @@ +import logging + +from codegen.git.schemas.enums import FetchResult +from codegen.git.utils.branch_sync import BranchSyncResult, fetch_highside_branch, get_highside_origin +from codegen.runner.models.codemod import Codemod +from codegen.sdk.codebase.factory.codebase_factory import CodebaseType + +logger = logging.getLogger(__name__) + + +class SandboxRepo: + """Responsible for managing the state of the git repo stored in the sandbox runner""" + + codebase: CodebaseType + + def __init__(self, codebase: CodebaseType) -> None: + self.codebase = codebase + + def set_up_base_branch(self, base_branch: str | None) -> None: + """Set-up base branch by pushing latest highside branch to lowside and checking out the branch.""" + # If base branch is already checked out, do nothing + if self.codebase.op.is_branch_checked_out(base_branch): + return + + res = self._pull_highside_to_lowside(base_branch) + if res is BranchSyncResult.SUCCESS: + self.codebase.checkout(branch=base_branch, remote=True) + + def set_up_head_branch(self, head_branch: str, force_push_head_branch: bool): + """Set-up head branch by pushing latest highside branch to lowside and fetching the branch (so that it can be checked out later).""" + # If head branch is not specified, do nothing + if head_branch is None: + return + + if head_branch and head_branch == self.codebase.default_branch: + # NOTE: assuming that the main branch is always protected, instead should pull this from github (but it requires admin permissions) + error = f"Branch {head_branch} is protected and cannot be used as the head branch!" + logger.error(error) + raise ValueError(error) + + # If are force pushing the head branch, don't checkout the remote. + # This will cause set-up group to create a new branch off of master by the same name + if force_push_head_branch: + return + + res = self._pull_highside_to_lowside(head_branch) + if res is BranchSyncResult.SUCCESS: + self.codebase.op.fetch_remote("origin", refspec=f"{head_branch}:{head_branch}") + + def _pull_highside_to_lowside(self, branch_name: str): + """Grabs the latest highside branch `branch_name` and pushes it to the lowside.""" + # Step 1: checkout branch that tracks highside remote + res = fetch_highside_branch(op=self.codebase.op, branch_name=branch_name) + if res == FetchResult.REFSPEC_NOT_FOUND: + return BranchSyncResult.BRANCH_NOT_FOUND + + # Step 2: push branch up to lowside + logger.info(f"Pushing branch: {branch_name} from highside to lowside w/ force=False ...") + lowside_origin = self.codebase.op.git_cli.remote("origin") + self.codebase.op.push_changes(remote=lowside_origin, refspec=f"{branch_name}:{branch_name}", force=False) + return BranchSyncResult.SUCCESS + + def reset_branch(self, base_branch: str, head_branch: str) -> None: + logger.info(f"Checking out base branch {base_branch} ...") + self.codebase.checkout(branch=base_branch, create_if_missing=True) + # =====[ Checkout head branch ]===== + logger.info(f"Checking out head branch {head_branch} ...") + self.codebase.checkout(branch=head_branch, create_if_missing=True) + + def push_changes_to_remote(self, codemod: Codemod, head_branch: str, force_push: bool) -> bool: + """Takes current state of repo and pushes it""" + # =====[ Stage changes ]===== + has_staged_commit = self.codebase.git_commit(f"[Codegen] {codemod.epic_title}") + if not has_staged_commit: + logger.info(f"Skipping opening pull request for cm_run {codemod.run_id} b/c the codemod produced no changes") + return False + + # =====[ Push changes highside ]===== + highside_origin = get_highside_origin(self.codebase.op) + highside_res = self.codebase.op.push_changes(remote=highside_origin, refspec=f"{head_branch}:{head_branch}", force=force_push) + return not any(push_info.flags & push_info.ERROR for push_info in highside_res) + + # TODO: move bunch of codebase git operations into this class. + # The goal is to make the codebase class ONLY allow LocalRepoOperator. diff --git a/src/codegen/runner/sandbox/runner.py b/src/codegen/runner/sandbox/runner.py new file mode 100644 index 000000000..0f0d638dd --- /dev/null +++ b/src/codegen/runner/sandbox/runner.py @@ -0,0 +1,127 @@ +import logging +import sys + +import sentry_sdk +from git import Commit as GitCommit + +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.git.schemas.repo_config import RepoConfig +from codegen.runner.models.apis import CreateBranchRequest, CreateBranchResponse, GetDiffRequest, GetDiffResponse +from codegen.runner.models.configs import get_codebase_config +from codegen.runner.sandbox.executor import SandboxExecutor +from codegen.sdk.codebase.config import ProjectConfig, SessionOptions +from codegen.sdk.codebase.factory.codebase_factory import CodebaseType +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.enums import ProgrammingLanguage +from codegen.utils.compilation.string_to_code import create_execute_function_from_codeblock +from codegen.utils.performance.stopwatch_utils import stopwatch + +logger = logging.getLogger(__name__) + + +class SandboxRunner: + """Responsible for orchestrating the lifecycle of a warmed sandbox""" + + # =====[ __init__ instance attributes ]===== + container_id: str + repo: RepoConfig + commit: GitCommit + op: RemoteRepoOperator | None + + # =====[ computed instance attributes ]===== + codebase: CodebaseType + executor: SandboxExecutor + + def __init__( + self, + container_id: str, + repo_config: RepoConfig, + ) -> None: + self.container_id = container_id + self.repo = repo_config + self.op = RemoteRepoOperator(repo_config, base_dir=repo_config.base_dir) + self.commit = self.op.git_cli.head.commit + + async def warmup(self) -> None: + """Warms up this runner by cloning the repo and parsing the graph.""" + logger.info(f"===== Warming runner for {self.repo.full_name} (ID={self.repo.id}) =====") + sys.setrecursionlimit(10000) # for graph parsing + + self.codebase = await self._build_graph() + self.executor = SandboxExecutor(self.codebase) + + async def _build_graph(self) -> Codebase: + logger.info("> Building graph...") + programming_language = ProgrammingLanguage[self.op.repo_config.language.upper()] + projects = [ProjectConfig(programming_language=programming_language, repo_operator=self.op, base_path=self.op.repo_config.base_path, subdirectories=self.op.repo_config.subdirectories)] + return Codebase(projects=projects, config=get_codebase_config()) + + @stopwatch + def reset_runner(self) -> None: + """Reset the runner to a cleaned/stable state for the next job. + + At the start of every job the runner should be in the following state: + - Codebase is checked out to the pinned commit (i.e. self.commit) + - Codebase LRP (LocalRepoOperator) has only the origin remote and no branches + + This method puts the runner in the above state and should be called at the end of every job. + """ + # TODO: move self.codebase.reset() here instead of during run + # TODO assert codebase is on the default branch and its clean + # TODO re-enable this (i.e. rather than pinning the runner commit, always move it forward to the latest commit) + logger.info("=====[ reset_runner ]=====") + logger.info(f"Syncing runner to commit: {self.commit} ...") + self.codebase.checkout(commit=self.commit) + self.codebase.clean_repo() + self.codebase.checkout(branch=self.codebase.default_branch, create_if_missing=True) + + @staticmethod + def _set_sentry_tags(epic_id: int, is_customer: bool) -> None: + """Set the sentry tags for a CodemodRun""" + sentry_sdk.set_tag("epic_id", epic_id) # To easily get to the epic in the UI + sentry_sdk.set_tag("is_customer", is_customer) # To filter "prod" level errors, ex if customer hits an error vs an admin + + async def get_diff(self, request: GetDiffRequest) -> GetDiffResponse: + self._set_sentry_tags(epic_id=request.codemod.epic_id, is_customer=request.codemod.is_customer) + custom_scope = {"context": request.codemod.codemod_context} if request.codemod.codemod_context else {} + code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod.user_code, custom_scope=custom_scope) + session_options = SessionOptions(max_transactions=request.max_transactions, max_seconds=request.max_seconds) + + res = await self.executor.execute(code_to_exec, session_options=session_options) + + return GetDiffResponse(result=res) + + async def create_branch(self, request: CreateBranchRequest) -> CreateBranchResponse: + self._set_sentry_tags(epic_id=request.codemod.epic_id, is_customer=request.codemod.is_customer) + custom_scope = {"context": request.codemod.codemod_context} if request.codemod.codemod_context else {} + code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod.user_code, custom_scope=custom_scope) + branch_config = request.branch_config + + branch_config.base_branch = branch_config.base_branch or self.codebase.default_branch + self.executor.remote_repo.set_up_base_branch(branch_config.base_branch) + self.executor.remote_repo.set_up_head_branch(branch_config.custom_head_branch, branch_config.force_push_head_branch) + + response = CreateBranchResponse() + if "codebase.flag_instance" in request.codemod.user_code: + flags = await self.executor.find_flags(code_to_exec) + flag_groups = await self.executor.find_flag_groups(flags, request.grouping_config) + response.num_flags = len(flags) + response.group_segments = [group.segment for group in flag_groups] + if len(flag_groups) == 0: + logger.info("No flag groups found. Running without flagging.") + flag_groups = [None] + else: + flag_groups = [None] + + # TODO: do this as part of find_flag_groups? + max_prs = request.grouping_config.max_prs + if max_prs and len(flag_groups) >= max_prs: + logger.info(f"Max PRs limit reached: {max_prs}. Skipping remaining groups.") + flag_groups = flag_groups[:max_prs] + + run_results, branches = await self.executor.execute_flag_groups(request.codemod, code_to_exec, flag_groups, branch_config) + response.results = run_results + response.branches = branches + + self.codebase.G.flags._flags.clear() + return response diff --git a/src/codegen/runner/sandbox/server.py b/src/codegen/runner/sandbox/server.py new file mode 100644 index 000000000..76ec7bf6c --- /dev/null +++ b/src/codegen/runner/sandbox/server.py @@ -0,0 +1,109 @@ +import datetime as dt +import logging +import os +from contextlib import asynccontextmanager +from datetime import datetime + +import psutil +from fastapi import FastAPI + +from codegen.runner.constants.envvars import CUSTOMER_REPO_ID +from codegen.runner.enums.warmup_state import WarmupState +from codegen.runner.models.apis import ( + BRANCH_ENDPOINT, + DIFF_ENDPOINT, + SIGNAL_SHUTDOWN_ENDPOINT, + CreateBranchRequest, + CreateBranchResponse, + GetDiffRequest, + GetDiffResponse, + ServerInfo, + SignalShutdownResponse, + UtilizationMetrics, +) +from codegen.runner.models.configs import get_repo_config +from codegen.runner.sandbox.middlewares import CodemodRunMiddleware +from codegen.runner.sandbox.runner import SandboxRunner +from codegen.utils.performance.memory_utils import get_memory_stats + +logger = logging.getLogger(__name__) + +server_info: ServerInfo +runner: SandboxRunner + + +@asynccontextmanager +async def lifespan(server: FastAPI): + global server_info + global runner + + try: + server_info = ServerInfo(repo_id=int(os.getenv(CUSTOMER_REPO_ID)), container_id=os.getenv("MODAL_TASK_ID")) + logger.info(f"Starting up sandbox fastapi server for repo_id={server_info.repo_id} in container ID={server_info.container_id}") + + runner = SandboxRunner(container_id=server_info.container_id, repo_config=get_repo_config()) + server_info.warmup_state = WarmupState.PENDING + await runner.warmup() + server_info.warmup_state = WarmupState.COMPLETED + except Exception: + logger.exception("Failed to build graph during warmup") + server_info.warmup_state = WarmupState.FAILED + raise + + logger.info("Sandbox fastapi server is ready to accept requests") + yield + logger.info("Shutting down sandbox fastapi server") + + +app = FastAPI(lifespan=lifespan) +app.add_middleware( + CodemodRunMiddleware[GetDiffRequest, GetDiffResponse], + path=DIFF_ENDPOINT, + server_info_fn=lambda: server_info, + runner_fn=lambda: runner, +) +app.add_middleware( + CodemodRunMiddleware[CreateBranchRequest, CreateBranchResponse], + path=BRANCH_ENDPOINT, + server_info_fn=lambda: server_info, + runner_fn=lambda: runner, +) + + +@app.get("/") +def health() -> ServerInfo: + return server_info + + +@app.get("/metrics/utilization", response_model=UtilizationMetrics) +async def utilization_metrics() -> UtilizationMetrics: + # Get the current process + process = psutil.Process(os.getpid()) + memory_stats = get_memory_stats() + + return UtilizationMetrics( + container_id=os.getenv("MODAL_TASK_ID"), + timestamp=datetime.now(dt.UTC).isoformat(), + memory_rss_gb=memory_stats.memory_rss_gb, + memory_vms_gb=memory_stats.memory_vms_gb, + cpu_percent=process.cpu_percent(), + threads_count=process.num_threads(), + open_files_count=len(process.open_files()), + ) + + +@app.post(SIGNAL_SHUTDOWN_ENDPOINT) +async def signal_shutdown() -> SignalShutdownResponse: + logger.info(f"repo_id={server_info.repo_id} container ID={server_info.container_id} received signal_shutdown") + server_info.is_shutting_down = True + return SignalShutdownResponse(is_ready_to_shutdown=not server_info.is_running_codemod) + + +@app.post(DIFF_ENDPOINT) +async def get_diff(request: GetDiffRequest) -> GetDiffResponse: + return await runner.get_diff(request=request) + + +@app.post(BRANCH_ENDPOINT) +async def create_branch(request: CreateBranchRequest) -> CreateBranchResponse: + return await runner.create_branch(request=request) diff --git a/src/codegen/runner/utils/branch_name.py b/src/codegen/runner/utils/branch_name.py new file mode 100644 index 000000000..04c59946c --- /dev/null +++ b/src/codegen/runner/utils/branch_name.py @@ -0,0 +1,25 @@ +import re + +from codegen.runner.models.codemod import Codemod +from codegen.sdk.codebase.flagging.group import DEFAULT_GROUP_ID, Group + +# Codegen branches are of the format: codegen-codemod--version--run--group- +CODEGEN_BRANCH_PATTERN = r"codegen-codemod-(\d+)-version-(\d+)-run-(\d+)-group-(\d+)" + +# Regex used for parsing DB IDs from Codegen branch names +CODEGEN_BRANCH_REGEX = re.compile(f"^{CODEGEN_BRANCH_PATTERN}$") + +# Template used to create a Codegen branch name +CODEGEN_BRANCH_TEMPLATE = CODEGEN_BRANCH_PATTERN.replace("(\\d+)", "{}") + + +def get_head_branch_name(codemod: Codemod, group: Group | None = None) -> str: + if not codemod.version_id: + raise ValueError(f"CodemodRun: {codemod.run_id} does not have a codemod version!") + if not codemod.epic_id: + raise ValueError(f"CodemodRun: {codemod.run_id} does not have an epic!") + if group and group.id is None: + raise ValueError("Group ID is required to create a branch name") + + group_id = group.id if group else DEFAULT_GROUP_ID + return CODEGEN_BRANCH_TEMPLATE.format(codemod.epic_id, codemod.version_id, codemod.run_id, group_id) diff --git a/src/codegen/runner/utils/exception_utils.py b/src/codegen/runner/utils/exception_utils.py new file mode 100644 index 000000000..4f2de21cf --- /dev/null +++ b/src/codegen/runner/utils/exception_utils.py @@ -0,0 +1,15 @@ +from codegen.utils.exceptions.control_flow import StopCodemodException + + +def update_observation_meta( + e: StopCodemodException, + observation_meta: dict | None = None, +) -> dict: + observation_meta = observation_meta or {} + observation_meta.update( + { + "stop_codemod_exception_type": e.__class__.__name__, + "threshold": e.threshold, + }, + ) + return observation_meta