diff --git a/.codegen/.gitignore b/.codegen/.gitignore deleted file mode 100644 index 77d89d205..000000000 --- a/.codegen/.gitignore +++ /dev/null @@ -1,15 +0,0 @@ -# Codegen -docs/ -examples/ -prompts/ -jupyter/ -.venv/ -.env -codegen-system-prompt.txt - -# Python cache files -**/__pycache__/ -*.py[cod] -*$py.class -*.txt -*.pyc diff --git a/.codegen/codemods/no_link_backticks/no_link_backticks.py b/.codegen/codemods/no_link_backticks/no_link_backticks.py deleted file mode 100644 index e8cda5323..000000000 --- a/.codegen/codemods/no_link_backticks/no_link_backticks.py +++ /dev/null @@ -1,44 +0,0 @@ -import codegen -from codegen import Codebase - - -@codegen.function(name="no-link-backticks", subdirectories=["test/unit"]) -def run(codebase: Codebase): - import re - - # Define the pattern for Markdown links with backticks in the link text - link_pattern = re.compile(r"\[([^\]]*`[^\]]*`[^\]]*)\]\(([^)]+)\)") - - # Iterate over all .mdx files in the codebase - for file in codebase.files(extensions=["mdx"]): - if file.extension == ".mdx": - print(f"Processing {file.path}") - new_content = file.content - - # Find all markdown links with backticks in link text - matches = link_pattern.finditer(new_content) - - for match in matches: - # Original link text with backticks - original_text = match.group(1) - - # Remove backticks from the link text - new_text = original_text.replace("`", "") - - # Replace the link in content - new_content = new_content.replace(match.group(0), f"[{new_text}]({match.group(2)})") - - # Update file content if changes were made - if new_content != file.content: - file.edit(new_content) - - # Commit all changes - codebase.commit() - - -if __name__ == "__main__": - print("Parsing codebase...") - codebase = Codebase("./") - - print("Running function...") - codegen.run(run) diff --git a/.codegen/codemods/test_language/test_language.py b/.codegen/codemods/test_language/test_language.py deleted file mode 100644 index 19ae4c0bd..000000000 --- a/.codegen/codemods/test_language/test_language.py +++ /dev/null @@ -1,19 +0,0 @@ -import codegen -from codegen.sdk.core.codebase import Codebase -from codegen.shared.enums.programming_language import ProgrammingLanguage - - -@codegen.function("test-language", subdirectories=["src/codegen/cli"], language=ProgrammingLanguage.PYTHON) -def run(codebase: Codebase): - file = codebase.get_file("src/codegen/cli/errors.py") - print(f"File: {file.path}") - for s in file.symbols: - print(s.name) - - -if __name__ == "__main__": - print("Parsing codebase...") - codebase = Codebase("./") - - print("Running...") - run(codebase) diff --git a/.codegen/codemods/update_loggers/update_loggers.py b/.codegen/codemods/update_loggers/update_loggers.py deleted file mode 100644 index 74edee3e1..000000000 --- a/.codegen/codemods/update_loggers/update_loggers.py +++ /dev/null @@ -1,18 +0,0 @@ -import codegen -from codegen.sdk.core.codebase import PyCodebaseType - - -@codegen.function("update-loggers") -def run(codebase: PyCodebaseType) -> None: - """Updates all loggers in src/codegen to use the new get_logger function.""" - for file in codebase.files: - if not str(file.filepath).startswith("src/codegen/"): - continue - - if file.get_import("logging") is None: - continue - - if (logger := file.get_global_var("logger")) and logger.value.source == "logging.getLogger(__name__)": - print(f"Updating logger in {file.filepath}") - logger.set_value("get_logger(__name__)") - file.add_import_from_import_string("\nfrom codegen.shared.logging.get_logger import get_logger") diff --git a/.github/actions/setup-oss-repos/action.yml b/.github/actions/setup-oss-repos/action.yml index 4ec25be83..c7951d599 100644 --- a/.github/actions/setup-oss-repos/action.yml +++ b/.github/actions/setup-oss-repos/action.yml @@ -1,23 +1,8 @@ -# yaml-language-server: $schema=https://json.schemastore.org/github-action.json -name: "Setup OSS repos" -description: "Setup OSS repos" -# TODO: add size filter +name: "Setup OSS repos (disabled)" +description: "OSS repos setup has been disabled" runs: using: "composite" steps: - - name: Cache oss-repos - id: cache-oss-repos - uses: actions/cache@v4 - with: - path: oss_repos - key: ${{ runner.os }}-repo-cache-2-${{hashFiles('codegen-backend/codegen_tests/graph_sitter/codemod/repos/open_source/*.json')}} - - name: Populate oss-repos if the cache is empty - if: steps.cache-oss-repos.outputs.cache-hit != 'true' + - name: Skip OSS repos setup shell: bash - run: | - uv run --frozen python -m tests.shared.codemod.commands clone-repos --clean-cache - env: - GITHUB_WORKSPACE: $GITHUB_WORKSPACE - - name: Verify cache contents - shell: bash - run: ls -la $GITHUB_WORKSPACE/oss_repos/ + run: echo "OSS repos setup is disabled" diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 153fe1d9b..483a1a0a2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -37,8 +37,9 @@ jobs: - run: uv run --frozen pre-commit run --show-diff-on-failure --color=always --all-files --source ${{ github.event.pull_request.base.sha || github.event.before }} --origin ${{ github.event.pull_request.head.sha || github.event.after }} shell: bash - - uses: stefanzweifel/git-auto-commit-action@v5 - if: ${{ always() && env.REPO_SCOPED_TOKEN && github.event_name == 'pull_request' }} - with: - commit_message: "Automated pre-commit update" - push_options: "--no-verify" + # Temporarily disabled to prevent infinite loop with version updates + # - uses: stefanzweifel/git-auto-commit-action@v5 + # if: ${{ always() && env.REPO_SCOPED_TOKEN && github.event_name == 'pull_request' }} + # with: + # commit_message: "Automated pre-commit update" + # push_options: "--no-verify" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c2d87b75f..b982d045f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,22 +26,8 @@ permissions: jobs: build: - name: Build 3.${{ matrix.python }} ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ - ubuntu-latest, - ubuntu-22.04-arm, # https://github.com/actions/partner-runner-images/issues/37 https://github.com/orgs/community/discussions/148648#discussioncomment-12099554 - macos-latest, - macos-14-large - ] - python: [ - 12, - 13, - ] - + name: Build Pure Python Wheel + runs-on: ubuntu-latest steps: - name: Github context env: @@ -58,11 +44,10 @@ jobs: uses: astral-sh/setup-uv@v5.4 id: setup-uv with: - enable-cache: false + enable-cache: true prune-cache: false - python-version: 3.${{ matrix.python }} + python-version: "3.12" # Use single Python version for building version: '0.5.24' - cache-suffix: 3.${{ matrix.python }} - name: Fetch tags if: ${{ inputs.release-tag || startsWith(github.ref, 'refs/tags/') }} @@ -70,16 +55,14 @@ jobs: git branch git fetch --depth=1 origin +refs/tags/*:refs/tags/* - # TODO: add cbuildwheel cache - name: Build wheel - uses: pypa/cibuildwheel@v2.23.3 - env: - CIBW_BUILD: "*cp3${{ matrix.python }}*" + run: | + uv build --wheel --out-dir dist/ - uses: actions/upload-artifact@v4 with: - name: wheels-${{ matrix.os }}-3.${{ matrix.python }} - path: ./wheelhouse/*.whl + name: wheels + path: ./dist/*.whl release: if: ${{ inputs.release-tag || startsWith(github.ref, 'refs/tags/') }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95936543e..02c884cd3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -99,76 +99,7 @@ jobs: env: GITHUB_WORKSPACE: $GITHUB_WORKSPACE - parse-tests: - needs: access-check - if: contains(github.event.pull_request.labels.*.name, 'parse-tests') || github.event_name == 'push' || github.event_name == 'workflow_dispatch' - runs-on: ubuntu-latest-32 - steps: - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.sha }} - - - name: Setup environment - uses: ./.github/actions/setup-environment - - - name: Cache oss-repos - uses: ./.github/actions/setup-oss-repos - - - name: Install yarn and pnpm - run: | - npm install -g yarn & - npm install -g pnpm - - name: Test with pytest - timeout-minutes: 15 - env: - GITHUB_WORKSPACE: $GITHUB_WORKSPACE - run: | - uv run pytest \ - -n auto \ - -o junit_suite_name="${{github.job}}" \ - tests/integration/codemod/test_parse.py - - - uses: ./.github/actions/report - with: - flag: no-flag - codecov_token: ${{ secrets.CODECOV_TOKEN }} - - - name: Notify parse tests failure - uses: slackapi/slack-github-action@v2.1.0 - if: failure() && github.event_name == 'push' && false - with: - webhook: ${{ secrets.SLACK_WEBHOOK_URL }} - webhook-type: incoming-webhook - payload: | - { - "blocks": [ - { - "type": "header", - "text": { - "type": "plain_text", - "text": "❌ Parse Tests Failed", - "emoji": true - } - }, - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": "*Branch:* ${{ github.ref_name }}\n*Triggered by:* <${{ github.server_url }}/${{ github.actor }}|@${{ github.actor }}>\n\n*Details:*\n• <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View workflow run>" - } - }, - { - "type": "context", - "elements": [ - { - "type": "mrkdwn", - "text": "Failed at " - } - ] - } - ] - } integration-tests: needs: access-check diff --git a/.gitignore b/.gitignore index 2c38ccae0..00f68e676 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ alembic_versions_backup /.nvmrc **/build/test-results/test/TEST*.xml src/codegen/sdk/__init__.py +src/codegen/_version.py src/**/*.html .ccache/ uv-*.tar.gz diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eeea3f677..dda073d87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,6 @@ default_language_version: python: python3.13 repos: - - repo: https://github.com/ComPWA/taplo-pre-commit rev: v0.9.3 hooks: @@ -24,12 +23,7 @@ repos: - id: biome-check language: node additional_dependencies: ["@biomejs/biome@1.9.4"] - exclude: (src/codemods/eval)|(tests/unit/skills/snapshots)|(tests/unit/codegen/sdk/output)|(tests/integration/verified_codemods)|(docs/samples) - - repo: https://github.com/MarcoGorelli/cython-lint - rev: v0.16.6 - hooks: - - id: cython-lint - - id: double-quote-cython-strings + exclude: (src/codemods/eval)|(tests/unit/skills/snapshots)|(tests/unit/codegen/sdk/output)|(tests/integration/verified_codemods)|(docs/) - repo: https://github.com/kynan/nbstripout rev: 0.8.1 @@ -88,13 +82,13 @@ repos: args: ["--frozen", "--all-packages", "--all-extras"] - repo: https://github.com/hukkin/mdformat - rev: 0.7.22 # Use the ref you want to point at + rev: 0.7.22 # Use the ref you want to point at hooks: - - id: mdformat - language: python - # Optionally add plugins - additional_dependencies: - - mdformat-gfm - - mdformat-ruff - - mdformat-config - - mdformat-pyproject + - id: mdformat + language: python + # Optionally add plugins + additional_dependencies: + - mdformat-gfm + - mdformat-ruff + - mdformat-config + - mdformat-pyproject diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index a71cfdd77..000000000 --- a/Dockerfile +++ /dev/null @@ -1,38 +0,0 @@ -ARG PYTHON_VERSION=3.13 -ARG CODEGEN_BOT_GHE_TOKEN="" -FROM ghcr.io/astral-sh/uv:python${PYTHON_VERSION}-bookworm-slim AS base_uv -ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy -ENV GITHUB_WORKSPACE=/workspace -## Change the working directory to the `codegen-sdk` directory -FROM base_uv AS install-tools -RUN apt-get update && apt-get install -y build-essential curl git -RUN curl -fsSL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh -RUN bash nodesource_setup.sh -RUN apt-get update && apt-get install -y jq nodejs -RUN corepack enable -RUN --mount=type=cache,target=/root/.cache/uv uv pip install --system coverage -RUN --mount=type=cache,target=/root/.cache/uv uv tool install codecov-cli --python 3.10 -RUN --mount=type=cache,target=/root/.cache/uv uv tool install pre-commit --with pre-commit-uv -WORKDIR /codegen-sdk -ENTRYPOINT [ "uv", "run", "--frozen", "/bin/bash"] -FROM install-tools AS base-image -## Install dependencies -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=uv.lock,target=uv.lock \ - --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - --mount=type=bind,source=hatch.toml,target=hatch.toml \ - uv sync --frozen --no-install-workspace --all-extras -ADD . /codegen-sdk -## Sync the project -RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --frozen --all-extras -FROM base-image AS pre-commit -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=cache,target=/root/.cache/pre-commit \ - uv run pre-commit install-hooks -FROM base-image AS extra-repos -ARG CODEGEN_BOT_GHE_TOKEN="" -RUN uv run python -m tests.shared.codemod.commands clone-repos --clean-cache --extra-repos --token ${CODEGEN_BOT_GHE_TOKEN} -FROM base-image AS oss-repos -ARG CODEGEN_BOT_GHE_TOKEN="" -RUN uv run python -m tests.shared.codemod.commands clone-repos --clean-cache --token ${CODEGEN_BOT_GHE_TOKEN} diff --git a/architecture/1. plumbing/file-discovery.md b/architecture/1. plumbing/file-discovery.md deleted file mode 100644 index f4c3998d0..000000000 --- a/architecture/1. plumbing/file-discovery.md +++ /dev/null @@ -1,19 +0,0 @@ -# File Discovery - -The file discovery process is responsible for identifying and organizing all relevant files in a project that need to be processed by the SDK. - -## Initialization - -- We take in either a list of projects or a path to a filesystem. -- If we get a path, we'll detect the programming language, initialize the git client based on the path and get a Project - -## File discovery - -- We discover files using the git client so we can respect gitignored files -- We then filter files based on the language and the project configuration - - If specified, we filter by subdirectories - - We also filter by file extensions - -## Next Step - -After file discovery is complete, the files are passed to the [Tree-sitter Parsing](../parsing/tree-sitter.md) phase, where each file is parsed into a concrete syntax tree. diff --git a/architecture/2. parsing/A. Tree Sitter.md b/architecture/2. parsing/A. Tree Sitter.md deleted file mode 100644 index 3500b65fd..000000000 --- a/architecture/2. parsing/A. Tree Sitter.md +++ /dev/null @@ -1,33 +0,0 @@ -# Tree-sitter Parsing - -Tree-sitter is used as the primary parsing engine for converting source code into concrete syntax trees. Tree-sitter supports two modes of operation: - -```python -def my_function(): - pass -``` - -Tree sitter parses this as the following: - -``` -module [0, 0] - [3, 0] - function_definition [0, 0] - [1, 8] - name: identifier [0, 4] - [0, 15] - parameters: parameters [0, 15] - [0, 17] - body: block [1, 4] - [1, 8] - pass_statement [1, 4] - [1, 8] -``` - -- An CST mode which includes syntax nodes (for example, the `def` keyword, spaces, or parentheses). The syntax nodes are "anonymous" and don't have any semantic meaning. - - You don't see these nodes in the tree-sitter output, but they are there. -- A AST mode where we only focus on the semantic nodes (for example, the `my_function` identifier, and the `pass` statement). These are 'named nodes' and have semantic meaning. - - This is different from field names (like 'body'). These mean nothing about the node, they indicate what role the child node ('block') plays in the parent node ('function_definition'). - -## Implementation Details - -- We construct a mapping between file type and the tree-sitter grammar -- For each file given to us (via git), we parse it using the appropriate grammar - -## Next Step - -Once the concrete syntax trees are built, they are transformed into our abstract syntax tree representation in the [AST Construction](./B.%20AST%20Construction.md) phase. diff --git a/architecture/2. parsing/B. AST Construction.md b/architecture/2. parsing/B. AST Construction.md deleted file mode 100644 index 06a1cd48c..000000000 --- a/architecture/2. parsing/B. AST Construction.md +++ /dev/null @@ -1,77 +0,0 @@ -# AST Construction - -The tree-sitter CST/AST is powerful but it focuses on syntax highlighting and not semantic meaning. -For example, take decorators: - -```python -@decorator -def my_function(): - pass -``` - -``` -module [0, 0] - [3, 0] - decorated_definition [0, 0] - [2, 8] - decorator [0, 0] - [0, 10] - identifier [0, 1] - [0, 10] - definition: function_definition [1, 0] - [2, 8] - name: identifier [1, 4] - [1, 15] - parameters: parameters [1, 15] - [1, 17] - body: block [2, 4] - [2, 8] - pass_statement [2, 4] - [2, 8] - -``` - -You can see the decorated_definition node has a decorator and a definition. This makes sense for syntax highlighting - the decorator is highlighted seperately from the function definition. - -However, this is not useful for semantic analysis. We need to know that the decorator is decorating the function definition - there is a single function definition which may contain multiple decorators. -This becomes visibile when we consider function call chains: - -```python -a().b().c().d() -``` - -``` -module [0, 0] - [2, 0] - expression_statement [0, 0] - [0, 15] - call [0, 0] - [0, 15] - function: attribute [0, 0] - [0, 13] - object: call [0, 0] - [0, 11] - function: attribute [0, 0] - [0, 9] - object: call [0, 0] - [0, 7] - function: attribute [0, 0] - [0, 5] - object: call [0, 0] - [0, 3] - function: identifier [0, 0] - [0, 1] - arguments: argument_list [0, 1] - [0, 3] - attribute: identifier [0, 4] - [0, 5] - arguments: argument_list [0, 5] - [0, 7] - attribute: identifier [0, 8] - [0, 9] - arguments: argument_list [0, 9] - [0, 11] - attribute: identifier [0, 12] - [0, 13] - arguments: argument_list [0, 13] - [0, 15] -``` - -You can see that the chain of calls is represented as a deeply nested structure. This is not useful for semantic analysis or performing edits on these nodes. Therefore, when parsing we need to build an AST that is more useful for semantic analysis. - -## Implementation - -- For each file, we parse a file-specific AST -- We offer two modes of parsing: - - Pattern based parsing: It maps a particular node type to a semantic node type. For example, we broadly map all identifiers to the `Name` node type. - - Custom parsing: It takes a CST and builds a custom node type. For example, we can turn a decorated_definition node into a function_definition node with decorators. This involves careful arranging of the CST nodes into a new structure. - -## Pattern based parsing - -To do this, we need to build a mapping between the tree-sitter node types and our semantic node types. These mappings are language specific and stored in node_classes. They are processed by parser.py at runtime. We can access these via many functions - child_by_field_name, \_parse_expression, etc. These methods both wrap the tree-sitter methods and parse the tree-sitter node into our semantic node. - -## Custom parsing - -These are more complex and require more work. Most symbols (classes, functions, etc), imports, exports, and other complex constructs are parsed using custom parsing. - -## Statement parsing - -Statements have another layer of complexity. They are essentially pattern based but the mapping and logic is defined directly in the parser.py file. - -## Next Step - -After the AST is constructed, the system moves on to [Directory Parsing](./C.%20Directory%20Parsing.md) to build a hierarchical representation of the codebase's directory structure. diff --git a/architecture/2. parsing/C. Directory Parsing.md b/architecture/2. parsing/C. Directory Parsing.md deleted file mode 100644 index f25de2e29..000000000 --- a/architecture/2. parsing/C. Directory Parsing.md +++ /dev/null @@ -1,50 +0,0 @@ -# Directory Parsing - -The Directory Parsing system is responsible for creating and maintaining a hierarchical representation of the codebase's directory structure in memory. Directories do not hold references to the file itself, but instead holds the names to the files and does a dynamic lookup when needed. - -In addition to providing a more cohesive API for listing directory files, the Directory API is also used for [TSConfig](../3.%20imports-exports/C.%20TSConfig.md)-based (Import Resolution)[../3.%20imports-exports/A.%20Imports.md]. - -## Core Components - -The Directory Tree is constructed during the initial build_graph step in codebase_context.py, and is recreated from scratch on every re-sync. More details are below: - -## Directory Tree Construction - -The directory tree is built through the following process: - -1. The `build_directory_tree` method in `CodebaseContext` is called during graph initialization or when the codebase structure changes. -1. The method iterates through all files in the repository, creating directory objects for each directory path encountered. -1. For each file, it adds the file to its parent directory using the `_add_file` method. -1. Directories are created recursively as needed using the `get_directory` method with create_on_missing=True\`. - -## Directory Representation - -The `Directory` class provides a rich interface for working with directories: - -- **Hierarchy Navigation**: Access parent directories and subdirectories -- **File Access**: Retrieve files by name or extension -- **Symbol Access**: Find symbols (classes, functions, etc.) within files in the directory -- **Directory Operations**: Rename, remove, or update directories - -Each `Directory` instance maintains: - -- A reference to its parent directory -- Lists of files and subdirectories -- Methods to recursively traverse the directory tree - -## File Representation - -Files are represented by the `File` class and its subclasses: - -- `File`: Base class for all files, supporting basic operations like reading and writing content -- `SourceFile`: Specialized class for source code files that can be parsed into an AST - -Files maintain references to: - -- Their parent directory -- Their content (loaded dynamically to preserve the source of truth) -- For source files, the parsed AST and symbols - -## Next Step - -After the directory structure is parsed, the system can perform [Import Resolution](../3.%20imports-exports/A.%20Imports.md) to analyze module dependencies and resolve symbols across files. diff --git a/architecture/3. imports-exports/A. Imports.md b/architecture/3. imports-exports/A. Imports.md deleted file mode 100644 index cca5951ab..000000000 --- a/architecture/3. imports-exports/A. Imports.md +++ /dev/null @@ -1,60 +0,0 @@ -# Import Resolution - -Import resolution follows AST construction in the code analysis pipeline. It identifies dependencies between modules and builds a graph of relationships across the codebase. - -> NOTE: This is an actively evolving part of Codegen SDK, so some details here may be imcomplete, outdated, or incorrect. - -## Purpose - -The import resolution system serves these purposes: - -1. **Dependency Tracking**: Maps relationships between files by resolving import statements. -1. **Symbol Resolution**: Connects imported symbols to their definitions. -1. **Module Graph Construction**: Builds a directed graph of module dependencies. -1. **(WIP) Cross-Language Support**: Provides implementations for different programming languages. - -## Core Components - -### ImportResolution Class - -The `ImportResolution` class represents the outcome of resolving an import statement. It contains: - -- The source file containing the imported symbol -- The specific symbol being imported (if applicable) -- Whether the import references an entire file/module - -### Import Base Class - -The `Import` class is the foundation for language-specific import implementations. It: - -- Stores metadata about the import (module path, symbol name, alias) -- Provides the abstract `resolve_import()` method -- Adds symbol resolution edges to the codebase graph - -### Language-Specific Implementations - -#### Python Import Resolution - -The `PyImport` class extends the base `Import` class with Python-specific logic: - -- Handles relative imports -- Supports module imports, named imports, and wildcard imports -- Resolves imports using configurable resolution paths and `sys.path` -- Handles special cases like `__init__.py` files - -#### TypeScript Import Resolution - -The `TSImport` class implements TypeScript-specific resolution: - -- Supports named imports, default imports, and namespace imports -- Handles type imports and dynamic imports -- Resolves imports using TSConfig path mappings -- Supports file extension resolution - -## Implementation - -After file and directory parse, we loop through all import nodes and perform `add_symbol_resolution_edge`. This then invokes the language-specific `resolve_import` method that converts the import statement into a resolvable `ImportResolution` object (or None if the import cannot be resolved). This import symbol and the `ImportResolution` object are then used to add a symbol resolution edge to the graph, where it can then be used in future steps to resolve symbols. - -## Next Step - -After import resolution, the system analyzes [Export Analysis](./B.%20Exports.md) and handles [TSConfig Support](./C.%20TSConfig.md) for TypeScript projects. This is followed by [Type Analysis](../4.%20type-analysis/A.%20Type%20Analysis.md). diff --git a/architecture/3. imports-exports/B. Exports.md b/architecture/3. imports-exports/B. Exports.md deleted file mode 100644 index 0e42c98c4..000000000 --- a/architecture/3. imports-exports/B. Exports.md +++ /dev/null @@ -1,75 +0,0 @@ -# Export Analysis - -Some languages contain additional metadata on "exported" symbols, specifying which symbols are made available to other modules. Export analysis follows import resolution in the code analysis pipeline. It identifies and processes exported symbols from modules, enabling the system to track what each module makes available to others. - -## Core Components - -### Export Base Class - -The `Export` class serves as the foundation for language-specific export implementations. It: - -- Stores metadata about the export (symbol name, is default, etc.) -- Tracks the relationship between the export and its declared symbol -- Adds export edges to the codebase graph - -### TypeScript Export Implementation - -The `TSExport` class implements TypeScript-specific export handling: - -- Supports various export styles (named exports, default exports, re-exports) -- Handles export declarations with and without values -- Processes wildcard exports (`export * from 'module'`) -- Manages export statements with multiple exports - -#### Export Types and Symbol Resolution - -The TypeScript implementation handles several types of exports: - -1. **Declaration Exports** - - - Function declarations (including generators) - - Class declarations - - Interface declarations - - Type alias declarations - - Enum declarations - - Namespace declarations - - Variable/constant declarations - -1. **Value Exports** - - - Object literals with property exports - - Arrow functions and function expressions - - Classes and class expressions - - Assignment expressions - - Primitive values and expressions - -1. **Special Export Forms** - - - Wildcard exports (`export * from 'module'`) - - Named re-exports (`export { name as alias } from 'module'`) - - Default exports with various value types - -#### Symbol Tracking and Dependencies - -The export system: - -- Maintains relationships between exported symbols and their declarations -- Validates export names match their declared symbols -- Tracks dependencies through the codebase graph -- Handles complex scenarios like: - - Shorthand property exports in objects - - Nested function and class declarations - - Re-exports from other modules - -#### Integration with Type System - -Exports are tightly integrated with the type system: - -- Exported type declarations are properly tracked -- Symbol resolution considers both value and type exports -- Re-exports preserve type information -- Export edges in the codebase graph maintain type relationships - -## Next Step - -After export analysis is complete, for TypeScript projects, the system processes [TSConfig Support](./C.%20TSConfig.md) configurations. Then it moves on to [Type Analysis](../4.%20type-analysis/A.%20Type%20Analysis.md) to build a complete understanding of types and symbols. diff --git a/architecture/3. imports-exports/C. TSConfig.md b/architecture/3. imports-exports/C. TSConfig.md deleted file mode 100644 index b2362a7c8..000000000 --- a/architecture/3. imports-exports/C. TSConfig.md +++ /dev/null @@ -1,81 +0,0 @@ -# TSConfig Support - -TSConfig support is a critical component for TypeScript projects in the import resolution system. It processes TypeScript configuration files (tsconfig.json) to correctly resolve module paths and dependencies. - -## Purpose - -The TSConfig support system serves these purposes: - -1. **Path Mapping**: Resolves custom module path aliases defined in the tsconfig.json file. -1. **Base URL Resolution**: Handles non-relative module imports using the baseUrl configuration. -1. **Project References**: Manages dependencies between TypeScript projects using the references field. -1. **Directory Structure**: Respects rootDir and outDir settings for maintaining proper directory structures. - -## Core Components - -### TSConfig Class - -The `TSConfig` class represents a parsed TypeScript configuration file. It: - -- Parses and stores the configuration settings from tsconfig.json -- Handles inheritance through the "extends" field -- Provides methods for translating between import paths and absolute file paths -- Caches computed values for performance optimization - -## Configuration Processing - -### Configuration Inheritance - -TSConfig files can extend other configuration files through the "extends" field: - -1. Base configurations are loaded and parsed first -1. Child configurations inherit and can override settings from their parent -1. Path mappings, base URLs, and other settings are merged appropriately - -### Path Mapping Resolution - -The system processes the "paths" field in tsconfig.json to create a mapping between import aliases and file paths: - -1. Path patterns are normalized (removing wildcards, trailing slashes) -1. Relative paths are converted to absolute paths -1. Mappings are stored for efficient lookup during import resolution - -### Project References - -The "references" field defines dependencies between TypeScript projects: - -1. Referenced projects are identified and loaded -1. Their configurations are analyzed to determine import paths -1. Import resolution can cross project boundaries using these references - -## Import Resolution Process - -### Path Translation - -When resolving an import path in TypeScript: - -1. Check if the path matches any path alias in the tsconfig.json -1. If a match is found, translate the path according to the mapping -1. Apply baseUrl resolution for non-relative imports -1. Handle project references for cross-project imports - -### Optimization Techniques - -The system employs several optimizations: - -1. Caching computed values to avoid redundant processing -1. Early path checking for common patterns (e.g., paths starting with "@" or "~") -1. Hierarchical resolution that respects the configuration inheritance chain - -## Integration with Import Resolution - -The TSConfig support integrates with the broader import resolution system: - -1. Each TypeScript file is associated with its nearest tsconfig.json -1. Import statements are processed using the file's associated configuration -1. Path mappings are applied during the module resolution process -1. Project references are considered when resolving imports across project boundaries - -## Next Step - -After TSConfig processing is complete, the system proceeds to [Type Analysis](../4.%20type-analysis/A.%20Type%20Analysis.md) where it builds a complete understanding of types, symbols, and their relationships. diff --git a/architecture/4. type-analysis/A. Type Analysis.md b/architecture/4. type-analysis/A. Type Analysis.md deleted file mode 100644 index 9f2d9c28c..000000000 --- a/architecture/4. type-analysis/A. Type Analysis.md +++ /dev/null @@ -1,25 +0,0 @@ -# Type Analysis - -The type analysis system builds a complete understanding of types and symbols across the codebase. - -## Basic flow - -- Discover names that need to be resolved -- Resolve names -- Convert resolutions into graph edges - -## The resolution stack - -To accomplish this, we have an in house computation engine - the ResolutionStack. Each stack frame contains a reference to it's parent frame. However, a parent can have multiple child frames (IE: Union Types). - -When we resolve types on a node, we call resolved_type_frames to get the resolved types. Once we know what goes in the next frame, we call with_resolution_frame to construct the next frame. This is a generator that yields the next frame until we've resolved all the types. Resolved_type_frames is a property caches a list of the generated frames. -Therefore, once you have computed type resolution on a node, you don't need to recompute it. That way, we can start at arbitrary nodes without performance overhead. - -This is similar to how other's implement incremental computation engines with a few weaknesses: - -- There is only 1 query in the query engine -- Partial cache invalidation isn't implemented - -## Next Step - -After understanding the type analysis system overview, let's look at how we [walk the syntax tree](./B.%20Tree%20Walking.md) to analyze code structure. diff --git a/architecture/4. type-analysis/B. Tree Walking.md b/architecture/4. type-analysis/B. Tree Walking.md deleted file mode 100644 index c0c777dc4..000000000 --- a/architecture/4. type-analysis/B. Tree Walking.md +++ /dev/null @@ -1,49 +0,0 @@ -# Tree Walking - -To compute dependencies, we have to walk the entire AST for every file. -At a high level, the procedure is pretty simple - -```python -def compute_dependencies(self): - for child in self.children: - compute_dependencies(child) -``` - -We start at the root node and walk the tree until we have computed all dependencies. - -## Usage Kind identification - -We have to identify the kind of usage for each node. This is done by looking at the parent node and the child node. - -```python -def foo() -> c: - c() -``` - -We will classify the usage kind of the `c` callsite differently from the return type. - -```python -class PyFunction(...): - ... - - def _compute_dependencies(self, usage_kind: UsageKind): - self.return_type._compute_dependencies(UsageKind.RETURN_TYPE) - self.body._compute_dependencies(UsageKind.BODY) -``` - -By default, we just pass the usage kind to the children. - -## Resolvable Nodes - -At no step in this process described so far have we actually computed any dependencies. That's because there are some special nodes ("Resolvables") that do the heavy lifting. All of the tree walking is just to identify these nodes and the context they are used in. Resolvables are anything inheriting from `Resolvable`: - -- [Name Resolution](./C.%20Name%20Resolution.md) -- [Chained Attributes](./D.%20Chained%20Attributes.md) -- [Function Calls](./E.%20Function%20Calls.md) -- [Subscript Expression](./G.%20Subscript%20Expression.md) - -These are all processed using the [Type Analysis](./A.%20Type%20Analysis.md) to get the definition of the node. They are then converted into [Graph Edges](./H.%20Graph%20Edges.md) and added to the graph. - -## Next Step - -After understanding how we walk the tree, let's look at how we [resolve names](./C.%20Name%20Resolution.md) in the code. diff --git a/architecture/4. type-analysis/C. Name Resolution.md b/architecture/4. type-analysis/C. Name Resolution.md deleted file mode 100644 index bd6516708..000000000 --- a/architecture/4. type-analysis/C. Name Resolution.md +++ /dev/null @@ -1,70 +0,0 @@ -# Name Resolution - -The name resolution system handles symbol references, scoping rules, and name binding across the codebase. - -## What's in a name? - -A name is a `Name` node. It is just a string of text. -For example, `foo` is a name. - -```python -from my_module import foo - -foo() -``` - -Tree sitter parses this into: - -``` -module [0, 0] - [2, 0] - import_from_statement [0, 0] - [0, 25] - module_name: dotted_name [0, 5] - [0, 14] - identifier [0, 5] - [0, 14] - name: dotted_name [0, 22] - [0, 25] - identifier [0, 22] - [0, 25] - expression_statement [1, 0] - [1, 5] - call [1, 0] - [1, 5] - function: identifier [1, 0] - [1, 3] - arguments: argument_list [1, 3] - [1, 5] -``` - -We can map the identifier nodes to `Name` nodes. -You'll see there are actually 3 name nodes here: `foo`, `my_module`, and `foo`. - -- `my_module` is the module name. -- `foo` is the name imported from the module. -- `foo` is the name of the function being called. - -## Name Resolution - -Name resolution is the process of resolving a name to its definition. To do this, all we need to do is - -1. Get the name we're looking for. (e.g. `foo`) -1. Find the scope we're looking in. (in this case, the global file scope) -1. Recursively search the scope for the name (which will return the node corresponding `from my_module import foo`). -1. Use the type engine to get the definition of the name (which will return the function definition). - -## Scoping - -```python -# Local vs global scope -from my_module import foo, bar, fuzz - - -def outer(): - def foo(): ... - - foo() - bar() - fuzz() - - def fuzz(): ... -``` - -If we wanted to resolve `foo` in this case, we would start at the name foo, then check it's parent recursively till we arrive at the function outer. We would then check for the name foo and find there is a nested function with that name. We would then return the function definition. -However, if we wanted to resolve `bar`, we would then check for the name bar and find there is no nested function, variable, or parameter with that name. We would then return the import statement. -Finally for fuzz, when we check for the name fuzz, we would find there is a nested function with that name, but it is defined after the call to `fuzz()`. We would then return the import. - -## Next Step - -These simple cases let us build up to more complex cases. [Chained Attributes](./D.%20Chained%20Attributes.md) covers how we handle method and property access chains. diff --git a/architecture/4. type-analysis/D. Chained Attributes.md b/architecture/4. type-analysis/D. Chained Attributes.md deleted file mode 100644 index 57a3b941c..000000000 --- a/architecture/4. type-analysis/D. Chained Attributes.md +++ /dev/null @@ -1,89 +0,0 @@ -# Chained Attributes - -```python -class Foo: - def foo(self): ... - - -a = Foo() -a.foo() -``` - -A core functionality is to be able to calculate that `a.foo()` is a usage of `foo` in the `Foo` class. -To do this, we must first understand how tree-sitter parses the code. - -``` -module [0, 0] - [5, 0] - class_definition [0, 0] - [2, 11] - name: identifier [0, 6] - [0, 9] - body: block [1, 4] - [2, 11] - function_definition [1, 4] - [2, 11] - name: identifier [1, 8] - [1, 11] - parameters: parameters [1, 11] - [1, 17] - identifier [1, 12] - [1, 16] - body: block [2, 8] - [2, 11] - expression_statement [2, 8] - [2, 11] - ellipsis [2, 8] - [2, 11] - expression_statement [3, 0] - [3, 9] - assignment [3, 0] - [3, 9] - left: identifier [3, 0] - [3, 1] - right: call [3, 4] - [3, 9] - function: identifier [3, 4] - [3, 7] - arguments: argument_list [3, 7] - [3, 9] - expression_statement [4, 0] - [4, 7] - call [4, 0] - [4, 7] - function: attribute [4, 0] - [4, 5] - object: identifier [4, 0] - [4, 1] - attribute: identifier [4, 2] - [4, 5] - arguments: argument_list [4, 5] - [4, 7] -``` - -If we look at this parse tree - we can see that the `a.foo()` call has a name of type attribute. The object of the call is an identifier for `a`, and the `foo` is an attribute of the identifier for `a`. Typescript has a similar structure. These are the core building blocks of chained attributes. -Chained attributes contain 2 parts: - -1. The object: `a` -1. The attribute: `foo` - -All we must do to resolve the definition of `a.foo` is - -1. Find the definition of the object `a` (the class `Foo`) -1. Get the attribute (`foo`) on the resolved object (`Foo`) (the function `foo`) -1. Resolve the attribute to it's original definition (in this case, the function `foo`) - -## Step 1: Resolve the object - -We can resolve the object by calling resolved_types to get potential types of the object. -If it is a name (like `a`) we can use the name resolution to get the definition of the name. -If it is another chained attribute, we can recursively resolve the chained attribute. -If the original type is a union, we can operate on multiple types and return all the possible results. - -## Step 2: Get the attribute - -We can get the attribute by calling resolve_attribute on the resolved object. Nodes which implement this inherit from `HasAttribute`. Examples include: - -- Class -- File -- Type aliases -- Enums - -## Step 3: Resolve the attribute - -Finally, we can resolve the attribute by calling resolved_types on the attribute. This is useful in cases, particularly for attributes of the class like the following: - -```python -def fuzz(): ... - - -class Foo: - foo = fuzz - - -a = Foo() -a.foo() -``` - -We can resolve the attribute by calling resolved_types on the attribute to go from the attribute (foo) to the underlying resolved type (fuzz). - -## Next Step - -After handling chained attributes, the system moves on to [Function Calls](./E.%20Function%20Calls.md) analysis for handling function and method invocations. diff --git a/architecture/4. type-analysis/E. Function Calls.md b/architecture/4. type-analysis/E. Function Calls.md deleted file mode 100644 index d4db8cd6b..000000000 --- a/architecture/4. type-analysis/E. Function Calls.md +++ /dev/null @@ -1,64 +0,0 @@ -# Function Call - -At a first glance, function calls are simple. We can resolve the function call by looking up the function name in the current scope. - -However, there are some complexities to consider. - -## Constructors - -In Python, we can call a class definition as if it were a function. This is known as a constructor. - -```python -class Foo: - def __init__(self): ... - - -a = Foo() -``` - -This changes the behavior of the function call from the name. The name resolves to Foo (the class definition) but the constructor resolves to the function definition. - -## Imports - -```typescript -require('foo') -``` - -In this case, we need to resolve the import statement to the module definition. - -## Return Types - -```python -class Foo: - def foo(self) -> int: - return 1 - - -class Bar: - def bar(self) -> Foo: ... - - -a = Bar() -a.bar().foo() -``` - -In this case, we need to resolve the return type of the function to the type of the return value. However, the function definition is not the same as the return type. This means we now have 3 different things going on with function calls: - -1. Resolving the function definition -1. Resolving the return type -1. Computing what this function call depends on (both the function definition and the arguments passed to the function) - -## Generics - -```python -def foo[T](a: list[T]) -> T: ... - - -foo([1, 2, 3]) -``` - -Generics depend on the types of the arguments to the function. We need to resolve the types of the arguments to the function to determine the type of the generic. [Generics](./F.%20Generics.md) covers how we handle generics. - -## Next Step - -After understanding function calls, let's look at how we handle [Generics](./F.%20Generics.md) in the type system. diff --git a/architecture/4. type-analysis/F. Generics.md b/architecture/4. type-analysis/F. Generics.md deleted file mode 100644 index 46df52bfc..000000000 --- a/architecture/4. type-analysis/F. Generics.md +++ /dev/null @@ -1,7 +0,0 @@ -# Generics Analysis - -TODO - -## Next Step - -After generics analysis, the system handles [Subscript Expressions](./G.%20Subscript%20Expression.md) for array and dictionary access. diff --git a/architecture/4. type-analysis/G. Subscript Expression.md b/architecture/4. type-analysis/G. Subscript Expression.md deleted file mode 100644 index e2bb1a80a..000000000 --- a/architecture/4. type-analysis/G. Subscript Expression.md +++ /dev/null @@ -1,7 +0,0 @@ -# Subscript Expression - -TODO - -## Next Step - -After handling subscript expressions, the system builds [Graph Edges](./H.%20Graph%20Edges.md) to represent relationships between types and symbols. diff --git a/architecture/4. type-analysis/H. Graph Edges.md b/architecture/4. type-analysis/H. Graph Edges.md deleted file mode 100644 index 46efd3c46..000000000 --- a/architecture/4. type-analysis/H. Graph Edges.md +++ /dev/null @@ -1,59 +0,0 @@ -# Graph Edges - -The SDK contains a graph of nodes and edges. -Nodes are the core of the graph and represent the symbols in the codebase. Examples include: - -- Symbols: Classes, functions, Assignments, etc. -- Imports, Exports -- Files -- Parameters, Attributes - Edges are between - each containes 4 elements: -- Source: The node that the edge is coming from -- Target: The node that the edge is going to -- Type: The type of the edge -- Metadata: Additional information about the edge - -## Edge Types - -We have 4 types of [edges](../src/codegen/sdk/enums.py#L10) - -- IMPORT_SYMBOL_RESOLUTION: An edge from an import to a symbol -- EXPORT: An edge from a symbol to an export -- SUBCLASS: An edge from a symbol to a subclass -- SYMBOL_USAGE: An edge from a symbol to a usage - -The only edges that are used in almost every API are SYMBOL_USAGE edges. They are also the only ones that have additional metadata. - -## Edge construction order - -To compute the graph we follow a specific order: - -1. Import edges are added first - - This is completely independent of the type engine -1. Symbol edges are added next - - these may export symbols that are imported from other files. - - This is almost entirely independent of the type engine -1. Subclass edges are added next - - these may reference symbols that are imported or exported from other files. - - This is fully dependent on the type engine -1. Usage edges are added last - - they reference symbols that are imported or exported from other files - - This is fully dependent on the type engine - - Subclass edges are computed beforehand as a performance optimization - -## Usages - -SYMBOL_USAGE edges contain additional [metadata](../src/codegen/sdk/core/dataclasses/usage.py) - -- match: The exact match of the usage -- usage_symbol: The symbol this object is used in. Derived from the match object -- usage_type: How this symbol was used. Derived from the resolution stack -- imported_by: The import that imported this symbol. Derived from the resolution stack -- kind: Where this symbol was used (IE: in a type parameter or in the body of the class, etc). Derived from the compute dependencies function - You may notice these edges are actually between the usage symbol and the match object but the match object is not on the graph. This way we have constructed triple edges. -- They are technically edges between the usage symbol and the symbol contained in the match object -- The edge metadata contains the match object - -## Next Step - -After constructing the type graph, the system moves on to [Edit Operations](../5.%20performing-edits/A.%20Edit%20Operations.md) where it can safely modify code while preserving type relationships. diff --git a/architecture/5. performing-edits/A. Transactions.md b/architecture/5. performing-edits/A. Transactions.md deleted file mode 100644 index c27c7e65f..000000000 --- a/architecture/5. performing-edits/A. Transactions.md +++ /dev/null @@ -1,54 +0,0 @@ -# Transactions - -Transactions represent atomic changes to files in the codebase. Each transaction defines a specific modification that can be queued, validated, and executed. - -## Transaction Types - -The transaction system is built around a base `Transaction` class with specialized subclasses: - -### Content Transactions - -- **RemoveTransaction**: Removes content between specified byte positions -- **InsertTransaction**: Inserts new content at a specified byte position -- **EditTransaction**: Replaces content between specified byte positions - -### File Transactions - -- **FileAddTransaction**: Creates a new file -- **FileRenameTransaction**: Renames an existing file -- **FileRemoveTransaction**: Deletes a file - -## Transaction Priority - -Transactions are executed in a specific order defined by the `TransactionPriority` enum: - -1. **Remove** (highest priority) -1. **Edit** -1. **Insert** -1. **FileAdd** -1. **FileRename** -1. **FileRemove** - -This ordering ensures that content is removed before editing or inserting, and that all content operations happen before file operations. - -## Key Concepts - -### Byte-Level Operations - -All content transactions operate at the byte level rather than on lines or characters. This provides precise control over modifications and allows transactions to work with any file type, regardless of encoding or line ending conventions. - -### Content Generation - -Transactions support both static content (direct strings) and dynamic content (generated at execution time). This flexibility allows for complex transformations where the new content depends on the state of the codebase at execution time. - -Most content transactions use static content, but dynamic content is supported for rare cases where the new content depends on the state of other transactions. One common example is handling whitespace during add and remove transactions. - -### File Operations - -File transactions are used to create, rename, and delete files. - -> NOTE: It is important to note that most file transactions such as `FileAddTransaction` are no-ops (AKA skiping Transaction Manager) and instead applied immediately once the `create_file` API is called. This allows for created files to be immediately available for edit and use. The reason file operations are still added to Transaction Manager is to help with optimizing graph re-parse and diff generation. (Keeping track of which files exist and don't exist anymore). - -## Next Step - -After understanding the transaction system, they are managed by the [Transaction Manager](./B.%20Transaction%20Manager.md) to ensure consistency and atomicity. diff --git a/architecture/5. performing-edits/B. Transaction Manager.md b/architecture/5. performing-edits/B. Transaction Manager.md deleted file mode 100644 index 4ed78a750..000000000 --- a/architecture/5. performing-edits/B. Transaction Manager.md +++ /dev/null @@ -1,93 +0,0 @@ -# Transaction Manager - -The Transaction Manager coordinates the execution of transactions across multiple files, handling conflict resolution, and enforcing resource limits. - -## High-level Concept - -Since all node operations are on byte positions of the original file, multiple operations that change the total byte length of the file will result in offset errors and broken code. - -Give this example over here: - -``` -Original: FooBar -Operations: Remove "Foo" (bytes 0-3), Insert "Hello" (bytes 0-5) - Remove "Bar" (bytes 3-6), Insert "World" (bytes 3-7) -``` - -If these operations were applied in order, the result would be: - -``` -Result: FooBar -Operation: Remove "Foo" (bytes 0-3), Insert "Hello" (bytes 0-5) -Result: HelloBar -Operation: Remove "Bar" (bytes 3-6), Insert "World" (bytes 3-7) -Result: HelWorldar -``` - -Resulting in an invalid output. - -⭐ The key with TransactionManager is that it queues up all transactions in a given Codemod run, the applies all of the ***backwards*** from the last byte range to the first. Given the same example as above but applied backwards: - -``` -Result: FooBar -Operation: Remove "Bar" (bytes 3-6), Insert "World" (bytes 3-7) -Result: FooWorld -Operation: Remove "Foo" (bytes 0-3), Insert "Hello" (bytes 0-5) -Result: HelloWorld -``` - -TransactionManager also performs some additional operations such detecting conflicts and coordinating (some basic) conflict resolutions. Overall, the core responsibilities are as follows: - -1. **Transaction Queueing**: Maintains a queue of pending transactions organized by file -1. **Conflict Resolution**: Detects and resolves conflicts between transactions -1. **Transaction Execution**: Applies transactions in the correct order -1. **Resource Management**: Enforces limits on transaction count and execution time -1. **Change Tracking**: Generates diffs for applied changes - -## Sorting Transactions - -Before execution, transactions are sorted based on (in this priority): - -1. Position in the file (higher byte positions first) -1. Transaction type (following the priority order) -1. User-defined priority -1. Creation order - -This sorting ensures that transactions are applied in a deterministic order that minimizes conflicts. Larger byte ranges are always edited first, removals happen before insertions, and older transactions are applied before newer ones. - -## Conflict Resolution - -### Conflict Types - -The manager identifies several types of conflicts: - -1. **Overlapping Transactions**: Multiple transactions affecting the same byte range -1. **Contained Transactions**: One transaction completely contained within another -1. **Adjacent Transactions**: Transactions affecting adjacent byte ranges - -In it's current implementation, TransactionManager only handles Contained Transactions that are trivially sovable. (If a remove transaction completely overlaps with another remove transaction, only the larger one will be kept) - -## Resource Management - -The Transaction Manager enforces two types of limits: - -1. **Transaction Count**: Optional maximum number of transactions -1. **Execution Time**: Optional time limit for transaction processing - -These limits prevent excessive resource usage and allow for early termination of long-running operations. - -## Commit Process - -The commit process applies queued transactions to the codebase: - -1. Transactions are sorted according to priority rules -1. Files are processed one by one -1. For each file, transactions are executed in order -1. Diffs are collected for each modified file -1. The queue is cleared after successful commit - -The diff's are later used during resyc to efficiently update the codebase graph as changes occur. See [Incremental Computation](../6.%20incremental-computation/A.%20Overview.md) for more details. - -## Next Step - -After managing transactions, the system handles [Incremental Computation](../6.%20incremental-computation/A.%20Overview.md) to efficiently update the codebase graph as changes occur. diff --git a/architecture/6. incremental-computation/A. Overview.md b/architecture/6. incremental-computation/A. Overview.md deleted file mode 100644 index 741cb426f..000000000 --- a/architecture/6. incremental-computation/A. Overview.md +++ /dev/null @@ -1,47 +0,0 @@ -# Incremental Computation - -After we performed some changes to the codebase, we may need to recompute the codebase graph. -This is not a trivial task, because we need to be able to recompute the codebase graph incrementally and efficiently. - -## Use Cases - -### 1. Repeated Moves - -```python -# file1.py -def foo(): - return bar() - - -def bar(): - return 42 -``` - -Let's move symbol `bar` to `file2.py` - -```python -# file2.py -def bar(): - return 42 -``` - -Then we move symbol `foo` to `file3.py` - -```python -# file3.py -from file2 import bar - - -def foo(): - return bar() -``` - -You'll notice we have added an import from file2, not file1. This means that before we can move foo to file3, we need to sync the graph to reflect the changes in file2. - -### 2. Branching - -If we want to checkout a different branch, we need to update the baseline state to the git commit of the new branch and recompute the codebase graph. - -## Next Step - -After understanding the overview of incremental computation, let's look at how we [detect changes](./B.%20Change%20Detection.md) in the codebase. diff --git a/architecture/6. incremental-computation/B. Change Detection.md b/architecture/6. incremental-computation/B. Change Detection.md deleted file mode 100644 index ca3322762..000000000 --- a/architecture/6. incremental-computation/B. Change Detection.md +++ /dev/null @@ -1,58 +0,0 @@ -# Change Detection - -## Lifecycle of an operation on the codebase graph - -Changes will go through 4 states. By default, we do not apply changes to the codebase graph, only to the filesystem. - -### Pending transactions - -After calling an edit or other transaction method, the changes are stored in a pending transaction. Pending transactions will be committed as described in the previous chapter. - -### Pending syncs - -After a transaction is committed, the file is marked as a pending sync. This means the filesystem state has been updated, but the codebase graph has not been updated yet. - -### Applied syncs - -When we sync the graph, we apply all the pending syncs and clear them. The codebase graph is updated to reflect the changes. We track all the applied syncs in the codebase graph. - -### Saved/baseline state - -Finally, we can set the baseline state to a git commit. This is the state we target when we reset the codebase graph. When we checkout branches, we update the baseline state. - -## Change Detection - -When we sync or build the graph, first we build a list of all files in 3 categories: - -- Removed files -- Added files -- Files to repase - -For example, if we move a file, it will be in the added and removed files -If we add a file, it will be in the added files even if we peformed edits on it later. - -## Codebase.commit logic - -We follow the following logic - -1. Commit all pending transactions -1. Write all buffered files to the disk -1. Store this to pending changes (usually we will skip the remaining steps if we commit without syncing the graph) -1. Build list of removed, added and modified files from pending changes -1. For removed files, we need to remove all the edges that point to the file. -1. For added files, we need to add all the edges that point to the file. -1. For modified files, we remove all the edges that point to the file and add all the edges that point to the new file. This is complicated since edges may pass through the modified file and need to be intelligently updated. -1. Mark all pending changes as applied - -## Reset logic - -Reset is just the inverse of commit. We need to - -1. Cancel all pending transactions -1. Restore file state to the state to the target git commit -1. Clear all pending changes to the graph -1. Reverse all applied syncs to the graph - -## Next Step - -After detecting changes, the system performs [Graph Recomputation](./C.%20Graph%20Recomputation.md) to update the dependency graph efficiently. diff --git a/architecture/6. incremental-computation/C. Graph Recomputation.md b/architecture/6. incremental-computation/C. Graph Recomputation.md deleted file mode 100644 index 2e2f378ee..000000000 --- a/architecture/6. incremental-computation/C. Graph Recomputation.md +++ /dev/null @@ -1,40 +0,0 @@ -# Graph Recomputation - -## Node Reparsing - -Some limitations we encounter are: - -- It is non-trivial to update tree sitter nodes, and the SDK has no method to do this. -- Therefore, all existing nodes are invalidated and need to be recomputed every time filesystem state changes. - -Therefore, to recompute the graph, we must first have the filesystem state updated. Then we can remove all nodes in the modified files and create new nodes in the modified files. - -## Edge Recomputation - -- Nodes may either use (out edges) or be used by (in edges) other nodes. - - Recomputing the out-edges is straightforward, we just need to reparse the file and compute dependencies again. - - Recomputing the in-edges is more difficult. - - The basic algorithm of any incremental computation engine is to: - - Detect what changed - - Update that query with the new data - - If the output of the query changed, we need to update all the queries that depend on that query. - -### Detecting what changed - -A difficulty is that the nodes are completely freshed for updated files. Therefore, this by default will include all nodes in updated files. - -### Updating the query - -To do this, we: - -- Wipe the entire cache of the query engine -- Remove all existing out edges of the node -- Recompute dependencies of that node - -### Update what changed - -This part has not been fully implemented yet. Currently, we update all the nodes that are descendants of the changed node and all the nodes in the file. - -## Next Step - -After graph recomputation, the system is ready for the next set of operations. The cycle continues with [File Discovery](../plumbing/file-discovery.md) for any new changes. diff --git a/architecture/architecture.md b/architecture/architecture.md deleted file mode 100644 index dd044e4dc..000000000 --- a/architecture/architecture.md +++ /dev/null @@ -1,113 +0,0 @@ -# Architecture of the Codegen SDK - -This is a technical document explaining the architecture of the Codegen SDK. - -## Purpose of the SDK - -This SDK is designed to accomplish a large set of use cases in one tool: - -- Parsing large, enterprise-scale codebases -- Making syntax aware changes to code while respecting original formatting -- Being user-friendly and easy to use -- Able to quickly execute large scale refactorings against a codebase -- Supporting multiple languages with common abstractions -- Aware of both project structure (tsconfig.json, pyproject.toml, etc.) and language-specific structure (imports, etc.) -- Able to perform type resolution -- Responding to changes to the codebase and updating the graph - -### Performance - -A key problem is performance. We must be able to quickly respond to user requests on enterprise codebases (IE: renaming a symbol). However, we don't know what those requests are in advance and the scope of these requests can be quite massive (They may choose to iterate over a large number of symbols and their usages). To respond to these problems, we introduced codegen cloud. We split operations into two parts: - -- A "parse" step that builds up a graph of the codebase - - This can take a long time to complete, but it only needs to be done once - - This computes the entire graph of the codebase -- A "run" step that performs operations on the codebase - - This can be done quickly, but it needs to be done many times - - This uses the graph to perform operations on the codebase - -This allows us to perform operations on the codebase without having to parse it every time. - -## Existing Solutions - -To accomplish these goals, we can look at existing classes of solutions: - -### Language Server Architecture - -The immediate question is: why not use a language server? They have a lot of the same goals as codegen, but do not address many of our goals: - -- Language servers can handle many of these same use cases, but they are not as performant as we need. -- Generally, language servers compute their results lazily. This doesn't work for us because we need to perform a large number of operations on the codebase. -- While the LSP protocol is powerful, it is not designed to be scriptable the way codegen is. -- In Python, many of the language servers are an aglamation of many different tools and libraries. None are very good at refactoring or offer the comprehensive set of features that codegen does. - -Generally language servers parse codebases in response to user actions. This is not a good fit for us because we need to perform a large number of operations on the codebase without knowing which symbols are being changed or queried. - -### Compiler Architecture - -Many of the same goals can be accomplished with a compiler. C However, compilers are not as user-friendly as we need. - -- They do not generally offer easy-to-use apis -- They do not focus on refactoring code after parsing -- They generally don't handle graph-updates -- They aren't common or complete in python/typescript - -Generally compilers build up knowledge of the entire codebase in a single pass. This is a much better fit for our use case. - -## Architecture - -The codegen SDK combines aspects of both systems to accomplish our goals. -At a high level our architecture is: - -1. We discover files to parse - -## Processing Steps - -The SDK processes code through several distinct steps: - -1. \[File Discovery\](./1. plumbing/file-discovery.md) - - - Project structure analysis - - File system traversal - -1. \[Tree-sitter Parsing\](./2. parsing/A. Tree Sitter.md) - - - Initial syntax tree construction - - Language-specific parsing rules - - Error recovery - -1. \[AST Construction\](./2. parsing/B. AST Construction.md) - - - Abstract syntax tree building - - Node type assignment - - Syntax validation - -1. \[Import & Export Resolution\](./3. imports-exports/A. Imports.md) - - - Module dependency analysis - - \[Export Analysis\](./3. imports-exports/B. Exports.md) - - \[TSConfig Support\](./3. imports-exports/C. TSConfig.md) - - Path resolution - -1. \[Type Analysis\](./4. type-analysis/A. Type Analysis.md) - - - \[Type Analysis\](./4. type-analysis/A. Type Analysis.md) - - \[Tree Walking\](./4. type-analysis/B. Tree Walking.md) - - \[Name Resolution\](./4. type-analysis/C. Name Resolution.md) - - \[Chained Attributes\](./4. type-analysis/D. Chained Attributes.md) - - \[Function Calls\](./4. type-analysis/E. Function Calls.md) - - \[Generics\](./4. type-analysis/F. Generics.md) - - \[Subscript Expression\](./4. type-analysis/G. Subscript Expression.md) - - \[Graph Edges\](./4. type-analysis/H. Graph Edges.md) - -1. \[Performing Edits\](./5. performing-edits/A. Edit Operations.md) - - - \[Transaction Manager\](./5. performing-edits/B. Transaction Manager.md) - - Change validation - - Format preservation - -1. \[Incremental Computation\](./6. incremental-computation/A. Overview.md) - - - \[Detecting Changes\](./6. incremental-computation/B. Change Detection.md) - - \[Recomputing Graph\](./6. incremental-computation/C. Graph Recomputation.md) - - Cache invalidation diff --git a/architecture/external/dependency-manager.md b/architecture/external/dependency-manager.md deleted file mode 100644 index ed8e42a3d..000000000 --- a/architecture/external/dependency-manager.md +++ /dev/null @@ -1,100 +0,0 @@ -# Dependency Manager - -> WARNING: Dependency manager is an experimental feature designed for Codegen Cloud! The current implementation WILL delete any existing `node_modules` folder! - -## Motivation - -A future goal of Codegen is to support resolving symbols directly from dependencies, instead of falling back to `ExternalModule`s. (In fact, some experimental Codegen features such as [Type Engine](./type-engine.md) already parse and use 3rd party dependencies from `node_modules`) - -This requires us to pull and install dependencies from a repository's `package.json`. However, simply installing dependencies from `package.json` is not enough, as many projects require internal dependencies that use custom NPM registries. Others require custom post-install scripts that may not run on our codemod environments. - -Dependency Manager is an experimental solution to this problem. It creates a shadow tree of `package.json` files that includes all core dependencies and settings from the repository's original `package.json` without any custom registries or potentially problematic settings. - -> NOTE: Currently, this is only implemented for TypeScript projects. - -## Implementation - -Given this example codebase structure: - -``` -repo/ -├── package.json -├── node_modules/ -├── src/ -│ ├── frontend/ -│ │ └── package.json -│ └── backend/ -│ └── package.json -└── tests/ - └── package.json -``` - -Dependency Manager first deletes any existing `node_modules` folder in the user's repository. After this step, Dependency Manager initializes itself to use the correct version of NPM, Yarn, or PNPM for the user's repository. - -Dependency Manager then creates a "shadow copy" of the repository's original `package.json` file. This shadow copy is used to later revert any changes made by Codegen before running codemods. With these steps, the codebase structure now looks like this: - -``` -repo/ -├── package.json -├── package.json.gs_internal.bak -├── src/ -│ ├── frontend/ -│ │ └── package.json -│ │ └── package.json.gs_internal.bak -│ └── backend/ -│ └── package.json -│ └── package.json.gs_internal.bak -└── tests/ - └── package.json - └── package.json.gs_internal.bak -``` - -Next, Dependency Manager iterates through all the `package.json` files and creates a "clean" version of each file. This "clean" version only includes a subset of information from the original, including: - -- Name -- Version -- Package Manager Details -- Workspaces - -Most importantly, this step iterates through `dependencies` and `devDependencies` of each `package.json` file and validates them against the npm registry. If a package is not found, it is added to a list of invalid dependencies and removed from the `package.json` file. - -After this step, the codebase structure now looks like this: - -``` -repo/ -├── package.json (modified) -├── package.json.gs_internal.bak -├── src/ -│ ├── frontend/ -│ │ └── package.json (modified) -│ │ └── package.json.gs_internal.bak -│ └── backend/ -│ └── package.json (modified) -│ └── package.json.gs_internal.bak -└── tests/ - └── package.json (modified) - └── package.json.gs_internal.bak -``` - -After the shadow and cleaning steps, Dependency Manager proceeds to install the user's dependencies through NPM, Yarn, or PNPM, depending on the detected installer type. Finally, Dependency Manager restores the original `package.json` files and removes the shadow copies. - -The final codebase structure looks like this: - -``` -repo/ -├── package.json -├── node_modules/ -├── src/ -│ ├── frontend/ -│ │ └── package.json -│ └── backend/ -│ └── package.json -└── tests/ - └── package.json -``` - -If all goes well, Dependency Manager will have successfully installed the user's dependencies and prepared the codebase for codemods. - -## Next Step - -The dependency manager works closely with the [Type Engine](./type-engine.md) to ensure type compatibility across dependencies. diff --git a/architecture/external/type-engine.md b/architecture/external/type-engine.md deleted file mode 100644 index 42b96f643..000000000 --- a/architecture/external/type-engine.md +++ /dev/null @@ -1,25 +0,0 @@ -# Type Engine - -Type Engine is an experimental feature of Codegen that leverages the [TypeScript Compiler API](https://github.com/microsoft/TypeScript/wiki/Using-the-Compiler-API) to provide deeper insight into a user's codebase (such as resolving return types). - -> NOTE: Currently, this is only implemented for TypeScript projects. - -There are currently two experimental implementations of TypeScript's Type Engine: an external process-based implementation and a V8-based implementation. - -## Implementation (External Process) - -During codebase parsing, the Type Engine spawns a type inference subprocess (defined in `src/codegen/sdk/typescript/external/typescript_analyzer/run_full.ts`) that concurrently parses the codebase with the TypeScript API to resolve return types. The final analyzer output is placed in `/tmp/typescript-analysis.json` and is read in by Codegen to resolve return types. - -## Implementation (V8) - -The V8-based implementation is much more flexible and powerful in comparison but is currently not as stable. It uses the [PyMiniRacer](https://github.com/sqreen/py_mini_racer) package to spawn a V8-based JavaScript engine that can parse the codebase with the TypeScript API to resolve return types. - -The entirety of `src/codegen/sdk/typescript/external/typescript_analyzer` is compiled down using [Rollup.js](https://rollupjs.org/) into a single `index.js` file. A couple of patches are applied to the engine source to remove `require` and `export` statements, which are not supported by MiniRacer. - -Then, the entire `index.js` file is loaded into the MiniRacer context. To work around file read limitations with V8, an in-memory shadow filesystem is created that mimics the user's repository's filesystem. These are defined in `fsi.ts` (`FileSystemInterface`) and `fs_proxy.ts` (`ProxyFileSystem`). The TypeScript Compiler then uses the custom `ProxyFileSystem.readFile` function instead of the traditional `fs.readFile`. - -Once the analyzer is initialized and the codebase is parsed, the entire TypeScript Compiler API is available in the MiniRacer context. The analyzer can then be used to resolve return types for any function in the codebase or to parse the codebase and generate a full type analysis. - -## Next Step - -The type engine works in conjunction with the [Dependency Manager](./dependency-manager.md) to ensure type safety across project dependencies. diff --git a/codegen-examples/CONTRIBUTING.md b/codegen-examples/CONTRIBUTING.md deleted file mode 100644 index 752b5d6aa..000000000 --- a/codegen-examples/CONTRIBUTING.md +++ /dev/null @@ -1,19 +0,0 @@ -# Contributing to Codegen Examples - -Thank you for your interest in contributing to `codegen-examples`! This document outlines the process and guidelines for contributing. - -## Contributor License Agreement - -By contributing to Codegen Examples, you agree that: - -1. Your contributions will be licensed under the project's license. -1. You have the right to license your contribution under the project's license. -1. You grant Codegen a perpetual, worldwide, non-exclusive, royalty-free license to use your contribution. - -## Pull Request Process - -1. Fork the repository and create your branch from `main`. -1. Ensure your code passes all tests. -1. Update documentation as needed. -1. Submit a pull request to the `main` branch. -1. Include a clear description of your changes in the PR. diff --git a/codegen-examples/LICENSE b/codegen-examples/LICENSE deleted file mode 100644 index 261eeb9e9..000000000 --- a/codegen-examples/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/codegen-examples/README.md b/codegen-examples/README.md deleted file mode 100644 index 3e430024c..000000000 --- a/codegen-examples/README.md +++ /dev/null @@ -1,60 +0,0 @@ -# Codegen Examples - -[![Documentation](https://img.shields.io/badge/docs-docs.codegen.com-blue)](https://docs.codegen.com) - -This is a collection of examples using [Codegen](https://codegen.com). You can use these examples to learn how to use Codegen and build custom code transformations. - -## Setup - -We recommend using [`uv`](https://github.com/astral-sh/uv) with Python 3.13 for the best experience. - -To install Codegen, please follow the [official installation guide](https://docs.codegen.com/introduction/installation). Once Codegen is installed, use these steps to run the examples in this repository: - -Install the Codegen CLI globally - -```bash -uv tool install codegen -``` - -Initialize Codegen in your project - -```bash -codegen init -``` - -Activate the virtual environment - -```bash -source .codegen/.venv/bin/activate -``` - -Your environment is now ready to run example codemods. - -### IDE Configuration (Optional) - -To configure your IDE for optimal use with Codegen, follow our [IDE setup guide](https://docs.codegen.com/introduction/ide-usage#configuring-your-ide-interpreter). - -## Examples - -Within the examples folder, each subdirectory contains a self-contained example with: - -- An explanation of the transformation (`README.md`) -- A Codegen script that performs the transformation (`run.py`) -- Sample code to transform, if not using a repository (`input_repo/`) - -To see a transformation, simply run the `run.py` script within the desired directory. - -## Learn More - -- [Documentation](https://docs.codegen.com) -- [Getting Started Guide](https://docs.codegen.com/introduction/getting-started) -- [Tutorials](https://docs.codegen.com/tutorials/at-a-glance) -- [API Reference](https://docs.codegen.com/api-reference) - -## Contributing - -Have a useful example to share? We'd love to include it! Please see our [Contributing Guide](CONTRIBUTING.md) for instructions. - -## License - -The [Apache 2.0 license](LICENSE). diff --git a/codegen-examples/STRUCTURE.md b/codegen-examples/STRUCTURE.md deleted file mode 100644 index f4695135d..000000000 --- a/codegen-examples/STRUCTURE.md +++ /dev/null @@ -1,180 +0,0 @@ -# Structuring Codegen Examples - -This guide explains how to structure examples for the Codegen library. A well-structured example helps both humans and AI understand the code's purpose and how to use it effectively. - -## Core Principles - -1. **Single Responsibility**: Each example should demonstrate one clear use case -1. **Self-Contained**: Examples should work independently with minimal setup -1. **Clear Structure**: Follow a consistent file organization pattern -1. **Good Documentation**: Include README.md with clear explanations and examples - -## Standard File Structure - -``` -example-name/ -├── README.md # Documentation and usage examples -├── run.py # Main implementation -└── input_repo/ # (Optional) Sample code for transformation -``` - -## Code Organization in `run.py` - -Your `run.py` should follow this structure, demonstrated well in the `generate_training_data` example: - -1. **Imports at the top** - - ```python - import codegen - from codegen import Codebase - from codegen.sdk.core import Function - # ... other imports - ``` - -1. **Utility functions with clear docstrings** - - ```python - def hop_through_imports(imp: Import) -> Symbol | ExternalModule: - """Finds the root symbol for an import""" - # Implementation... - ``` - -1. **Main Codegen function with decorator** - - ```python - @codegen.function("your-function-name") - def run(codebase: Codebase): - """Clear docstring explaining what the function does. - - Include: - 1. Purpose of the function - 2. Key steps or transformations - 3. Expected output - """ - # Implementation... - ``` - -1. **Entry point at bottom** - - ```python - if __name__ == "__main__": - # Initialize codebase - # Run transformation - # Save/display results - ``` - -## Working with Codebases - -Prefer using public repositories for examples when possible. However, sometimes you need a specific code structure to demonstrate a concept clearly. Here's how to handle both cases: - -```python -# Preferred: Use a well-known public repo that demonstrates the concept well -codebase = Codebase.from_repo("fastapi/fastapi") - -# Alternative: Create a minimal example repo when you need specific code structure -# 1. Create an input_repo/ directory in your example -# 2. Add minimal code that clearly demonstrates the transformation -codebase = Codebase("./input_repo") -``` - -For example: - -``` -example-name/ -├── README.md -├── run.py -└── input_repo/ # Your minimal example code - ├── app.py - └── utils.py -``` - -Choose between these approaches based on: - -1. Can you find a public repo that clearly shows the concept? -1. Is the transformation specific enough that a custom example would be clearer? -1. Would a minimal example be more educational than a complex real-world one? - -## Best Practices - -1. **Function Decorator** - - - Always use `@codegen.function()` with a descriptive name - - Name should match the example's purpose - -1. **Utility Functions** - - - Break down complex logic into smaller, focused functions - - Each utility should demonstrate one clear concept - - Include type hints and docstrings - -1. **Main Function** - - - Name it `run()` for consistency - - Include comprehensive docstring explaining the transformation - - Return meaningful data that can be used programmatically - -1. **Entry Point** - - - Include a `__name__ == "__main__"` block - - Show both initialization and execution - - Add progress messages for better UX - -1. **Error Handling** - - - Include appropriate error handling for common cases - - Provide clear error messages - -## Example Reference Implementation - -The `generate_training_data` example demonstrates these principles well: - -```python -# Focused utility function -def get_function_context(function) -> dict: - """Get the implementation, dependencies, and usages of a function.""" - # Clear, focused implementation... - - -# Main transformation with decorator -@codegen.function("generate-training-data") -def run(codebase: Codebase): - """Generate training data using a node2vec-like approach... - - This codemod: - 1. Finds all functions... - 2. For each function... - 3. Outputs structured JSON... - """ - # Clear implementation with good structure... - - -# Clean entry point -if __name__ == "__main__": - print("Initializing codebase...") - codebase = Codebase.from_repo("fastapi/fastapi") - run(codebase) - # ... rest of execution -``` - -## Documentation Requirements - -Every example should include: - -1. **README.md** - - Clear explanation of purpose - - Explains key syntax and program function - - Code examples showing the transformation (before/after) - - If using `input_repo/`, explain its structure and contents - - Output format (if applicable) - - Setup and running instructions - -## Testing Your Example - -Before submitting: - -1. Test with a fresh environment -1. Verify all dependencies are listed -1. Ensure the example runs with minimal setup -1. Check that documentation is clear and accurate - -Remember: Your example might be used by both humans and AI to understand Codegen's capabilities. Clear structure and documentation help everyone use your code effectively. diff --git a/codegen-examples/examples/ai_impact_analysis/README.md b/codegen-examples/examples/ai_impact_analysis/README.md deleted file mode 100644 index e34e1a8af..000000000 --- a/codegen-examples/examples/ai_impact_analysis/README.md +++ /dev/null @@ -1,124 +0,0 @@ -# AI Impact Analysis - -This script analyzes a codebase to measure and report the impact of AI-generated code contributions. It provides detailed insights about AI vs human contributions, helping teams understand the role of AI in their development process. - -## Features - -- **Repository Analysis**: Automatically detects and analyzes git repositories: - - - Uses current directory if it's a git repo - - - Searches parent directories for a git repo - - - Falls back to cloning a specified repository if needed - - ```python - # Basic repository setup - repo_path = os.getcwd() - repo_config = RepoConfig.from_repo_path(repo_path) - repo_operator = RepoOperator(repo_config=repo_config) - project = ProjectConfig.from_repo_operator(repo_operator=repo_operator, programming_language=ProgrammingLanguage.PYTHON) - codebase = Codebase(projects=[project]) - ``` - -- **Comprehensive Statistics**: - - - Total number of commits and AI vs human contribution percentages - - Files with significant AI contribution (>50%) - - AI-touched symbols and their impact - - Detailed contributor breakdown (human and AI contributors) - - ```python - # Run the analysis - ai_authors = ["github-actions[bot]", "dependabot[bot]"] - results = analyze_ai_impact(codebase, ai_authors) - - # Access statistics - stats = results["stats"] - print(f"Total commits: {stats['total_commits']}") - print(f"AI commits: {stats['ai_commits']} ({stats['ai_percentage']:.1f}%)") - print(f"Files with >50% AI: {stats['ai_file_count']} of {stats['total_file_count']}") - - # View contributors - for author, count in results["contributors"]: - is_ai = any(ai_name in author for ai_name in ai_authors) - print(f"{'🤖' if is_ai else '👤'} {author}: {count} commits") - ``` - -- **High-Impact Code Detection**: - - - Identifies AI-written code that is heavily used by other parts of the codebase - - Shows dependency relationships for AI-contributed code - - ```python - # Access high-impact AI symbols - for symbol in results["high_impact_symbols"]: - print(f"Symbol: {symbol['name']} ({symbol['filepath']})") - print(f"Used by {symbol['usage_count']} other symbols") - print(f"Last edited by: {symbol['last_editor']}") - - # View top AI-contributed files - for file_path, percentage in stats["top_ai_files"]: - print(f"{file_path}: {percentage:.1f}% AI contribution") - ``` - -- **Detailed Attribution**: - - - Maps symbols to git history - - Tracks last editor and complete editor history for each symbol - - Flags AI-authored symbols - - ```python - # Get attribution information for a specific symbol - symbol = codebase.get_symbol("path/to/file.py:MyClass.my_method") - - # Access attribution data - print(f"Last editor: {symbol.last_editor}") - print(f"Editor history: {symbol.editor_history}") - print(f"AI authored: {symbol.is_ai_authored}") - - # Find all AI-authored symbols - ai_symbols = [s for s in codebase.get_symbols() if s.is_ai_authored] - for symbol in ai_symbols: - print(f"AI symbol: {symbol.name}") - ``` - -## Output - -The script generates: - -1. Console output with summary statistics -1. Detailed analysis in `ai_impact_analysis.json` -1. Attribution information added to codebase symbols - -## Usage - -```bash -python run.py -``` - -The script will automatically: - -1. Initialize and analyze the codebase -1. Process git history -1. Generate attribution information -1. Output detailed statistics - -You can also visualize the AI impact analysis results using a dashboard. For setup and usage instructions, please see the documentation in the `/dashboard` subdirectory. - -## Symbol Attribution - -After running the analysis, symbols in the codebase will have the following attribution information: - -- `symbol.last_editor`: The last person who edited the symbol -- `symbol.editor_history`: List of all editors who have touched the symbol -- `symbol.is_ai_authored`: Boolean indicating if the symbol was authored by AI - -## Learn More - -- [Attributions](https://docs.codegen.com/tutorials/attributions) -- [Codegen Documentation](https://docs.codegen.com) - -## Contributing - -Feel free to submit issues and enhancement requests! diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/README.md b/codegen-examples/examples/ai_impact_analysis/dashboard/README.md deleted file mode 100644 index cde758b55..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/README.md +++ /dev/null @@ -1,86 +0,0 @@ -# AI Impact Analysis Dashboard - -A web dashboard for visualizing AI-generated code contributions in your codebase. This dashboard provides detailed insights about AI vs human contributions, helping understand the role of AI in a codebase development process. - -## Setup - -### Backend - -1. Install dependencies: - -```bash -uv venv -source .venv/bin/activate -uv pip install modal codegen fastapi -``` - -2. Deploy or serve the Modal endpoint: - -```bash -modal serve backend/api.py -``` - -```bash -modal deploy backend/api.py -``` - -### Frontend - -1. Install dependencies: - -```bash -cd frontend -npm install -``` - -2. Update the API endpoint: - Edit the fetch URL on line 29 in `components/repo-analysis-dashboard.tsx` to point to your Modal endpoint: - -```bash - fetch(`[your-modal-deployment-url]/analyze?repo_full_name=${repoFullName}`, { - method: 'POST', - }) -``` - -3. Start the development server: - -```bash -npm run dev -``` - -## Usage - -1. Visit the dashboard in your browser (default: http://localhost:3000) -1. Enter a GitHub repository name (format: username/repo) -1. Click "Analyze Repo" to generate insights - -The dashboard will display: - -- Summary statistics of AI contributions -- Monthly contribution timeline -- Top files with AI contributions -- High-impact AI-authored symbols -- Contributor breakdown visualization - -## Architecture - -- **Backend**: Modal-deployed FastAPI service that: - - - Clones and analyzes repositories - - Processes git history - - Calculates AI impact metrics - - Returns structured analysis data - -- **Frontend**: Next.js application with: - - - Interactive charts - - Visualized AI impact metrics - -## Learn More - -- [AI Impact Analysis Documentation](https://docs.codegen.com/tutorials/attributions) -- [Codegen Documentation](https://docs.codegen.com) - -## Contributing - -Feel free to submit issues and enhancement requests! diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/backend/api.py b/codegen-examples/examples/ai_impact_analysis/dashboard/backend/api.py deleted file mode 100644 index ddb08115d..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/backend/api.py +++ /dev/null @@ -1,54 +0,0 @@ -from codegen import Codebase -from codegen.extensions.attribution.main import ( - add_attribution_to_symbols, - analyze_ai_impact, -) -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -import modal - -image = modal.Image.debian_slim().apt_install("git").pip_install("codegen", "fastapi", "intervaltree", "pygit2", "requests") - -app = modal.App(name="ai-impact-analysis", image=image) - -fastapi_app = FastAPI() - -fastapi_app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@fastapi_app.post("/analyze") -async def analyze(repo_full_name: str): - codebase = Codebase.from_repo(repo_full_name=repo_full_name, language="python", full_history=True) - - print("🤖 Analyzing AI impact on codebase...") - - ai_authors = [ - "renovate[bot]", - "dependabot[bot]", - "github-actions[bot]", - "devin-ai-integration[bot]", - ] - - results = analyze_ai_impact(codebase, ai_authors) - - print("\n🏷️ Adding attribution information to symbols...") - add_attribution_to_symbols(codebase, ai_authors) - print("✅ Attribution information added to symbols") - - return results - - -@app.function(image=image) -@modal.asgi_app() -def fastapi_modal_app(): - return fastapi_app - - -if __name__ == "__main__": - app.deploy("ai-impact-analysis") diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/favicon.ico b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/favicon.ico deleted file mode 100644 index fd8587746..000000000 Binary files a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/favicon.ico and /dev/null differ diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/globals.css b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/globals.css deleted file mode 100644 index 1535f872d..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/globals.css +++ /dev/null @@ -1,76 +0,0 @@ -@tailwind base; -@tailwind components; -@tailwind utilities; - -@layer base { - :root { - --background: 0 0% 100%; - --foreground: 222.2 84% 4.9%; - - --card: 0 0% 100%; - --card-foreground: 222.2 84% 4.9%; - - --popover: 0 0% 100%; - --popover-foreground: 222.2 84% 4.9%; - - --primary: 221.2 83.2% 53.3%; - --primary-foreground: 210 40% 98%; - - --secondary: 210 40% 96.1%; - --secondary-foreground: 222.2 47.4% 11.2%; - - --muted: 210 40% 96.1%; - --muted-foreground: 215.4 16.3% 46.9%; - - --accent: 210 40% 96.1%; - --accent-foreground: 222.2 47.4% 11.2%; - - --destructive: 0 84.2% 60.2%; - --destructive-foreground: 210 40% 98%; - - --border: 214.3 31.8% 91.4%; - --input: 214.3 31.8% 91.4%; - --ring: 221.2 83.2% 53.3%; - - --radius: 0.5rem; - } - - .dark { - --background: 222.2 84% 4.9%; - --foreground: 210 40% 98%; - - --card: 222.2 84% 4.9%; - --card-foreground: 210 40% 98%; - - --popover: 222.2 84% 4.9%; - --popover-foreground: 210 40% 98%; - - --primary: 217.2 91.2% 59.8%; - --primary-foreground: 222.2 47.4% 11.2%; - - --secondary: 217.2 32.6% 17.5%; - --secondary-foreground: 210 40% 98%; - - --muted: 217.2 32.6% 17.5%; - --muted-foreground: 215 20.2% 65.1%; - - --accent: 217.2 32.6% 17.5%; - --accent-foreground: 210 40% 98%; - - --destructive: 0 62.8% 30.6%; - --destructive-foreground: 210 40% 98%; - - --border: 217.2 32.6% 17.5%; - --input: 217.2 32.6% 17.5%; - --ring: 224.3 76.3% 48%; - } -} - -@layer base { - * { - @apply border-border; - } - body { - @apply bg-background text-foreground; - } -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/layout.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/layout.tsx deleted file mode 100644 index 264632940..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/layout.tsx +++ /dev/null @@ -1,34 +0,0 @@ -import type { Metadata } from "next"; -import { Inter } from "next/font/google"; -import type React from "react"; -import "./globals.css"; -import { ThemeProvider } from "@/components/theme-provider"; - -const inter = Inter({ subsets: ["latin"] }); - -export const metadata: Metadata = { - title: "AI Code Impact Analysis", -}; - -export default function RootLayout({ - children, -}: { - children: React.ReactNode; -}) { - return ( - - - - {children} - - - - ); -} - -import "./globals.css"; diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/page.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/page.tsx deleted file mode 100644 index 5b048bbdf..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/app/page.tsx +++ /dev/null @@ -1,9 +0,0 @@ -import { RepoAnalysisDashboard } from "@/components/repo-analysis-dashboard"; - -export default function Home() { - return ( -
- -
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components.json b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components.json deleted file mode 100644 index 7f48f98e8..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "$schema": "https://ui.shadcn.com/schema.json", - "style": "default", - "rsc": true, - "tsx": true, - "tailwind": { - "config": "tailwind.config.ts", - "css": "app/globals.css", - "baseColor": "neutral", - "cssVariables": true, - "prefix": "" - }, - "aliases": { - "components": "@/components", - "utils": "@/lib/utils", - "ui": "@/components/ui", - "lib": "@/lib", - "hooks": "@/hooks" - }, - "iconLibrary": "lucide" -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/contribution-timeline.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/contribution-timeline.tsx deleted file mode 100644 index e6017ba9c..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/contribution-timeline.tsx +++ /dev/null @@ -1,83 +0,0 @@ -"use client"; - -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from "@/components/ui/card"; -import type { Timeline } from "@/lib/types"; -import { - Bar, - BarChart, - ResponsiveContainer, - Tooltip, - XAxis, - YAxis, -} from "recharts"; - -interface ContributionTimelineProps { - timeline: Timeline[]; -} - -export function ContributionTimeline({ timeline }: ContributionTimelineProps) { - return ( - - - AI Contribution Timeline - Monthly AI contributions over time - - - - - - `${value}`} - /> - { - if (active && payload && payload.length) { - return ( -
-
-
- - Date - - - {payload[0].payload.date} - -
-
- - Commits - - - {payload[0].value} - -
-
-
- ); - } - return null; - }} - /> - -
-
-
-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/contributors-breakdown.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/contributors-breakdown.tsx deleted file mode 100644 index fb54b7fec..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/contributors-breakdown.tsx +++ /dev/null @@ -1,131 +0,0 @@ -"use client"; - -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from "@/components/ui/card"; -import { ScrollArea } from "@/components/ui/scroll-area"; -import { - Cell, - Legend, - Pie, - PieChart, - ResponsiveContainer, - Tooltip, -} from "recharts"; - -interface ContributorsBreakdownProps { - contributors: [string, number][]; -} - -export function ContributorsBreakdown({ - contributors, -}: ContributorsBreakdownProps) { - // Take top 5 contributors for the chart - const topContributors = contributors.slice(0, 5); - const otherContributors = contributors.slice(5); - const otherCount = otherContributors.reduce( - (sum, [_, count]) => sum + count, - 0, - ); - - const chartData = [ - ...topContributors.map(([name, count]) => ({ - name: name.split(" ")[0], // Just use first name for chart - fullName: name, - count, - })), - otherContributors.length > 0 - ? { name: "Others", fullName: "Other Contributors", count: otherCount } - : null, - ].filter(Boolean); - - const COLORS = [ - "#3b82f6", - "#10b981", - "#f59e0b", - "#ef4444", - "#8b5cf6", - "#6b7280", - ]; - - return ( - - - Contributors Breakdown - Top contributors by commit count - - -
-
- - - - {chartData.map((entry, index) => ( - - ))} - - [ - value, - props.payload.fullName, - ]} - contentStyle={{ - backgroundColor: "white", - borderColor: "#e2e8f0", - borderRadius: "0.375rem", - }} - /> - - - -
-
- -
- {contributors.slice(0, 10).map(([name, count], index) => ( -
-
-
- - {name.split(" ")[0]} - -
-
{count}
-
- ))} - {contributors.length > 10 && ( -
- +{contributors.length - 10} more contributors -
- )} -
- -
-
- - - ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/dashboard-header.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/dashboard-header.tsx deleted file mode 100644 index 3fa854491..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/dashboard-header.tsx +++ /dev/null @@ -1,18 +0,0 @@ -import { Code2 } from "lucide-react"; - -export function DashboardHeader() { - return ( -
-
-
-

- AI Code Impact Analysis -

-
-

- Analyze AI-generated code contributions in your repository -

-
-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/high-impact-symbols.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/high-impact-symbols.tsx deleted file mode 100644 index db1fe51aa..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/high-impact-symbols.tsx +++ /dev/null @@ -1,58 +0,0 @@ -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from "@/components/ui/card"; -import { ScrollArea } from "@/components/ui/scroll-area"; -import type { HighImpactSymbol } from "@/lib/types"; - -interface HighImpactSymbolsProps { - symbols: HighImpactSymbol[]; -} - -export function HighImpactSymbols({ symbols }: HighImpactSymbolsProps) { - return ( - - - High-Impact AI Symbols - - AI-written code with significant usage - - - - -
- {symbols.length > 0 ? ( - symbols.map((symbol) => ( -
-
-
{symbol.name}
-
- Used by {symbol.usage_count} symbols -
-
-
- {symbol.filepath} -
-
- Last edited by:{" "} - {symbol.last_editor} -
-
- )) - ) : ( -
- No high-impact AI symbols found -
- )} -
-
-
-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/loading-screen.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/loading-screen.tsx deleted file mode 100644 index 089cdf833..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/loading-screen.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { Loader2 } from "lucide-react"; - -export function LoadingScreen() { - return ( -
-
- -

- Analyzing Repository -

-

This may take a few seconds...

-
-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/repo-analysis-dashboard.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/repo-analysis-dashboard.tsx deleted file mode 100644 index 4cf6ddb71..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/repo-analysis-dashboard.tsx +++ /dev/null @@ -1,113 +0,0 @@ -"use client"; - -import { ContributionTimeline } from "@/components/contribution-timeline"; -import { ContributorsBreakdown } from "@/components/contributors-breakdown"; -import { DashboardHeader } from "@/components/dashboard-header"; -import { HighImpactSymbols } from "@/components/high-impact-symbols"; -import { LoadingScreen } from "@/components/loading-screen"; -import { SummaryCards } from "@/components/summary-cards"; -import { TopAIFiles } from "@/components/top-ai-files"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { Input } from "@/components/ui/input"; -import type { AnalysisData } from "@/lib/types"; -import { GitBranch, Loader2 } from "lucide-react"; -import { useState } from "react"; - -export function RepoAnalysisDashboard() { - const [data, setData] = useState(null); - const [loading, setLoading] = useState(false); - const [repoUrl, setRepoUrl] = useState(""); - - const handleSubmit = (e: React.FormEvent) => { - e.preventDefault(); - if (repoUrl.trim()) { - setLoading(true); - const match = repoUrl.match(/(?:github\.com\/)?([^/\s]+\/[^/\s]+)/); - if (match) { - const repoFullName = match[1]; - fetch( - `[your-modal-deployment-url]/analyze?repo_full_name=${repoFullName}`, - { - method: "POST", - }, - ) - .then((response) => { - if (!response.ok) { - throw new Error("Network response was not ok"); - } - return response.json(); - }) - .then((analysisData: AnalysisData) => { - setData(analysisData); - setLoading(false); - }) - .catch((error) => { - console.error("Error analyzing repository:", error); - setLoading(false); - }); - } - } - }; - - return ( -
- {loading && } - - - - - -
-
- -
-
- - setRepoUrl(e.target.value)} - disabled={loading} - /> -
- -
-
-
-
-
- - {data && ( -
- - -
- - -
- -
- - -
-
- )} -

-

-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/summary-cards.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/summary-cards.tsx deleted file mode 100644 index dea866379..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/summary-cards.tsx +++ /dev/null @@ -1,77 +0,0 @@ -import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; -import type { AnalysisData } from "@/lib/types"; -import { BarChart3, FileCode, GitCommit, Percent } from "lucide-react"; - -interface SummaryCardsProps { - data: AnalysisData; -} - -export function SummaryCards({ data }: SummaryCardsProps) { - const { stats, ai_symbol_count, total_symbol_count } = data; - - return ( -
- - - AI Commits - - - -
- {stats.ai_commits} / {stats.total_commits} -
-

- {stats.ai_percentage.toFixed(1)}% of total commits -

-
-
- - - - AI Files - - - -
- {stats.ai_file_count} / {stats.total_file_count} -
-

- {((stats.ai_file_count / stats.total_file_count) * 100).toFixed(1)}% - of files have >50% AI contribution -

-
-
- - - - AI Symbols - - - -
- {ai_symbol_count} / {total_symbol_count} -
-

- {((ai_symbol_count / total_symbol_count) * 100).toFixed(1)}% of code - symbols -

-
-
- - - - High Impact - - - -
- {data.high_impact_symbols.length} -
-

- AI-written symbols with high usage -

-
-
-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/theme-provider.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/theme-provider.tsx deleted file mode 100644 index 020003cf9..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/theme-provider.tsx +++ /dev/null @@ -1,11 +0,0 @@ -"use client"; - -import { - ThemeProvider as NextThemesProvider, - type ThemeProviderProps, -} from "next-themes"; -import * as React from "react"; - -export function ThemeProvider({ children, ...props }: ThemeProviderProps) { - return {children}; -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/top-ai-files.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/top-ai-files.tsx deleted file mode 100644 index 67a5472a8..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/top-ai-files.tsx +++ /dev/null @@ -1,48 +0,0 @@ -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from "@/components/ui/card"; -import { Progress } from "@/components/ui/progress"; -import { ScrollArea } from "@/components/ui/scroll-area"; - -interface TopAIFilesProps { - files: [string, number][]; -} - -export function TopAIFiles({ files }: TopAIFilesProps) { - return ( - - - Top AI-Contributed Files - - Files with highest AI contribution percentage - - - - -
- {files.map(([filepath, percentage]) => ( -
-
-
- {filepath.split("/").pop()} -
-
- {percentage.toFixed(1)}% -
-
- -
- {filepath} -
-
- ))} -
-
-
-
- ); -} diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/button.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/button.tsx deleted file mode 100644 index 91b784a28..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/button.tsx +++ /dev/null @@ -1,56 +0,0 @@ -import { Slot } from "@radix-ui/react-slot"; -import { type VariantProps, cva } from "class-variance-authority"; -import * as React from "react"; - -import { cn } from "@/lib/utils"; - -const buttonVariants = cva( - "inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0", - { - variants: { - variant: { - default: "bg-primary text-primary-foreground hover:bg-primary/90", - destructive: - "bg-destructive text-destructive-foreground hover:bg-destructive/90", - outline: - "border border-input bg-background hover:bg-accent hover:text-accent-foreground", - secondary: - "bg-secondary text-secondary-foreground hover:bg-secondary/80", - ghost: "hover:bg-accent hover:text-accent-foreground", - link: "text-primary underline-offset-4 hover:underline", - }, - size: { - default: "h-10 px-4 py-2", - sm: "h-9 rounded-md px-3", - lg: "h-11 rounded-md px-8", - icon: "h-10 w-10", - }, - }, - defaultVariants: { - variant: "default", - size: "default", - }, - }, -); - -export interface ButtonProps - extends React.ButtonHTMLAttributes, - VariantProps { - asChild?: boolean; -} - -const Button = React.forwardRef( - ({ className, variant, size, asChild = false, ...props }, ref) => { - const Comp = asChild ? Slot : "button"; - return ( - - ); - }, -); -Button.displayName = "Button"; - -export { Button, buttonVariants }; diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/card.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/card.tsx deleted file mode 100644 index bb368bd00..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/card.tsx +++ /dev/null @@ -1,86 +0,0 @@ -import * as React from "react"; - -import { cn } from "@/lib/utils"; - -const Card = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)); -Card.displayName = "Card"; - -const CardHeader = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)); -CardHeader.displayName = "CardHeader"; - -const CardTitle = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)); -CardTitle.displayName = "CardTitle"; - -const CardDescription = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)); -CardDescription.displayName = "CardDescription"; - -const CardContent = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)); -CardContent.displayName = "CardContent"; - -const CardFooter = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => ( -
-)); -CardFooter.displayName = "CardFooter"; - -export { - Card, - CardHeader, - CardFooter, - CardTitle, - CardDescription, - CardContent, -}; diff --git a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/chart.tsx b/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/chart.tsx deleted file mode 100644 index aa3d5f99a..000000000 --- a/codegen-examples/examples/ai_impact_analysis/dashboard/frontend/components/ui/chart.tsx +++ /dev/null @@ -1,365 +0,0 @@ -"use client"; - -import * as React from "react"; -import * as RechartsPrimitive from "recharts"; - -import { cn } from "@/lib/utils"; - -// Format: { THEME_NAME: CSS_SELECTOR } -const THEMES = { light: "", dark: ".dark" } as const; - -export type ChartConfig = { - [k in string]: { - label?: React.ReactNode; - icon?: React.ComponentType; - } & ( - | { color?: string; theme?: never } - | { color?: never; theme: Record } - ); -}; - -type ChartContextProps = { - config: ChartConfig; -}; - -const ChartContext = React.createContext(null); - -function useChart() { - const context = React.useContext(ChartContext); - - if (!context) { - throw new Error("useChart must be used within a "); - } - - return context; -} - -const ChartContainer = React.forwardRef< - HTMLDivElement, - React.ComponentProps<"div"> & { - config: ChartConfig; - children: React.ComponentProps< - typeof RechartsPrimitive.ResponsiveContainer - >["children"]; - } ->(({ id, className, children, config, ...props }, ref) => { - const uniqueId = React.useId(); - const chartId = `chart-${id || uniqueId.replace(/:/g, "")}`; - - return ( - -
- - - {children} - -
-
- ); -}); -ChartContainer.displayName = "Chart"; - -const ChartStyle = ({ id, config }: { id: string; config: ChartConfig }) => { - const colorConfig = Object.entries(config).filter( - ([_, config]) => config.theme || config.color, - ); - - if (!colorConfig.length) { - return null; - } - - return ( - - - -

codegen

- - - """ - - async def handle_slack_event(self, request: Request): - """Handle incoming Slack events.""" - payload = await request.json() - return await self.slack.handle(payload) - - async def handle_github_event(self, request: Request): - """Handle incoming GitHub events.""" - payload = await request.json() - return await self.github.handle(payload, request) - - async def handle_linear_event(self, request: Request): - """Handle incoming Linear events.""" - payload = await request.json() - return await self.linear.handle(payload) - - def _setup_routes(self): - """Set up the FastAPI routes for different event types.""" - - @self.app.get("/", response_class=HTMLResponse) - async def _root(): - return await self.root() - - # @self.app.post("/{org}/{repo}/slack/events") - @self.app.post("/slack/events") - async def _handle_slack_event(request: Request): - return await self.handle_slack_event(request) - - # @self.app.post("/{org}/{repo}/github/events") - @self.app.post("/github/events") - async def _handle_github_event(request: Request): - return await self.handle_github_event(request) - - # @self.app.post("/{org}/{repo}/linear/events") - @self.app.post("/linear/events") - async def handle_linear_event(request: Request): - return await self.handle_linear_event(request) - - def run(self, host: str = "0.0.0.0", port: int = 8000, **kwargs): - """Run the FastAPI application.""" - import uvicorn - - uvicorn.run(self.app, host=host, port=port, **kwargs) diff --git a/src/codegen/extensions/events/github.py b/src/codegen/extensions/events/github.py deleted file mode 100644 index d17b16aef..000000000 --- a/src/codegen/extensions/events/github.py +++ /dev/null @@ -1,138 +0,0 @@ -import logging -import os -from typing import Any, Callable, TypeVar - -from fastapi import Request -from github import Github -from pydantic import BaseModel - -from codegen.extensions.events.interface import EventHandlerManagerProtocol -from codegen.extensions.github.types.base import GitHubInstallation, GitHubWebhookPayload -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) -logger.setLevel(logging.DEBUG) - - -# Type variable for event types -T = TypeVar("T", bound=BaseModel) - - -class GitHub(EventHandlerManagerProtocol): - def __init__(self, app): - self.app = app - self.registered_handlers = {} - - @property - def client(self) -> Github: - if not os.getenv("GITHUB_TOKEN"): - msg = "GITHUB_TOKEN is not set" - logger.exception(msg) - raise ValueError(msg) - if not self._client: - self._client = Github(os.getenv("GITHUB_TOKEN")) - return self._client - - def unsubscribe_all_handlers(self): - logger.info("[HANDLERS] Clearing all handlers") - self.registered_handlers.clear() - - def event(self, event_name: str): - """Decorator for registering a GitHub event handler. - - Example: - @app.github.event('push') - def handle_push(event: PushEvent): # Can be typed with Pydantic model - logger.info(f"Received push to {event.ref}") - - @app.github.event('pull_request:opened') - def handle_pr(event: dict): # Or just use dict for raw event - logger.info(f"Received PR") - """ - logger.info(f"[EVENT] Registering handler for {event_name}") - - def register_handler(func: Callable[[T], Any]): - # Get the type annotation from the first parameter - event_type = func.__annotations__.get("event") - func_name = func.__qualname__ - logger.info(f"[EVENT] Registering function {func_name} for {event_name}") - - def new_func(raw_event: dict): - # Only validate if a Pydantic model was specified - if event_type and issubclass(event_type, BaseModel): - try: - parsed_event = event_type.model_validate(raw_event) - return func(parsed_event) - except Exception as e: - logger.exception(f"Error parsing event: {e}") - raise - else: - # Pass through raw dict if no type validation needed - return func(raw_event) - - self.registered_handlers[event_name] = new_func - return new_func - - return register_handler - - async def handle(self, event: dict, request: Request | None = None) -> dict: - """Handle both webhook events and installation callbacks.""" - logger.info("[HANDLER] Handling GitHub event") - - # Check if this is an installation event - if "installation_id" in event and "code" in event: - installation = GitHubInstallation.model_validate(event) - logger.info("=====[GITHUB APP INSTALLATION]=====") - logger.info(f"Code: {installation.code}") - logger.info(f"Installation ID: {installation.installation_id}") - logger.info(f"Setup Action: {installation.setup_action}") - return { - "message": "GitHub app installation details received", - "details": { - "code": installation.code, - "installation_id": installation.installation_id, - "setup_action": installation.setup_action, - }, - } - - # Extract headers for webhook events if request is provided - headers = {} - if request: - headers = { - "x-github-event": request.headers.get("x-github-event"), - "x-github-delivery": request.headers.get("x-github-delivery"), - "x-github-hook-id": request.headers.get("x-github-hook-id"), - "x-github-hook-installation-target-id": request.headers.get("x-github-hook-installation-target-id"), - "x-github-hook-installation-target-type": request.headers.get("x-github-hook-installation-target-type"), - } - - # Handle webhook events - try: - # For simulation, use event data directly - if not request: - event_type = f"pull_request:{event['action']}" if "action" in event else event.get("type", "unknown") - if event_type not in self.registered_handlers: - logger.info(f"[HANDLER] No handler found for event type: {event_type}") - return {"message": "Event type not handled"} - else: - logger.info(f"[HANDLER] Handling event: {event_type}") - handler = self.registered_handlers[event_type] - return handler(event) - - # For actual webhooks, use the full payload - webhook = GitHubWebhookPayload.model_validate({"headers": headers, "event": event}) - event_type = webhook.headers.event_type - action = webhook.event.action - full_event_type = f"{event_type}:{action}" if action else event_type - - if full_event_type not in self.registered_handlers: - logger.info(f"[HANDLER] No handler found for event type: {full_event_type}") - return {"message": "Event type not handled"} - else: - logger.info(f"[HANDLER] Handling event: {full_event_type}") - handler = self.registered_handlers[full_event_type] - return handler(event) - - except Exception as e: - logger.exception(f"Error handling webhook: {e}") - raise diff --git a/src/codegen/extensions/events/github_types.py b/src/codegen/extensions/events/github_types.py deleted file mode 100644 index fd3f62536..000000000 --- a/src/codegen/extensions/events/github_types.py +++ /dev/null @@ -1,62 +0,0 @@ -from datetime import datetime -from typing import Optional - - -class GitHubRepository: - id: int - node_id: str - name: str - full_name: str - private: bool - - -class GitHubAccount: - login: str - id: int - node_id: str - avatar_url: str - type: str - site_admin: bool - # Other URL fields omitted for brevity - user_view_type: str - - -class GitHubInstallation: - id: int - client_id: str - account: GitHubAccount - repository_selection: str - access_tokens_url: str - repositories_url: str - html_url: str - app_id: int - app_slug: str - target_id: int - target_type: str - permissions: dict[str, str] # e.g. {'actions': 'write', 'checks': 'read', ...} - events: list[str] - created_at: datetime - updated_at: datetime - single_file_name: Optional[str] - has_multiple_single_files: bool - single_file_paths: list[str] - suspended_by: Optional[str] - suspended_at: Optional[datetime] - - -class GitHubUser: - login: str - id: int - node_id: str - avatar_url: str - type: str - site_admin: bool - # Other URL fields omitted for brevity - - -class GitHubInstallationEvent: - action: str - installation: GitHubInstallation - repositories: list[GitHubRepository] - requester: Optional[dict] - sender: GitHubUser diff --git a/src/codegen/extensions/events/interface.py b/src/codegen/extensions/events/interface.py deleted file mode 100644 index 998afa9d2..000000000 --- a/src/codegen/extensions/events/interface.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Protocol - -import modal # deptry: ignore - - -class EventHandlerManagerProtocol(Protocol): - def subscribe_handler_to_webhook(self, func_name: str, modal_app: modal.App, event_name): - pass - - def unsubscribe_handler_to_webhook(self, func_name: str, modal_app: modal.App, event_name): - pass - - def unsubscribe_all_handlers(self): - pass diff --git a/src/codegen/extensions/events/linear.py b/src/codegen/extensions/events/linear.py deleted file mode 100644 index 4fe5b2e91..000000000 --- a/src/codegen/extensions/events/linear.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging -from typing import Any, Callable, TypeVar - -from pydantic import BaseModel - -from codegen.extensions.events.interface import EventHandlerManagerProtocol -from codegen.extensions.linear.types import LinearEvent -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) -logger.setLevel(logging.DEBUG) - -# Type variable for event types -T = TypeVar("T", bound=BaseModel) - - -class Linear(EventHandlerManagerProtocol): - def __init__(self, app): - self.app = app - self.registered_handlers = {} - - def unsubscribe_all_handlers(self): - logger.info("[HANDLERS] Clearing all handlers") - self.registered_handlers.clear() - - def event(self, event_name: str): - """Decorator for registering a Linear event handler. - - Args: - event_name: The type of event to handle (e.g. 'Issue', 'Comment') - """ - logger.info(f"[EVENT] Registering handler for {event_name}") - - def register_handler(func: Callable[[LinearEvent], Any]): - func_name = func.__qualname__ - logger.info(f"[EVENT] Registering function {func_name} for {event_name}") - - def new_func(raw_event: dict): - # Get event type from payload - event_type = raw_event.get("type") - if event_type != event_name: - logger.info(f"[HANDLER] Event type mismatch: expected {event_name}, got {event_type}") - return None - - # Parse event into LinearEvent type - event = LinearEvent.model_validate(raw_event) - return func(event) - - self.registered_handlers[event_name] = new_func - return func - - return register_handler - - async def handle(self, event: dict) -> dict: - """Handle incoming Linear events. - - Args: - event: The event payload from Linear - - Returns: - Response dictionary - """ - logger.info("[HANDLER] Handling Linear event") - - try: - # Extract event type - event_type = event.get("type") - if not event_type: - logger.info("[HANDLER] No event type found in payload") - return {"message": "Event type not found"} - - if event_type not in self.registered_handlers: - logger.info(f"[HANDLER] No handler found for event type: {event_type}") - return {"message": "Event handled successfully"} - else: - logger.info(f"[HANDLER] Handling event: {event_type}") - handler = self.registered_handlers[event_type] - result = handler(event) - if hasattr(result, "__await__"): - result = await result - return result - - except Exception as e: - logger.exception(f"Error handling Linear event: {e}") - return {"error": f"Failed to handle event: {e!s}"} diff --git a/src/codegen/extensions/events/modal/base.py b/src/codegen/extensions/events/modal/base.py deleted file mode 100644 index 64bdf5b28..000000000 --- a/src/codegen/extensions/events/modal/base.py +++ /dev/null @@ -1,169 +0,0 @@ -import logging -import os -from typing import Literal - -import modal -from fastapi import Request - -from codegen.extensions.events.codegen_app import CodegenApp -from codegen.extensions.events.modal.request_util import fastapi_request_adapter -from codegen.git.clients.git_repo_client import GitRepoClient -from codegen.git.schemas.repo_config import RepoConfig - -logging.basicConfig(level=logging.INFO, force=True) -logger = logging.getLogger(__name__) - -# refactor this to be a config -DEFAULT_SNAPSHOT_DICT_ID = "codegen-events-codebase-snapshots" - - -class EventRouterMixin: - """This class is intended to be registered as a modal Class - and will be used to route events to the correct handler. - - Usage: - @codegen_events_app.cls(image=base_image, secrets=[modal.Secret.from_dotenv()]) - class CustomEventAPI(EventRouterMixin): - pass - - """ - - snapshot_index_id: str = DEFAULT_SNAPSHOT_DICT_ID - - def get_event_handler_cls(self) -> modal.Cls: - """Lookup the Modal Class where the event handlers are defined""" - msg = "Subclasses must implement this method" - raise NotImplementedError(msg) - - async def handle_event(self, org: str, repo: str, provider: Literal["slack", "github", "linear"], request: Request): - repo_config = RepoConfig( - name=repo, - full_name=f"{org}/{repo}", - ) - - repo_snapshotdict = modal.Dict.from_name(self.snapshot_index_id, {}, create_if_missing=True) - - last_snapshot_commit = repo_snapshotdict.get(f"{org}/{repo}", None) - - if last_snapshot_commit is None: - git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"]) - branch = git_client.get_branch_safe(git_client.default_branch) - last_snapshot_commit = branch.commit.sha if branch and branch.commit else None - - Klass = self.get_event_handler_cls() - klass = Klass(repo_org=org, repo_name=repo, commit=last_snapshot_commit) - - request_payload = await request.json() - request_headers = dict(request.headers) - request_headers.pop("host", None) # Remove host header if present - - if provider == "slack": - return klass.proxy_event.remote(f"{org}/{repo}/slack/events", payload=request_payload, headers=request_headers) - elif provider == "github": - return klass.proxy_event.remote(f"{org}/{repo}/github/events", payload=request_payload, headers=request_headers) - elif provider == "linear": - return klass.proxy_event.remote(f"{org}/{repo}/linear/events", payload=request_payload, headers=request_headers) - else: - msg = f"Invalid provider: {provider}" - raise ValueError(msg) - - def refresh_repository_snapshots(self, snapshot_index_id: str): - """Refresh the latest snapshot for all repositories in the dictionary.""" - # Get all repositories from the modal.Dict - repo_dict = modal.Dict.from_name(snapshot_index_id, {}, create_if_missing=True) - - for repo_full_name in repo_dict.keys(): - try: - # Parse the repository full name to get org and repo - org, repo = repo_full_name.split("/") - - # Create a RepoConfig for the repository - repo_config = RepoConfig( - name=repo, - full_name=repo_full_name, - ) - - # Initialize the GitRepoClient to fetch the latest commit - git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"]) - - # Get the default branch and its latest commit - branch = git_client.get_branch_safe(git_client.default_branch) - commit = branch.commit.sha if branch and branch.commit else None - - if commit: - # Get the CodegenEventsApi class - Klass = self.get_event_handler_cls() - # Create an instance with the latest commit - klass = Klass(repo_org=org, repo_name=repo, commit=commit) - - # Ping the function to refresh the snapshot - result = klass.ping.remote() - - logging.info(f"Refreshed snapshot for {repo_full_name} with commit {commit}: {result}") - else: - logging.warning(f"Could not fetch latest commit for {repo_full_name}") - - except Exception as e: - logging.exception(f"Error refreshing snapshot for {repo_full_name}: {e!s}") - - -class CodebaseEventsApp: - """This class is intended to be registered as a modal Class - and will be used to register event handlers for webhook events. It includes snapshotting behavior - and should be used with CodebaseEventsAPI. - - Usage: - @app.cls(image=base_image, secrets=[modal.Secret.from_dotenv()], enable_memory_snapshot=True, container_idle_timeout=300) - class YourCustomerEventsAPP(CodebaseEventsApp): - pass - """ - - commit: str = modal.parameter(default="") - repo_org: str = modal.parameter(default="") - repo_name: str = modal.parameter(default="") - snapshot_index_id: str = DEFAULT_SNAPSHOT_DICT_ID - - def get_codegen_app(self) -> CodegenApp: - full_repo_name = f"{self.repo_org}/{self.repo_name}" - return CodegenApp(name=f"{full_repo_name}-events", repo=full_repo_name, commit=self.commit) - - @modal.enter(snap=True) - def load(self): - self.cg = self.get_codegen_app() - self.cg.parse_repo() - self.setup_handlers(self.cg) - - # TODO: if multiple snapshots are taken for the same commit, we will need to compare commit timestamps - snapshot_dict = modal.Dict.from_name(self.snapshot_index_id, {}, create_if_missing=True) - snapshot_dict.put(f"{self.repo_org}/{self.repo_name}", self.commit) - - def setup_handlers(self, cg: CodegenApp): - msg = "Subclasses must implement this method" - raise NotImplementedError(msg) - - @modal.method() - async def proxy_event(self, route: str, payload: dict, headers: dict): - logger.info(f"Handling event: {route}") - request = await fastapi_request_adapter(payload=payload, headers=headers, route=route) - - if "slack/events" in route: - response_data = await self.cg.handle_slack_event(request) - elif "github/events" in route: - response_data = await self.cg.handle_github_event(request) - elif "linear/events" in route: - response_data = await self.cg.handle_linear_event(request) - else: - msg = f"Invalid route: {route}" - raise ValueError(msg) - - return response_data - - @modal.method() - def ping(self): - logger.info(f"Pinging function with repo: {self.repo_org}/{self.repo_name} commit: {self.commit}") - return {"status": "ok"} - - @modal.asgi_app() - def fastapi_endpoint(self): - logger.info("Serving FastAPI app from class method") - return self.cg.app diff --git a/src/codegen/extensions/events/modal/request_util.py b/src/codegen/extensions/events/modal/request_util.py deleted file mode 100644 index 029bb86cf..000000000 --- a/src/codegen/extensions/events/modal/request_util.py +++ /dev/null @@ -1,46 +0,0 @@ -import json - -from fastapi import Request as FastAPIRequest - - -async def fastapi_request_adapter(payload: dict, headers: dict, route: str) -> FastAPIRequest: - # Create a FastAPI Request object from the payload and headers - # 1. Create the scope dictionary - scope = { - "type": "http", - "method": "POST", - "path": f"/{route}", - "raw_path": f"/{route}".encode(), - "query_string": b"", - "headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()], - "client": ("127.0.0.1", 0), # Default client address - } - - # 2. Create a receive function that returns the request body - body_bytes = json.dumps(payload).encode() - - async def receive(): - return { - "type": "http.request", - "body": body_bytes, - "more_body": False, - } - - # 3. Create a send function to capture the response - response_body = [] - response_status = None - response_headers = None - - async def send(message): - nonlocal response_status, response_headers - - if message["type"] == "http.response.start": - response_status = message["status"] - response_headers = message["headers"] - elif message["type"] == "http.response.body": - response_body.append(message.get("body", b"")) - - # 4. Create the request object - fastapi_request = FastAPIRequest(scope, receive) - - return fastapi_request diff --git a/src/codegen/extensions/events/slack.py b/src/codegen/extensions/events/slack.py deleted file mode 100644 index 3c184da54..000000000 --- a/src/codegen/extensions/events/slack.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -import os - -from slack_sdk import WebClient - -from codegen.extensions.events.interface import EventHandlerManagerProtocol -from codegen.extensions.slack.types import SlackWebhookPayload -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) -logger.setLevel(logging.DEBUG) - - -class Slack(EventHandlerManagerProtocol): - _client: WebClient | None = None - - def __init__(self, app): - self.registered_handlers = {} - - @property - def client(self) -> WebClient: - if not self._client: - self._client = WebClient(token=os.environ["SLACK_BOT_TOKEN"]) - return self._client - - def unsubscribe_all_handlers(self): - logger.info("[HANDLERS] Clearing all handlers") - self.registered_handlers.clear() - - async def handle(self, event_data: dict) -> dict: - """Handle incoming Slack events.""" - logger.info("[HANDLER] Handling Slack event") - - try: - # Validate and convert to SlackWebhookPayload - event = SlackWebhookPayload.model_validate(event_data) - - if event.type == "url_verification": - return {"challenge": event.challenge} - elif event.type == "event_callback" and event.event: - if event.event.type not in self.registered_handlers: - logger.info(f"[HANDLER] No handler found for event type: {event.event.type}") - return {"message": "Event handled successfully"} - else: - handler = self.registered_handlers[event.event.type] - # Since the handler might be async, await it - result = handler(event.event) - if hasattr(result, "__await__"): - result = await result - return result - else: - logger.info(f"[HANDLER] No handler found for event type: {event.type}") - return {"message": "Event handled successfully"} - - except Exception as e: - logger.exception(f"Error handling Slack event: {e}") - return {"error": f"Failed to handle event: {e!s}"} - - def event(self, event_name: str): - """Decorator for registering a Slack event handler.""" - logger.info(f"[EVENT] Registering handler for {event_name}") - - def register_handler(func): - # Register the handler with the app's registry - func_name = func.__qualname__ - logger.info(f"[EVENT] Registering function {func_name} for {event_name}") - - async def new_func(event): - # Just pass the event, handler can access client via app.slack.client - return await func(event) - - self.registered_handlers[event_name] = new_func - return func - - return register_handler diff --git a/src/codegen/extensions/github/types/__init__.py b/src/codegen/extensions/github/types/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/extensions/github/types/author.py b/src/codegen/extensions/github/types/author.py deleted file mode 100644 index 2ecdd2e8a..000000000 --- a/src/codegen/extensions/github/types/author.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - - -class GitHubAuthor(BaseModel): - name: str - email: str - username: str diff --git a/src/codegen/extensions/github/types/base.py b/src/codegen/extensions/github/types/base.py deleted file mode 100644 index 8c6bef223..000000000 --- a/src/codegen/extensions/github/types/base.py +++ /dev/null @@ -1,68 +0,0 @@ -from pydantic import BaseModel, Field - - -class GitHubUser(BaseModel): - login: str - id: int - node_id: str - type: str - - -class GitHubRepository(BaseModel): - id: int - node_id: str - name: str - full_name: str - private: bool - owner: GitHubUser - - -class GitHubIssue(BaseModel): - id: int - node_id: str - number: int - title: str - body: str | None - user: GitHubUser - state: str - comments: int - - -class GitHubPullRequest(BaseModel): - id: int - node_id: str - number: int - title: str - body: str | None - user: GitHubUser - state: str - head: dict - base: dict - merged: bool | None = None - - -class GitHubEvent(BaseModel): - action: str | None = None - issue: GitHubIssue | None = None - pull_request: GitHubPullRequest | None = None - repository: GitHubRepository - sender: GitHubUser - - -class GitHubWebhookHeaders(BaseModel): - event_type: str = Field(..., alias="x-github-event") - delivery_id: str = Field(..., alias="x-github-delivery") - hook_id: str = Field(..., alias="x-github-hook-id") - installation_target_id: str = Field(..., alias="x-github-hook-installation-target-id") - installation_target_type: str = Field(..., alias="x-github-hook-installation-target-type") - - -class GitHubWebhookPayload(BaseModel): - headers: GitHubWebhookHeaders - event: GitHubEvent - - -class GitHubInstallation(BaseModel): - code: str - installation_id: str - setup_action: str = "install" diff --git a/src/codegen/extensions/github/types/commit.py b/src/codegen/extensions/github/types/commit.py deleted file mode 100644 index 9fa27052f..000000000 --- a/src/codegen/extensions/github/types/commit.py +++ /dev/null @@ -1,17 +0,0 @@ -from pydantic import BaseModel - -from .author import GitHubAuthor - - -class GitHubCommit(BaseModel): - id: str - tree_id: str - distinct: bool - message: str - timestamp: str - url: str - author: GitHubAuthor - committer: GitHubAuthor - added: list[str] - removed: list[str] - modified: list[str] diff --git a/src/codegen/extensions/github/types/enterprise.py b/src/codegen/extensions/github/types/enterprise.py deleted file mode 100644 index ed4861ad9..000000000 --- a/src/codegen/extensions/github/types/enterprise.py +++ /dev/null @@ -1,14 +0,0 @@ -from pydantic import BaseModel - - -class GitHubEnterprise(BaseModel): - id: int - slug: str - name: str - node_id: str - avatar_url: str - description: str - website_url: str - html_url: str - created_at: str - updated_at: str diff --git a/src/codegen/extensions/github/types/events/pull_request.py b/src/codegen/extensions/github/types/events/pull_request.py deleted file mode 100644 index 7838e8e87..000000000 --- a/src/codegen/extensions/github/types/events/pull_request.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel - -from ..base import GitHubRepository, GitHubUser -from ..enterprise import GitHubEnterprise -from ..installation import GitHubInstallation -from ..label import GitHubLabel -from ..organization import GitHubOrganization -from ..pull_request import PullRequest - - -class User(BaseModel): - id: int - login: str - - -class Label(BaseModel): - id: int - node_id: str - url: str - name: str - description: str | None = None - color: str - default: bool - - -class SimplePullRequest(BaseModel): - id: int - number: int - state: str - locked: bool - title: str - user: User - body: str | None = None - labels: list[Label] = [] - created_at: str - updated_at: str - draft: bool = False - - -class PullRequestLabeledEvent(BaseModel): - """Simplified version of the PR labeled event for testing""" - - action: Literal["labeled"] - number: int - pull_request: PullRequest - label: Label - repository: dict # Simplified for now - sender: User - - -class PullRequestOpenedEvent(BaseModel): - action: str = "opened" # Always "opened" for this event - number: int - pull_request: PullRequest - repository: GitHubRepository - organization: GitHubOrganization - enterprise: GitHubEnterprise - sender: GitHubUser - installation: GitHubInstallation - - -class PullRequestUnlabeledEvent(BaseModel): - action: str - number: int - pull_request: PullRequest - label: GitHubLabel - repository: GitHubRepository - organization: GitHubOrganization - enterprise: GitHubEnterprise - sender: GitHubUser - installation: GitHubInstallation diff --git a/src/codegen/extensions/github/types/events/push.py b/src/codegen/extensions/github/types/events/push.py deleted file mode 100644 index c3a7799bf..000000000 --- a/src/codegen/extensions/github/types/events/push.py +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseModel - -from ..base import GitHubRepository, GitHubUser -from ..commit import GitHubCommit -from ..enterprise import GitHubEnterprise -from ..installation import GitHubInstallation -from ..organization import GitHubOrganization -from ..pusher import GitHubPusher - - -class PushEvent(BaseModel): - ref: str - before: str - after: str - repository: GitHubRepository - pusher: GitHubPusher - organization: GitHubOrganization - enterprise: GitHubEnterprise - sender: GitHubUser - installation: GitHubInstallation - created: bool - deleted: bool - forced: bool - base_ref: str | None = None - compare: str - commits: list[GitHubCommit] - head_commit: GitHubCommit | None = None diff --git a/src/codegen/extensions/github/types/installation.py b/src/codegen/extensions/github/types/installation.py deleted file mode 100644 index 5b8e2b9cf..000000000 --- a/src/codegen/extensions/github/types/installation.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class GitHubInstallation(BaseModel): - id: int - node_id: str diff --git a/src/codegen/extensions/github/types/label.py b/src/codegen/extensions/github/types/label.py deleted file mode 100644 index 1d91f32f9..000000000 --- a/src/codegen/extensions/github/types/label.py +++ /dev/null @@ -1,11 +0,0 @@ -from pydantic import BaseModel - - -class GitHubLabel(BaseModel): - id: int - node_id: str - url: str - name: str - color: str - default: bool - description: str | None diff --git a/src/codegen/extensions/github/types/organization.py b/src/codegen/extensions/github/types/organization.py deleted file mode 100644 index 56b64e950..000000000 --- a/src/codegen/extensions/github/types/organization.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel - - -class GitHubOrganization(BaseModel): - login: str - id: int - node_id: str - url: str - repos_url: str - events_url: str - hooks_url: str - issues_url: str - members_url: str - public_members_url: str - avatar_url: str - description: str diff --git a/src/codegen/extensions/github/types/pull_request.py b/src/codegen/extensions/github/types/pull_request.py deleted file mode 100644 index c4b58eed6..000000000 --- a/src/codegen/extensions/github/types/pull_request.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import Literal, Optional - -from pydantic import BaseModel - -from .base import GitHubRepository, GitHubUser - - -class PullRequestRef(BaseModel): - label: str - ref: str - sha: str - user: GitHubUser - repo: GitHubRepository - - -class PullRequestLinks(BaseModel): - self: dict - html: dict - issue: dict - comments: dict - review_comments: dict - review_comment: dict - commits: dict - statuses: dict - - -class Label(BaseModel): - id: int - node_id: str - url: str - name: str - description: str | None = None - color: str - default: bool - - -class PullRequest(BaseModel): - url: str - id: int - node_id: str - html_url: str - diff_url: str - patch_url: str - issue_url: str - number: int - state: str - locked: bool - title: str - user: GitHubUser - body: Optional[str] - created_at: str - updated_at: str - closed_at: Optional[str] - merged_at: Optional[str] - merge_commit_sha: Optional[str] - assignee: Optional[GitHubUser] - assignees: list[GitHubUser] - requested_reviewers: list[GitHubUser] - requested_teams: list[dict] - labels: list[Label] - milestone: Optional[dict] - draft: bool - head: PullRequestRef - base: PullRequestRef - _links: PullRequestLinks - author_association: str - auto_merge: Optional[dict] - active_lock_reason: Optional[str] - merged: bool - mergeable: Optional[bool] - rebaseable: Optional[bool] - mergeable_state: str - merged_by: Optional[GitHubUser] - comments: int - review_comments: int - maintainer_can_modify: bool - commits: int - additions: int - deletions: int - changed_files: int - - -class PullRequestLabeledEvent(BaseModel): - action: Literal["labeled"] - number: int - pull_request: PullRequest - label: Label - repository: dict # Simplified for now - sender: dict # Simplified for now diff --git a/src/codegen/extensions/github/types/push.py b/src/codegen/extensions/github/types/push.py deleted file mode 100644 index 10f44f5e7..000000000 --- a/src/codegen/extensions/github/types/push.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from .base import GitHubRepository, GitHubUser -from .commit import GitHubCommit -from .enterprise import GitHubEnterprise -from .installation import GitHubInstallation -from .organization import GitHubOrganization -from .pusher import GitHubPusher - - -class PushEvent(BaseModel): - ref: str - before: str - after: str - repository: GitHubRepository - pusher: GitHubPusher - organization: GitHubOrganization - enterprise: GitHubEnterprise - sender: GitHubUser - installation: GitHubInstallation - created: bool - deleted: bool - forced: bool - base_ref: Optional[str] - compare: str - commits: list[GitHubCommit] - head_commit: GitHubCommit diff --git a/src/codegen/extensions/github/types/pusher.py b/src/codegen/extensions/github/types/pusher.py deleted file mode 100644 index 2d52056d4..000000000 --- a/src/codegen/extensions/github/types/pusher.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class GitHubPusher(BaseModel): - name: str - email: str diff --git a/src/codegen/extensions/graph/__init__.py b/src/codegen/extensions/graph/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/extensions/graph/create_graph.py b/src/codegen/extensions/graph/create_graph.py deleted file mode 100644 index 442b2dcd6..000000000 --- a/src/codegen/extensions/graph/create_graph.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import Optional - -from codegen.extensions.graph.utils import Node, NodeLabel, Relation, RelationLabel, SimpleGraph -from codegen.sdk.code_generation.doc_utils.utils import safe_get_class -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.function import Function -from codegen.sdk.python.class_definition import PyClass - - -def create_codebase_graph(codebase): - """Create a SimpleGraph representing the codebase structure.""" - # Initialize graph - graph = SimpleGraph() - - # Track existing nodes by name to prevent duplicates - node_registry = {} # name -> node_id mapping - - def get_or_create_node(name: str, label: NodeLabel, parent_name: Optional[str] = None, properties: dict | None = None): - """Get existing node or create new one if it doesn't exist.""" - full_name = f"{parent_name}.{name}" if parent_name and parent_name != "Class" else name - if full_name in node_registry: - return graph.nodes[node_registry[full_name]] - - node = Node(name=name, full_name=full_name, label=label.value, properties=properties or {}) - node_registry[full_name] = node.id - graph.add_node(node) - return node - - def create_class_node(class_def): - """Create a node for a class definition.""" - return get_or_create_node( - name=class_def.name, - label=NodeLabel.CLASS, - properties={ - "filepath": class_def.filepath if hasattr(class_def, "filepath") else "", - "source": class_def.source if hasattr(class_def, "source") else "", - "type": "class", - }, - ) - - def create_function_node(func): - """Create a node for a function/method.""" - class_name = None - if func.is_method: - class_name = func.parent_class.name - - return get_or_create_node( - name=func.name, - label=NodeLabel.METHOD if class_name else NodeLabel.FUNCTION, - parent_name=class_name, - properties={ - "filepath": func.filepath if hasattr(func, "filepath") else "", - "is_async": func.is_async if hasattr(func, "is_async") else False, - "source": func.source if hasattr(func, "source") else "", - "type": "method" if class_name else "function", - }, - ) - - def create_function_call_node(func_call): - """Create a node for a function call.""" - func_def = func_call.function_definition - if not func_def: - return None - if isinstance(func_def, ExternalModule): - parent_class = safe_get_class(codebase, func_def.name) - if parent_class and parent_class.get_method(func_call.name): - return create_function_node(parent_class.get_method(func_call.name)) - else: - return None - - call_node = None - if isinstance(func_def, Function): - call_node = create_function_node(func_def) - - elif isinstance(func_def, Class): - call_node = create_class_node(func_def) - - return call_node - - # Process all classes - for class_def in codebase.classes: - class_node = create_class_node(class_def) - - # Process methods - methods = class_def.methods - for method in methods: - method_node = create_function_node(method) - - # Add DEFINES relation - defines_relation = Relation( - label=RelationLabel.DEFINES.value, source_id=class_node.id, target_id=method_node.id, properties={"relationship_description": "The parent class defines the method."} - ) - graph.add_relation(defines_relation) - - for call in method.function_calls: - call_node = create_function_call_node(call) - if call_node and call_node != method_node: - call_relation = Relation( - label=RelationLabel.CALLS.value, source_id=method_node.id, target_id=call_node.id, properties={"relationship_description": f"The method calls the {call_node.label}."} - ) - graph.add_relation(call_relation) - - # Add inheritance relations - if class_def.parent_classes: - for parent in class_def.parent_classes: - if not isinstance(parent, PyClass): - try: - parent = codebase.get_class(parent.name, optional=True) - if not parent: - continue - except Exception as e: - print(f"parent not found: {e}") - continue - if not hasattr(parent, "name"): - continue - parent_node = create_class_node(parent) - - inherits_relation = Relation( - label=RelationLabel.INHERITS_FROM.value, - source_id=class_node.id, - target_id=parent_node.id, - properties={"relationship_description": "The child class inherits from the parent class."}, - ) - graph.add_relation(inherits_relation) - - for func in codebase.functions: - func_node = create_function_node(func) - for call in func.function_calls: - call_node = create_function_call_node(call) - if call_node and call_node != func_node: - call_relation = Relation( - label=RelationLabel.CALLS.value, source_id=func_node.id, target_id=call_node.id, properties={"relationship_description": f"The function calls the {call_node.label}."} - ) - graph.add_relation(call_relation) - - return graph diff --git a/src/codegen/extensions/graph/main.py b/src/codegen/extensions/graph/main.py deleted file mode 100644 index c2a655b2e..000000000 --- a/src/codegen/extensions/graph/main.py +++ /dev/null @@ -1,42 +0,0 @@ -from codegen.extensions.graph.create_graph import create_codebase_graph -from codegen.extensions.graph.neo4j_exporter import Neo4jExporter -from codegen.sdk.core.codebase import Codebase - - -def visualize_codebase(codebase, neo4j_uri: str, username: str, password: str): - """Create and visualize a codebase graph in Neo4j. - - Args: - codebase: The codebase object to analyze - neo4j_uri: URI for Neo4j database - username: Neo4j username - password: Neo4j password - """ - # Create the graph using your existing function - graph = create_codebase_graph(codebase) - - # Export to Neo4j - exporter = Neo4jExporter(neo4j_uri, username, password) - try: - exporter.export_graph(graph) - print("Successfully exported graph to Neo4j") - - # Print some useful Cypher queries for visualization - print("\nUseful Cypher queries for visualization:") - print("\n1. View all nodes and relationships:") - print("MATCH (n)-[r]->(m) RETURN n, r, m") - - print("\n2. View class hierarchy:") - print("MATCH (c:Class)-[r:INHERITS_FROM]->(parent:Class) RETURN c, r, parent") - - print("\n3. View methods defined by each class:") - print("MATCH (c:Class)-[r:DEFINES]->(m:Method) RETURN c, r, m") - - finally: - exporter.close() - - -if __name__ == "__main__": - # Initialize codebase - codebase = Codebase("../../", language="python") - visualize_codebase(codebase, "bolt://localhost:7687", "neo4j", "password") diff --git a/src/codegen/extensions/graph/neo4j_exporter.py b/src/codegen/extensions/graph/neo4j_exporter.py deleted file mode 100644 index 72a499636..000000000 --- a/src/codegen/extensions/graph/neo4j_exporter.py +++ /dev/null @@ -1,49 +0,0 @@ -from neo4j import GraphDatabase - -from codegen.extensions.graph.utils import SimpleGraph - - -class Neo4jExporter: - """Class to handle exporting the codebase graph to Neo4j.""" - - def __init__(self, uri: str, username: str, password: str): - """Initialize Neo4j connection.""" - self.driver = GraphDatabase.driver(uri, auth=(username, password)) - - def close(self): - """Close the Neo4j connection.""" - self.driver.close() - - def clear_database(self): - """Clear all nodes and relationships in the database.""" - with self.driver.session() as session: - session.run("MATCH (n) DETACH DELETE n") - - def export_graph(self, graph: SimpleGraph): - """Export the SimpleGraph to Neo4j.""" - self.clear_database() - - with self.driver.session() as session: - # Create nodes - for node in graph.nodes.values(): - properties = {"name": node.name, "full_name": node.full_name, **{k: str(v) if isinstance(v, (dict, list)) else v for k, v in node.properties.items()}} - - query = f"CREATE (n:{node.label} {{{', '.join(f'{k}: ${k}' for k in properties.keys())}}})" - session.run(query, properties) - - # Create relationships - for relation in graph.relations: - source_node = graph.nodes[relation.source_id] - target_node = graph.nodes[relation.target_id] - - properties = {**{k: str(v) if isinstance(v, (dict, list)) else v for k, v in relation.properties.items()}} - - query = ( - f"MATCH (source:{source_node.label} {{full_name: $source_name}}), " - f"(target:{target_node.label} {{full_name: $target_name}}) " - f"CREATE (source)-[r:{relation.label} " - f"{{{', '.join(f'{k}: ${k}' for k in properties.keys())}}}]->" - f"(target)" - ) - - session.run(query, {"source_name": source_node.full_name, "target_name": target_node.full_name, **properties}) diff --git a/src/codegen/extensions/graph/utils.py b/src/codegen/extensions/graph/utils.py deleted file mode 100644 index e97eb2357..000000000 --- a/src/codegen/extensions/graph/utils.py +++ /dev/null @@ -1,78 +0,0 @@ -import uuid -from dataclasses import dataclass, field -from enum import Enum - - -class NodeLabel(Enum): - CLASS = "Class" - METHOD = "Method" - FUNCTION = "Func" - - -class RelationLabel(Enum): - DEFINES = "DEFINES" - INHERITS_FROM = "INHERITS_FROM" - CALLS = "CALLS" - - -@dataclass(kw_only=True) -class BaseNode: - label: str - properties: dict = field(default_factory=dict) - id: str = field(default_factory=lambda: str(uuid.uuid4())) - - def __hash__(self): - """Make the relation hashable based on its id.""" - return hash(self.id) - - def __eq__(self, other): - """Define equality based on id.""" - if not isinstance(other, Relation): - return NotImplemented - return self.id == other.id - - -@dataclass(kw_only=True) -class Node(BaseNode): - """Simple node class with label and properties.""" - - name: str - full_name: str - - -@dataclass(kw_only=True) -class Relation(BaseNode): - """Simple relation class connecting two nodes.""" - - source_id: str - target_id: str - - def __hash__(self): - """Make the relation hashable based on its id.""" - return hash(self.id) - - def __eq__(self, other): - """Define equality based on id.""" - if not isinstance(other, Relation): - return NotImplemented - return self.id == other.id - - -class SimpleGraph: - """Basic graph implementation using sets of nodes and relations.""" - - def __init__(self): - self.nodes: dict[str, Node] = {} - self.relations: set[Relation] = set() - self.existing_relations: set[str] = set() - - def add_node(self, node: Node) -> None: - """Add a node to the graph.""" - self.nodes[node.id] = node - - def add_relation(self, relation: Relation) -> None: - """Add a relation to the graph.""" - related_nodes = f"{relation.source_id}->{relation.label}->{relation.target_id}" - if relation.source_id in self.nodes and relation.target_id in self.nodes and related_nodes not in self.existing_relations: - self.relations.add(relation) - self.existing_relations.add(related_nodes) diff --git a/src/codegen/extensions/index/__init__.py b/src/codegen/extensions/index/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/extensions/index/code_index.py b/src/codegen/extensions/index/code_index.py deleted file mode 100644 index 4cf8a5de3..000000000 --- a/src/codegen/extensions/index/code_index.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Abstract base class for code indexing implementations.""" - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional, TypeVar - -import numpy as np - -from codegen.sdk.core.codebase import Codebase - -T = TypeVar("T") # Type of the items being indexed (e.g., File, Symbol) - - -class CodeIndex(ABC): - """Abstract base class for semantic code search indices. - - This class defines the interface for different code indexing implementations. - Implementations can index at different granularities (files, symbols, etc.) - and use different embedding strategies. - - Attributes: - codebase (Codebase): The codebase being indexed - E (Optional[np.ndarray]): The embeddings matrix - items (Optional[np.ndarray]): Array of items corresponding to embeddings - commit_hash (Optional[str]): Git commit hash when index was last updated - """ - - DEFAULT_SAVE_DIR = ".codegen" - - def __init__(self, codebase: Codebase): - """Initialize the code index. - - Args: - codebase: The codebase to index - """ - self.codebase = codebase - self.E: Optional[np.ndarray] = None - self.items: Optional[np.ndarray] = None - self.commit_hash: Optional[str] = None - - @property - @abstractmethod - def save_file_name(self) -> str: - """The filename template for saving the index.""" - pass - - @abstractmethod - def _get_embeddings(self, items: list[T]) -> list[list[float]]: - """Get embeddings for a list of items. - - Args: - items: List of items to get embeddings for - - Returns: - List of embedding vectors - """ - pass - - @abstractmethod - def _get_items_to_index(self) -> list[tuple[T, str]]: - """Get all items that should be indexed and their content. - - Returns: - List of tuples (item, content_to_embed) - """ - pass - - @abstractmethod - def _get_changed_items(self) -> set[T]: - """Get set of items that have changed since last index update. - - Returns: - Set of changed items - """ - pass - - def _get_current_commit(self) -> str: - """Get the current git commit hash.""" - current = self.codebase.current_commit - if current is None: - msg = "No current commit found. Repository may be empty or in a detached HEAD state." - raise ValueError(msg) - return current.hexsha - - def _get_default_save_path(self) -> Path: - """Get the default save path for the index.""" - save_dir = Path(self.codebase.repo_path) / self.DEFAULT_SAVE_DIR - save_dir.mkdir(exist_ok=True) - - if self.commit_hash is None: - self.commit_hash = self._get_current_commit() - - filename = self.save_file_name.format(commit=self.commit_hash[:8]) - return save_dir / filename - - def create(self) -> None: - """Create embeddings for all indexed items.""" - self.commit_hash = self._get_current_commit() - - # Get items and their content - items_with_content = self._get_items_to_index() - if not items_with_content: - self.E = np.array([]) - self.items = np.array([]) - return - - # Split into separate lists - items, contents = zip(*items_with_content) - - # Get embeddings - embeddings = self._get_embeddings(contents) - - # Store embeddings and item identifiers - self.E = np.array(embeddings) - self.items = np.array([str(item) for item in items]) # Store string identifiers - - def update(self) -> None: - """Update embeddings for changed items only.""" - if self.E is None or self.items is None or self.commit_hash is None: - msg = "No index to update. Call create() or load() first." - raise ValueError(msg) - - # Get changed items - changed_items = self._get_changed_items() - if not changed_items: - return - - # Get content for changed items - items_with_content = [(item, content) for item, content in self._get_items_to_index() if item in changed_items] - - if not items_with_content: - return - - items, contents = zip(*items_with_content) - new_embeddings = self._get_embeddings(contents) - - # Create mapping of items to their indices - item_to_idx = {str(item): idx for idx, item in enumerate(self.items)} - - # Update embeddings - for item, embedding in zip(items, new_embeddings): - item_key = str(item) - if item_key in item_to_idx: - # Update existing embedding - self.E[item_to_idx[item_key]] = embedding - else: - # Add new embedding - self.E = np.vstack([self.E, embedding]) - self.items = np.append(self.items, item) - - # Update commit hash - self.commit_hash = self._get_current_commit() - - def save(self, save_path: Optional[str] = None) -> None: - """Save the index to disk.""" - if self.E is None or self.items is None: - msg = "No embeddings to save. Call create() first." - raise ValueError(msg) - - save_path = Path(save_path) if save_path else self._get_default_save_path() - save_path.parent.mkdir(parents=True, exist_ok=True) - - self._save_index(save_path) - - def load(self, load_path: Optional[str] = None) -> None: - """Load the index from disk.""" - load_path = Path(load_path) if load_path else self._get_default_save_path() - - if not load_path.exists(): - msg = f"No index found at {load_path}" - raise FileNotFoundError(msg) - - self._load_index(load_path) - - @abstractmethod - def _save_index(self, path: Path) -> None: - """Save index data to disk.""" - pass - - @abstractmethod - def _load_index(self, path: Path) -> None: - """Load index data from disk.""" - pass - - def _similarity_search_raw(self, query: str, k: int = 5) -> list[tuple[str, float]]: - """Internal method to find the k most similar items by their string identifiers. - - Args: - query: The text to search for - k: Number of results to return - - Returns: - List of tuples (item_identifier, similarity_score) sorted by similarity - """ - if self.E is None or self.items is None: - msg = "No embeddings available. Call create() or load() first." - raise ValueError(msg) - - # Get query embedding - query_embeddings = self._get_embeddings([query]) - query_embedding = query_embeddings[0] - - # Compute cosine similarity - query_norm = query_embedding / np.linalg.norm(query_embedding) - E_norm = self.E / np.linalg.norm(self.E, axis=1)[:, np.newaxis] - similarities = np.dot(E_norm, query_norm) - - # Get top k indices - top_indices = np.argsort(similarities)[-k:][::-1] - - # Return items and similarity scores - return [(str(self.items[idx]), float(similarities[idx])) for idx in top_indices] - - @abstractmethod - def similarity_search(self, query: str, k: int = 5) -> list[tuple[T, float]]: - """Find the k most similar items to a query. - - Args: - query: The text to search for - k: Number of results to return - - Returns: - List of tuples (item, similarity_score) sorted by similarity - """ - pass diff --git a/src/codegen/extensions/index/file_index.py b/src/codegen/extensions/index/file_index.py deleted file mode 100644 index a76e62d5e..000000000 --- a/src/codegen/extensions/index/file_index.py +++ /dev/null @@ -1,367 +0,0 @@ -"""File-level semantic code search index.""" - -import pickle -from pathlib import Path -from typing import Optional - -import modal -import numpy as np -import tiktoken -from openai import OpenAI -from tqdm import tqdm - -from codegen.extensions.index.code_index import CodeIndex -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.file import File -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class FileIndex(CodeIndex): - """A semantic search index over codebase files. - - This implementation indexes entire files, splitting large files into chunks - if they exceed the token limit. - """ - - EMBEDDING_MODEL = "text-embedding-3-small" - MAX_TOKENS = 8000 - BATCH_SIZE = 100 - USE_MODAL_DICT = True # Flag to control whether to use Modal Dict - - def __init__(self, codebase: Codebase): - """Initialize the file index. - - Args: - codebase: The codebase to index - """ - super().__init__(codebase) - self.client = OpenAI() - self.encoding = tiktoken.get_encoding("cl100k_base") - - def set_use_modal_dict(self, use_modal: bool) -> None: - """Set whether to use Modal Dict for storage. - - Args: - use_modal: Whether to use Modal Dict for storage - """ - self.USE_MODAL_DICT = use_modal - logger.info(f"Modal Dict storage {'enabled' if use_modal else 'disabled'}") - - @property - def save_file_name(self) -> str: - return "file_index_{commit}.pkl" - - @property - def modal_dict_id(self) -> str: - """Get the Modal Dict ID based on the same naming convention as the pickle file.""" - if not self.commit_hash: - return "file_index_latest" - return f"file_index_{self.commit_hash}" - - def delete_modal_dict(self) -> bool: - """Delete the Modal Dict storage for this index. - - Returns: - bool: True if successfully deleted, False otherwise - """ - if not self.USE_MODAL_DICT: - logger.warning("Modal Dict storage is disabled") - return False - - try: - dict_id = self.modal_dict_id - logger.info(f"Deleting Modal Dict: {dict_id}") - - # Check if the dict exists before trying to delete - try: - # Use modal.Dict.delete to properly delete the dict - modal.Dict.delete(dict_id) - logger.info(f"Successfully deleted Modal Dict: {dict_id}") - return True - except Exception as e: - logger.info(f"Modal Dict {dict_id} does not exist or cannot be deleted: {e}") - return False - except Exception as e: - logger.exception(f"Failed to delete Modal Dict: {e}") - return False - - def modal_dict_exists(self, commit_hash: Optional[str] = None) -> bool: - """Check if a Modal Dict exists for a specific commit. - - Args: - commit_hash: The commit hash to check, or None to use the current commit - - Returns: - bool: True if the Modal Dict exists, False otherwise - """ - if not self.USE_MODAL_DICT: - return False - - try: - # Use provided commit hash or current one - old_commit = self.commit_hash - if commit_hash is not None: - self.commit_hash = commit_hash - - dict_id = self.modal_dict_id - - # Restore original commit hash - if commit_hash is not None: - self.commit_hash = old_commit - - try: - # Try to access the dict - this will raise an exception if it doesn't exist - modal_dict = modal.Dict.from_name(dict_id, create_if_missing=False) - # Check if our data is in the dict - return "index_data" in modal_dict - except Exception: - return False - except Exception: - return False - - def _split_by_tokens(self, text: str) -> list[str]: - """Split text into chunks that fit within token limit.""" - tokens = self.encoding.encode(text) - chunks = [] - current_chunk = [] - current_size = 0 - - for token in tokens: - if current_size + 1 > self.MAX_TOKENS: - chunks.append(self.encoding.decode(current_chunk)) - current_chunk = [token] - current_size = 1 - else: - current_chunk.append(token) - current_size += 1 - - if current_chunk: - chunks.append(self.encoding.decode(current_chunk)) - - return chunks - - def _get_embeddings(self, texts: list[str]) -> list[list[float]]: - """Get embeddings for a batch of texts using OpenAI's API.""" - # Clean texts - texts = [text.replace("\\n", " ") for text in texts] - - # Process in batches with progress bar - all_embeddings = [] - for i in tqdm(range(0, len(texts), self.BATCH_SIZE), desc="Getting embeddings"): - batch = texts[i : i + self.BATCH_SIZE] - response = self.client.embeddings.create(model=self.EMBEDDING_MODEL, input=batch, encoding_format="float") - all_embeddings.extend(data.embedding for data in response.data) - - return all_embeddings - - def _get_items_to_index_for_files(self, files: list[File]) -> list[tuple[str, str]]: - """Get items to index for specific files.""" - items_to_index = [] - - # Filter out binary files and files without content - files_to_process = [] - for f in files: - try: - if f.content: # This will raise ValueError for binary files - files_to_process.append(f) - except ValueError: - logger.debug(f"Skipping binary file: {f.filepath}") - - if len(files) == 1: - logger.info(f"Processing file: {files[0].filepath}") - else: - logger.info(f"Found {len(files_to_process)} indexable files out of {len(files)} total files") - - # Collect all chunks that need to be processed - for file in files_to_process: - chunks = self._split_by_tokens(file.content) - if len(chunks) == 1: - items_to_index.append((file.filepath, file.content)) - else: - # For multi-chunk files, create virtual items - for i, chunk in enumerate(chunks): - chunk_id = f"{file.filepath}#chunk{i}" - items_to_index.append((chunk_id, chunk)) - - if items_to_index: - logger.info(f"Total chunks to process: {len(items_to_index)}") - return items_to_index - - def _get_items_to_index(self) -> list[tuple[str, str]]: - """Get all files and their content chunks to index.""" - return self._get_items_to_index_for_files(list(self.codebase.files)) - - def _get_changed_items(self) -> set[File]: - """Get set of files that have changed since last index.""" - if not self.commit_hash: - return set() - - # Get diffs between base commit and current state - diffs = self.codebase.get_diffs(self.commit_hash) - changed_files = set() - - for diff in diffs: - if diff.a_path: - file = self.codebase.get_file(diff.a_path) - if file: - changed_files.add(file) - if diff.b_path: - file = self.codebase.get_file(diff.b_path) - if file: - changed_files.add(file) - - return changed_files - - def _save_index(self, path: Path) -> None: - """Save index data to disk and optionally to Modal Dict.""" - # Save to local pickle file - with open(path, "wb") as f: - pickle.dump({"E": self.E, "items": self.items, "commit_hash": self.commit_hash}, f) - - # Save to Modal Dict if enabled - if self.USE_MODAL_DICT: - try: - dict_id = self.modal_dict_id - logger.info(f"Saving index to Modal Dict: {dict_id}") - - # Convert numpy arrays to lists for JSON serialization - modal_data = {"E": self.E.tolist() if self.E is not None else None, "items": self.items.tolist() if self.items is not None else None, "commit_hash": self.commit_hash} - - # Create or update Modal Dict - # Note: from_name is lazy, so we need to explicitly set the data - modal_dict = modal.Dict.from_name(dict_id, create_if_missing=True) - modal_dict["index_data"] = modal_data - - logger.info(f"Successfully saved index to Modal Dict: {dict_id}") - except Exception as e: - logger.exception(f"Failed to save index to Modal Dict: {e}") - - def _load_index(self, path: Path) -> None: - """Load index data from disk or Modal Dict.""" - # Try loading from Modal Dict first if enabled - if self.USE_MODAL_DICT: - try: - dict_id = self.modal_dict_id - logger.info(f"Attempting to load index from Modal Dict: {dict_id}") - - # from_name is lazy, so we need to check if the dict exists first - try: - modal_dict = modal.Dict.from_name(dict_id, create_if_missing=False) - # Check if the dict contains our data - if "index_data" in modal_dict: - data = modal_dict["index_data"] - - # Convert lists back to numpy arrays - self.E = np.array(data["E"]) if data["E"] is not None else None - self.items = np.array(data["items"]) if data["items"] is not None else None - self.commit_hash = data["commit_hash"] - - logger.info(f"Successfully loaded index from Modal Dict: {dict_id}") - return - else: - logger.info(f"No index data found in Modal Dict: {dict_id}") - except Exception as e: - logger.warning(f"Modal Dict {dict_id} not found or error accessing it: {e}") - except Exception as e: - logger.warning(f"Failed to load index from Modal Dict, falling back to local file: {e}") - - # Fall back to loading from local file - try: - with open(path, "rb") as f: - data = pickle.load(f) - self.E = data["E"] - self.items = data["items"] - self.commit_hash = data["commit_hash"] - logger.info(f"Loaded index from local file: {path}") - except Exception as e: - logger.exception(f"Failed to load index from local file: {e}") - raise - - def similarity_search(self, query: str, k: int = 5) -> list[tuple[File, float]]: - """Find the k most similar files to a query. - - Args: - query: The text to search for - k: Number of results to return - - Returns: - List of tuples (File, similarity_score) sorted by similarity - """ - results = [] - for filepath, score in self._similarity_search_raw(query, k): - # Handle chunked files - base_path = filepath.split("#")[0] # Remove chunk identifier if present - try: - if file := self.codebase.get_file(base_path): - results.append((file, score)) - except FileNotFoundError: - pass # Skip files that no longer exist - - return results - - def update(self) -> None: - """Update embeddings for changed files only.""" - if self.E is None or self.items is None or self.commit_hash is None: - msg = "No index to update. Call create() or load() first." - raise ValueError(msg) - - # Get changed files - changed_files = self._get_changed_items() - if not changed_files: - logger.info("No files have changed since last update") - return - - logger.info(f"Found {len(changed_files)} changed files to update") - - # Get content for changed files only - items_with_content = self._get_items_to_index_for_files(list(changed_files)) - - if not items_with_content: - logger.info("No valid content found in changed files") - return - - items, contents = zip(*items_with_content) - logger.info(f"Processing {len(contents)} chunks from changed files") - new_embeddings = self._get_embeddings(contents) - - # Create mapping of items to their indices - item_to_idx = {str(item): idx for idx, item in enumerate(self.items)} - - # Update embeddings - num_updated = 0 - num_added = 0 - for item, embedding in zip(items, new_embeddings): - item_key = str(item) - if item_key in item_to_idx: - # Update existing embedding - self.E[item_to_idx[item_key]] = embedding - num_updated += 1 - else: - # Add new embedding - self.E = np.vstack([self.E, embedding]) - self.items = np.append(self.items, item) - num_added += 1 - - logger.info(f"Updated {num_updated} existing embeddings and added {num_added} new embeddings") - - # Update commit hash - self.commit_hash = self._get_current_commit() - - # Save updated index to Modal Dict if enabled - if self.USE_MODAL_DICT and (num_updated > 0 or num_added > 0): - try: - dict_id = self.modal_dict_id - logger.info(f"Updating index in Modal Dict: {dict_id}") - - # Convert numpy arrays to lists for JSON serialization - modal_data = {"E": self.E.tolist() if self.E is not None else None, "items": self.items.tolist() if self.items is not None else None, "commit_hash": self.commit_hash} - - # Create or update Modal Dict - modal_dict = modal.Dict.from_name(dict_id, create_if_missing=True) - modal_dict["index_data"] = modal_data - - logger.info(f"Successfully updated index in Modal Dict: {dict_id}") - except Exception as e: - logger.exception(f"Failed to update index in Modal Dict: {e}") diff --git a/src/codegen/extensions/index/symbol_index.py b/src/codegen/extensions/index/symbol_index.py deleted file mode 100644 index d59abb921..000000000 --- a/src/codegen/extensions/index/symbol_index.py +++ /dev/null @@ -1,160 +0,0 @@ -"""Symbol-level semantic code search index.""" - -import pickle -from pathlib import Path - -import tiktoken -from openai import OpenAI -from tqdm import tqdm - -from codegen.extensions.index.code_index import CodeIndex -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.symbol import Symbol -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -# TODO: WIP! -class SymbolIndex(CodeIndex): - """A semantic search index over codebase symbols. - - This implementation indexes individual symbols (functions, classes, etc.) - rather than entire files. This allows for more granular search results. - """ - - EMBEDDING_MODEL = "text-embedding-3-small" - MAX_TOKENS_PER_TEXT = 8000 # Max tokens per individual text - MAX_BATCH_TOKENS = 32000 # Max total tokens per API call - BATCH_SIZE = 100 # Max number of texts per API call - - def __init__(self, codebase: Codebase): - """Initialize the symbol index.""" - super().__init__(codebase) - self.client = OpenAI() - self.encoding = tiktoken.get_encoding("cl100k_base") - - @property - def save_file_name(self) -> str: - return "symbol_index_{commit}.pkl" - - def _batch_texts_by_tokens(self, texts: list[str]) -> list[list[str]]: - """Batch texts to maximize tokens per API call while respecting limits. - - This tries to pack as many texts as possible into each batch while ensuring: - 1. No individual text exceeds MAX_TOKENS_PER_TEXT - 2. Total tokens in batch doesn't exceed MAX_BATCH_TOKENS - 3. Number of texts doesn't exceed BATCH_SIZE - """ - batches = [] - current_batch = [] - current_tokens = 0 - - for text in texts: - # Get token count for this text - tokens = self.encoding.encode(text) - n_tokens = len(tokens) - - # If text is too long, truncate it - if n_tokens > self.MAX_TOKENS_PER_TEXT: - tokens = tokens[: self.MAX_TOKENS_PER_TEXT] - text = self.encoding.decode(tokens) - n_tokens = self.MAX_TOKENS_PER_TEXT - - # Check if adding this text would exceed batch limits - if len(current_batch) + 1 > self.BATCH_SIZE or current_tokens + n_tokens > self.MAX_BATCH_TOKENS: - # Current batch is full, start a new one - if current_batch: - batches.append(current_batch) - current_batch = [] - current_tokens = 0 - - # Add text to current batch - current_batch.append(text) - current_tokens += n_tokens - - # Add the last batch if not empty - if current_batch: - batches.append(current_batch) - - return batches - - def _get_embeddings(self, texts: list[str]) -> list[list[float]]: - """Get embeddings for a batch of texts using OpenAI's API.""" - # Clean texts - texts = [text.replace("\\n", " ") for text in texts] - - # Batch texts efficiently - batches = self._batch_texts_by_tokens(texts) - logger.info(f"Processing {len(texts)} texts in {len(batches)} batches") - - # Process batches with progress bar - all_embeddings = [] - for batch in tqdm(batches, desc="Getting embeddings"): - response = self.client.embeddings.create(model=self.EMBEDDING_MODEL, input=batch, encoding_format="float") - all_embeddings.extend(data.embedding for data in response.data) - - return all_embeddings - - def _get_items_to_index(self) -> list[tuple[str, str]]: - """Get all symbols and their content to index.""" - items_to_index = [] - symbols_to_process = [s for s in self.codebase.symbols if s.source] - logger.info(f"Found {len(symbols_to_process)} symbols to index") - - # Process each symbol - no need to pre-truncate since _batch_texts_by_tokens handles it - for symbol in symbols_to_process: - symbol_id = f"{symbol.file.filepath}::{symbol.name}" - items_to_index.append((symbol_id, symbol.source)) - - logger.info(f"Total symbols to process: {len(items_to_index)}") - return items_to_index - - def _get_changed_items(self) -> set[Symbol]: - """Get set of symbols that have changed since last index.""" - if not self.commit_hash: - return set() - - # Get diffs between base commit and current state - diffs = self.codebase.get_diffs(self.commit_hash) - changed_symbols = set() - - # Get all symbols from changed files - for diff in diffs: - for path in [diff.a_path, diff.b_path]: - if not path: - continue - file = self.codebase.get_file(path) - if file: - changed_symbols.update(s for s in file.symbols if s.source) - - logger.info(f"Found {len(changed_symbols)} changed symbols") - return changed_symbols - - def _save_index(self, path: Path) -> None: - """Save index data to disk.""" - with open(path, "wb") as f: - pickle.dump({"E": self.E, "items": self.items, "commit_hash": self.commit_hash}, f) - - def _load_index(self, path: Path) -> None: - """Load index data from disk.""" - with open(path, "rb") as f: - data = pickle.load(f) - self.E = data["E"] - self.items = data["items"] - self.commit_hash = data["commit_hash"] - - def similarity_search(self, query: str, k: int = 5) -> list[tuple[Symbol, float]]: - """Find the k most similar symbols to a query.""" - results = [] - for symbol_id, score in self._similarity_search_raw(query, k): - # Parse the symbol identifier - filepath, symbol_name = symbol_id.split("::") - # Get the file and find the symbol - if file := self.codebase.get_file(filepath): - for symbol in file.symbols: - if symbol.name == symbol_name and symbol.source: - results.append((symbol, score)) - break - - return results diff --git a/src/codegen/extensions/linear/__init__.py b/src/codegen/extensions/linear/__init__.py deleted file mode 100644 index 8ba060245..000000000 --- a/src/codegen/extensions/linear/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_client import LinearClient - -__all__ = ["LinearClient"] diff --git a/src/codegen/extensions/linear/linear_client.py b/src/codegen/extensions/linear/linear_client.py deleted file mode 100644 index 0c3803153..000000000 --- a/src/codegen/extensions/linear/linear_client.py +++ /dev/null @@ -1,295 +0,0 @@ -import os -from typing import Optional - -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -from codegen.extensions.linear.types import LinearComment, LinearIssue, LinearTeam, LinearUser -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class LinearClient: - api_headers: dict - api_endpoint = "https://api.linear.app/graphql" - - def __init__(self, access_token: Optional[str] = None, team_id: Optional[str] = None, max_retries: int = 3, backoff_factor: float = 0.5): - if not access_token: - access_token = os.getenv("LINEAR_ACCESS_TOKEN") - if not access_token: - msg = "access_token is required" - raise ValueError(msg) - self.access_token = access_token - - if not team_id: - team_id = os.getenv("LINEAR_TEAM_ID") - self.team_id = team_id - - self.api_headers = { - "Content-Type": "application/json", - "Authorization": self.access_token, - } - - # Set up a session with retry logic - self.session = requests.Session() - retry_strategy = Retry( - total=max_retries, - backoff_factor=backoff_factor, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["POST", "GET"], # POST is important for GraphQL - ) - adapter = HTTPAdapter(max_retries=retry_strategy) - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - def get_issue(self, issue_id: str) -> LinearIssue: - query = """ - query getIssue($issueId: String!) { - issue(id: $issueId) { - id - title - description - } - } - """ - variables = {"issueId": issue_id} - response = self.session.post(self.api_endpoint, headers=self.api_headers, json={"query": query, "variables": variables}) - data = response.json() - issue_data = data["data"]["issue"] - return LinearIssue(id=issue_data["id"], title=issue_data["title"], description=issue_data["description"]) - - def get_issue_comments(self, issue_id: str) -> list[LinearComment]: - query = """ - query getIssueComments($issueId: String!) { - issue(id: $issueId) { - comments { - nodes { - id - body - user { - id - name - } - } - - } - } - } - """ - variables = {"issueId": issue_id} - response = self.session.post(self.api_endpoint, headers=self.api_headers, json={"query": query, "variables": variables}) - data = response.json() - comments = data["data"]["issue"]["comments"]["nodes"] - - # Parse comments into list of LinearComment objects - parsed_comments = [] - for comment in comments: - user = comment.get("user", None) - parsed_comment = LinearComment(id=comment["id"], body=comment["body"], user=LinearUser(id=user.get("id"), name=user.get("name")) if user else None) - parsed_comments.append(parsed_comment) - - # Convert raw comments to LinearComment objects - return parsed_comments - - def comment_on_issue(self, issue_id: str, body: str) -> LinearComment: - """Add a comment to an issue.""" - query = """mutation makeComment($issueId: String!, $body: String!) { - commentCreate(input: {issueId: $issueId, body: $body}) { - comment { - id - body - url - user { - id - name - } - } - } - } - """ - variables = {"issueId": issue_id, "body": body} - response = self.session.post( - self.api_endpoint, - headers=self.api_headers, - json={"query": query, "variables": variables}, - ) - data = response.json() - try: - comment_data = data["data"]["commentCreate"]["comment"] - user_data = comment_data.get("user", None) - user = LinearUser(id=user_data["id"], name=user_data["name"]) if user_data else None - - return LinearComment(id=comment_data["id"], body=comment_data["body"], user=user) - except Exception as e: - msg = f"Error creating comment\n{data}\n{e}" - raise ValueError(msg) - - def register_webhook(self, webhook_url: str, team_id: str, secret: str, enabled: bool, resource_types: list[str]): - mutation = """ - mutation createWebhook($input: WebhookCreateInput!) { - webhookCreate(input: $input) { - success - webhook { - id - enabled - } - } - } - """ - - variables = { - "input": { - "url": webhook_url, - "teamId": team_id, - "resourceTypes": resource_types, - "enabled": enabled, - "secret": secret, - } - } - - response = self.session.post(self.api_endpoint, headers=self.api_headers, json={"query": mutation, "variables": variables}) - body = response.json() - return body - - def search_issues(self, query: str, limit: int = 10) -> list[LinearIssue]: - """Search for issues using a query string. - - Args: - query: Search query string - limit: Maximum number of issues to return (default: 10) - - Returns: - List of LinearIssue objects matching the search query - """ - graphql_query = """ - query searchIssues($query: String!, $limit: Int!) { - issueSearch(query: $query, first: $limit) { - nodes { - id - title - description - } - } - } - """ - variables = {"query": query, "limit": limit} - response = self.session.post( - self.api_endpoint, - headers=self.api_headers, - json={"query": graphql_query, "variables": variables}, - ) - data = response.json() - - try: - issues_data = data["data"]["issueSearch"]["nodes"] - return [ - LinearIssue( - id=issue["id"], - title=issue["title"], - description=issue["description"], - ) - for issue in issues_data - ] - except Exception as e: - msg = f"Error searching issues\n{data}\n{e}" - raise Exception(msg) - - def create_issue(self, title: str, description: str | None = None, team_id: str | None = None) -> LinearIssue: - """Create a new issue. - - Args: - title: Title of the issue - description: Optional description of the issue - team_id: Optional team ID. If not provided, uses the client's configured team_id - - Returns: - The created LinearIssue object - - Raises: - ValueError: If no team_id is provided or configured - """ - if not team_id: - team_id = self.team_id - if not team_id: - msg = "team_id must be provided either during client initialization or in the create_issue call" - raise ValueError(msg) - - mutation = """ - mutation createIssue($input: IssueCreateInput!) { - issueCreate(input: $input) { - success - issue { - id - title - description - } - } - } - """ - - variables = { - "input": { - "teamId": team_id, - "title": title, - "description": description, - } - } - - response = self.session.post( - self.api_endpoint, - headers=self.api_headers, - json={"query": mutation, "variables": variables}, - ) - data = response.json() - - try: - issue_data = data["data"]["issueCreate"]["issue"] - return LinearIssue( - id=issue_data["id"], - title=issue_data["title"], - description=issue_data["description"], - ) - except Exception as e: - msg = f"Error creating issue\n{data}\n{e}" - raise Exception(msg) - - def get_teams(self) -> list[LinearTeam]: - """Get all teams the authenticated user has access to. - - Returns: - List of LinearTeam objects - """ - query = """ - query { - teams { - nodes { - id - name - key - } - } - } - """ - - response = self.session.post( - self.api_endpoint, - headers=self.api_headers, - json={"query": query}, - ) - data = response.json() - - try: - teams_data = data["data"]["teams"]["nodes"] - return [ - LinearTeam( - id=team["id"], - name=team["name"], - key=team["key"], - ) - for team in teams_data - ] - except Exception as e: - msg = f"Error getting teams\n{data}\n{e}" - raise Exception(msg) diff --git a/src/codegen/extensions/linear/types.py b/src/codegen/extensions/linear/types.py deleted file mode 100644 index fb9439399..000000000 --- a/src/codegen/extensions/linear/types.py +++ /dev/null @@ -1,40 +0,0 @@ -from pydantic import BaseModel - - -class LinearUser(BaseModel): - id: str - name: str - - -class LinearTeam(BaseModel): - """Represents a Linear team.""" - - id: str - name: str - key: str - - -class LinearComment(BaseModel): - id: str - body: str - user: LinearUser | None = None - - -class LinearIssue(BaseModel): - id: str - title: str - description: str | None = None - priority: int | None = None - team_id: str | None = None - - -class LinearEvent(BaseModel): - """Represents a Linear webhook event.""" - - action: str # e.g. "create", "update", "remove" - type: str # e.g. "Issue", "Comment", "Project" - data: LinearIssue | LinearComment # The actual event data - url: str # URL to the resource in Linear - created_at: str | None = None # ISO timestamp - organization_id: str | None = None - team_id: str | None = None diff --git a/src/codegen/extensions/lsp/codemods/__init__.py b/src/codegen/extensions/lsp/codemods/__init__.py deleted file mode 100644 index feb8117b7..000000000 --- a/src/codegen/extensions/lsp/codemods/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from codegen.extensions.lsp.codemods.base import CodeAction -from codegen.extensions.lsp.codemods.split_tests import SplitTests - -ACTIONS: list[CodeAction] = [SplitTests()] diff --git a/src/codegen/extensions/lsp/codemods/base.py b/src/codegen/extensions/lsp/codemods/base.py deleted file mode 100644 index ced434217..000000000 --- a/src/codegen/extensions/lsp/codemods/base.py +++ /dev/null @@ -1,41 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, ClassVar - -from lsprotocol import types - -from codegen.sdk.core.interfaces.editable import Editable - -if TYPE_CHECKING: - from codegen.extensions.lsp.server import CodegenLanguageServer - - -class CodeAction(ABC): - name: str - kind: ClassVar[types.CodeActionKind] = types.CodeActionKind.Refactor - - def __init__(self): - pass - - @abstractmethod - def execute(self, server: "CodegenLanguageServer", node: Editable) -> None: ... - - @abstractmethod - def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool: ... - - def to_command(self, uri: str, range: types.Range) -> types.Command: - return types.Command( - title=self.name, - command=self.command_name(), - arguments=[uri, range], - ) - - def to_lsp(self, uri: str, range: types.Range) -> types.CodeAction: - return types.CodeAction( - title=self.name, - kind=self.kind, - data=[self.command_name(), uri, range], - ) - - @classmethod - def command_name(cls) -> str: - return f"codegen-{cls.__name__}" diff --git a/src/codegen/extensions/lsp/codemods/move_symbol_to_file.py b/src/codegen/extensions/lsp/codemods/move_symbol_to_file.py deleted file mode 100644 index c0c98f66e..000000000 --- a/src/codegen/extensions/lsp/codemods/move_symbol_to_file.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import TYPE_CHECKING - -from codegen.extensions.lsp.codemods.base import CodeAction -from codegen.sdk.core.interfaces.editable import Editable - -if TYPE_CHECKING: - from codegen.extensions.lsp.server import CodegenLanguageServer - - -class MoveSymbolToFile(CodeAction): - name = "Move Symbol to File" - - def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool: - return True - - def execute(self, server: "CodegenLanguageServer", node: Editable) -> None: - target_file = server.window_show_message_request( - "Select the file to move the symbol to", - server.codebase.files, - ).result(timeout=10) - if target_file is None: - return - server.codebase.move_symbol(node.parent_symbol, target_file) diff --git a/src/codegen/extensions/lsp/codemods/split_tests.py b/src/codegen/extensions/lsp/codemods/split_tests.py deleted file mode 100644 index 3b17bfeff..000000000 --- a/src/codegen/extensions/lsp/codemods/split_tests.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import TYPE_CHECKING - -from codegen.extensions.lsp.codemods.base import CodeAction -from codegen.sdk.core.function import Function -from codegen.sdk.core.interfaces.editable import Editable - -if TYPE_CHECKING: - from codegen.extensions.lsp.server import CodegenLanguageServer - - -class SplitTests(CodeAction): - name = "Split Tests" - - def _get_targets(self, server: "CodegenLanguageServer", node: Editable) -> dict[Function, str]: - targets = {} - for function in node.file.functions: - if function.name.startswith("test_"): - target = f"{node.file.directory.path}/{function.name}.py" - if not server.codebase.has_file(target): - targets[function] = target - return targets - - def is_applicable(self, server: "CodegenLanguageServer", node: Editable) -> bool: - if "tests" in str(node.file.path): - return len(self._get_targets(server, node)) > 1 - return False - - def execute(self, server: "CodegenLanguageServer", node: Editable) -> None: - targets = self._get_targets(server, node) - for function, target in targets.items(): - new_file = server.codebase.create_file(target) - function.move_to_file(new_file, strategy="duplicate_dependencies") - # node.file.remove() diff --git a/src/codegen/extensions/lsp/completion.py b/src/codegen/extensions/lsp/completion.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/extensions/lsp/definition.py b/src/codegen/extensions/lsp/definition.py deleted file mode 100644 index acecc7256..000000000 --- a/src/codegen/extensions/lsp/definition.py +++ /dev/null @@ -1,35 +0,0 @@ -from lsprotocol.types import Position - -from codegen.sdk.core.assignment import Assignment -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -def go_to_definition(node: Editable | None, uri: str, position: Position) -> Editable | None: - if node is None or not isinstance(node, (Expression)): - logger.warning(f"No node found at {uri}:{position}") - return None - if isinstance(node, Name) and isinstance(node.parent, ChainedAttribute) and node.parent.attribute == node: - node = node.parent - if isinstance(node.parent, FunctionCall) and node.parent.get_name() == node: - node = node.parent - logger.info(f"Resolving definition for {node}") - if isinstance(node, FunctionCall): - resolved = node.function_definition - else: - resolved = node.resolved_value - if resolved is None: - logger.warning(f"No resolved value found for {node.name} at {uri}:{position}") - return None - if isinstance(resolved, (HasName,)): - resolved = resolved.get_name() - if isinstance(resolved.parent, Assignment) and resolved.parent.value == resolved: - resolved = resolved.parent.get_name() - return resolved diff --git a/src/codegen/extensions/lsp/document_symbol.py b/src/codegen/extensions/lsp/document_symbol.py deleted file mode 100644 index 01000755a..000000000 --- a/src/codegen/extensions/lsp/document_symbol.py +++ /dev/null @@ -1,26 +0,0 @@ -from lsprotocol.types import DocumentSymbol - -from codegen.extensions.lsp.kind import get_kind -from codegen.extensions.lsp.range import get_range -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.extensions.sort import sort_editables - - -def get_document_symbol(node: Editable) -> DocumentSymbol: - children = [] - nodes = [] - if isinstance(node, Class): - nodes.extend(node.methods) - nodes.extend(node.attributes) - nodes.extend(node.nested_classes) - nodes = sort_editables(nodes) - for child in nodes: - children.append(get_document_symbol(child)) - return DocumentSymbol( - name=node.name, - kind=get_kind(node), - range=get_range(node), - selection_range=get_range(node.get_name()), - children=children, - ) diff --git a/src/codegen/extensions/lsp/execute.py b/src/codegen/extensions/lsp/execute.py deleted file mode 100644 index 5e34121d1..000000000 --- a/src/codegen/extensions/lsp/execute.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import TYPE_CHECKING, Any, Callable - -from lsprotocol import types -from lsprotocol.types import Position, Range - -from codegen.extensions.lsp.codemods.base import CodeAction -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.extensions.lsp.server import CodegenLanguageServer - -logger = get_logger(__name__) - - -def process_args(args: Any) -> tuple[str, Range]: - uri = args[0] - range = args[1] - range = Range(start=Position(line=range["start"]["line"], character=range["start"]["character"]), end=Position(line=range["end"]["line"], character=range["end"]["character"])) - return uri, range - - -def execute_action(server: "CodegenLanguageServer", action: CodeAction, args: Any) -> None: - uri, range = process_args(args) - node = server.get_node_under_cursor(uri, range.start, range.end) - if node is None: - logger.warning(f"No node found for range {range}") - return - action.execute(server, node, *args[2:]) - server.codebase.commit() - - -def get_execute_action(action: CodeAction) -> Callable[["CodegenLanguageServer", Any], None]: - def execute_action(server: "CodegenLanguageServer", args: Any) -> None: - logger.info(f"Executing action {action.command_name()} with args {args}") - execute_action(server, action, args) - server.workspace_apply_edit(types.ApplyWorkspaceEditParams(edit=server.io.get_workspace_edit())).result() - - return execute_action diff --git a/src/codegen/extensions/lsp/io.py b/src/codegen/extensions/lsp/io.py deleted file mode 100644 index b3f02b4e5..000000000 --- a/src/codegen/extensions/lsp/io.py +++ /dev/null @@ -1,152 +0,0 @@ -import pprint -from dataclasses import dataclass -from pathlib import Path - -from attr import asdict -from lsprotocol import types -from lsprotocol.types import CreateFile, CreateFileOptions, DeleteFile, Position, Range, RenameFile, TextEdit -from pygls.workspace import TextDocument, Workspace - -from codegen.sdk.codebase.io.file_io import FileIO -from codegen.sdk.codebase.io.io import IO -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -@dataclass -class File: - doc: TextDocument | None - path: Path - change: TextEdit | None = None - other_change: CreateFile | RenameFile | DeleteFile | None = None - version: int = 0 - - @property - def deleted(self) -> bool: - return self.other_change is not None and self.other_change.kind == "delete" - - @property - def created(self) -> bool: - return self.other_change is not None and self.other_change.kind == "create" - - @property - def identifier(self) -> types.OptionalVersionedTextDocumentIdentifier: - return types.OptionalVersionedTextDocumentIdentifier(uri=self.path.as_uri(), version=self.version) - - -class LSPIO(IO): - base_io: FileIO - workspace: Workspace - files: dict[Path, File] - - def __init__(self, workspace: Workspace): - self.workspace = workspace - self.base_io = FileIO() - self.files = {} - - def _get_doc(self, path: Path) -> TextDocument: - uri = path.as_uri() - logger.info(f"Getting document for {uri}") - return self.workspace.get_text_document(uri) - - def _get_file(self, path: Path) -> File: - if path not in self.files: - doc = self._get_doc(path) - self.files[path] = File(doc=doc, path=path, version=doc.version or 0) - return self.files[path] - - def read_text(self, path: Path) -> str: - file = self._get_file(path) - if file.deleted: - msg = f"File {path} has been deleted" - raise FileNotFoundError(msg) - if file.change: - return file.change.new_text - if file.created: - return "" - if file.doc is None: - return self.base_io.read_text(path) - return file.doc.source - - def read_bytes(self, path: Path) -> bytes: - file = self._get_file(path) - if file.deleted: - msg = f"File {path} has been deleted" - raise FileNotFoundError(msg) - if file.change: - return file.change.new_text.encode("utf-8") - if file.created: - return b"" - if file.doc is None: - return self.base_io.read_bytes(path) - return file.doc.source.encode("utf-8") - - def write_bytes(self, path: Path, content: bytes) -> None: - logger.info(f"Writing bytes to {path}") - start = Position(line=0, character=0) - file = self._get_file(path) - if self.file_exists(path): - lines = self.read_text(path).splitlines() - if len(lines) == 0: - end = Position(line=0, character=0) - else: - end = Position(line=len(lines) - 1, character=len(lines[-1])) - file.change = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8")) - else: - file.other_change = CreateFile(uri=path.as_uri(), options=CreateFileOptions()) - file.change = TextEdit(range=Range(start=start, end=start), new_text=content.decode("utf-8")) - - def save_files(self, files: set[Path] | None = None) -> None: - logger.info(f"Saving files {files}") - - def check_changes(self) -> None: - self.base_io.check_changes() - - def delete_file(self, path: Path) -> None: - file = self._get_file(path) - file.other_change = DeleteFile(uri=path.as_uri()) - self.base_io.delete_file(path) - - def file_exists(self, path: Path) -> bool: - file = self._get_file(path) - if file.deleted: - return False - if file.change: - return True - if file.created: - return True - if file.doc is None: - return self.base_io.file_exists(path) - try: - file.doc.source - return True - except FileNotFoundError: - return False - - def untrack_file(self, path: Path) -> None: - self.base_io.untrack_file(path) - - def get_workspace_edit(self) -> types.WorkspaceEdit: - document_changes = [] - for _, file in self.files.items(): - id = file.identifier - if file.other_change: - document_changes.append(file.other_change) - file.other_change = None - if file.change: - document_changes.append(types.TextDocumentEdit(text_document=id, edits=[file.change])) - file.version += 1 - file.change = None - logger.info(f"Workspace edit: {pprint.pformat(list(map(asdict, document_changes)))}") - return types.WorkspaceEdit(document_changes=document_changes) - - def update_file(self, path: Path, version: int | None = None) -> None: - file = self._get_file(path) - file.doc = self.workspace.get_text_document(path.as_uri()) - if version is not None: - file.version = version - - def close_file(self, path: Path) -> None: - file = self._get_file(path) - file.doc = None diff --git a/src/codegen/extensions/lsp/kind.py b/src/codegen/extensions/lsp/kind.py deleted file mode 100644 index 609885164..000000000 --- a/src/codegen/extensions/lsp/kind.py +++ /dev/null @@ -1,31 +0,0 @@ -from lsprotocol.types import SymbolKind - -from codegen.sdk.core.assignment import Assignment -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.file import File -from codegen.sdk.core.function import Function -from codegen.sdk.core.interface import Interface -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.statements.attribute import Attribute -from codegen.sdk.typescript.namespace import TSNamespace - -kinds = { - File: SymbolKind.File, - Class: SymbolKind.Class, - Function: SymbolKind.Function, - Assignment: SymbolKind.Variable, - Interface: SymbolKind.Interface, - TSNamespace: SymbolKind.Namespace, - Attribute: SymbolKind.Variable, -} - - -def get_kind(node: Editable) -> SymbolKind: - if isinstance(node, Function): - if node.is_method: - return SymbolKind.Method - for kind in kinds: - if isinstance(node, kind): - return kinds[kind] - msg = f"No kind found for {node}, {type(node)}" - raise ValueError(msg) diff --git a/src/codegen/extensions/lsp/lsp.py b/src/codegen/extensions/lsp/lsp.py deleted file mode 100644 index ac716fe9f..000000000 --- a/src/codegen/extensions/lsp/lsp.py +++ /dev/null @@ -1,140 +0,0 @@ -import logging - -from lsprotocol import types - -import codegen -from codegen.extensions.lsp.definition import go_to_definition -from codegen.extensions.lsp.document_symbol import get_document_symbol -from codegen.extensions.lsp.protocol import CodegenLanguageServerProtocol -from codegen.extensions.lsp.range import get_range -from codegen.extensions.lsp.server import CodegenLanguageServer -from codegen.extensions.lsp.utils import get_path -from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite -from codegen.sdk.core.file import SourceFile -from codegen.shared.logging.get_logger import get_logger - -version = getattr(codegen, "__version__", "v0.1") -server = CodegenLanguageServer("codegen", version, protocol_cls=CodegenLanguageServerProtocol) -logger = get_logger(__name__) - - -@server.feature(types.TEXT_DOCUMENT_DID_OPEN) -def did_open(server: CodegenLanguageServer, params: types.DidOpenTextDocumentParams) -> None: - """Handle document open notification.""" - logger.info(f"Document opened: {params.text_document.uri}") - # The document is automatically added to the workspace by pygls - # We can perform any additional processing here if needed - path = get_path(params.text_document.uri) - server.io.update_file(path, params.text_document.version) - file = server.codebase.get_file(str(path), optional=True) - if not isinstance(file, SourceFile) and path.suffix in server.codebase.ctx.extensions: - sync = DiffLite(change_type=ChangeType.Added, path=path) - server.codebase.ctx.apply_diffs([sync]) - - -@server.feature(types.TEXT_DOCUMENT_DID_CHANGE) -def did_change(server: CodegenLanguageServer, params: types.DidChangeTextDocumentParams) -> None: - """Handle document change notification.""" - logger.info(f"Document changed: {params.text_document.uri}") - # The document is automatically updated in the workspace by pygls - # We can perform any additional processing here if needed - path = get_path(params.text_document.uri) - server.io.update_file(path, params.text_document.version) - sync = DiffLite(change_type=ChangeType.Modified, path=path) - server.codebase.ctx.apply_diffs([sync]) - - -@server.feature(types.WORKSPACE_TEXT_DOCUMENT_CONTENT) -def workspace_text_document_content(server: CodegenLanguageServer, params: types.TextDocumentContentParams) -> types.TextDocumentContentResult: - """Handle workspace text document content notification.""" - logger.debug(f"Workspace text document content: {params.uri}") - path = get_path(params.uri) - if not server.io.file_exists(path): - logger.warning(f"File does not exist: {path}") - return types.TextDocumentContentResult( - text="", - ) - content = server.io.read_text(path) - return types.TextDocumentContentResult( - text=content, - ) - - -@server.feature(types.TEXT_DOCUMENT_DID_CLOSE) -def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentParams) -> None: - """Handle document close notification.""" - logger.info(f"Document closed: {params.text_document.uri}") - # The document is automatically removed from the workspace by pygls - # We can perform any additional cleanup here if needed - path = get_path(params.text_document.uri) - server.io.close_file(path) - - -@server.feature( - types.TEXT_DOCUMENT_RENAME, - options=types.RenameOptions(work_done_progress=True), -) -def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.RenameResult: - symbol = server.get_symbol(params.text_document.uri, params.position) - if symbol is None: - logger.warning(f"No symbol found at {params.text_document.uri}:{params.position}") - return - logger.info(f"Renaming symbol {symbol.name} to {params.new_name}") - task = server.progress_manager.begin_with_token(f"Renaming symbol {symbol.name} to {params.new_name}", params.work_done_token) - symbol.rename(params.new_name) - task.update("Committing changes") - server.codebase.commit() - task.end() - return server.io.get_workspace_edit() - - -@server.feature( - types.TEXT_DOCUMENT_DOCUMENT_SYMBOL, - options=types.DocumentSymbolOptions(work_done_progress=True), -) -def document_symbol(server: CodegenLanguageServer, params: types.DocumentSymbolParams) -> types.DocumentSymbolResult: - file = server.get_file(params.text_document.uri) - symbols = [] - task = server.progress_manager.begin_with_token(f"Getting document symbols for {params.text_document.uri}", params.work_done_token, count=len(file.symbols)) - for idx, symbol in enumerate(file.symbols): - task.update(f"Getting document symbols for {params.text_document.uri}", count=idx) - symbols.append(get_document_symbol(symbol)) - task.end() - return symbols - - -@server.feature( - types.TEXT_DOCUMENT_DEFINITION, - options=types.DefinitionOptions(work_done_progress=True), -) -def definition(server: CodegenLanguageServer, params: types.DefinitionParams): - node = server.get_node_under_cursor(params.text_document.uri, params.position) - task = server.progress_manager.begin_with_token(f"Getting definition for {params.text_document.uri}", params.work_done_token) - resolved = go_to_definition(node, params.text_document.uri, params.position) - task.end() - return types.Location( - uri=resolved.file.path.as_uri(), - range=get_range(resolved), - ) - - -@server.feature( - types.TEXT_DOCUMENT_CODE_ACTION, - options=types.CodeActionOptions(resolve_provider=True, work_done_progress=True), -) -def code_action(server: CodegenLanguageServer, params: types.CodeActionParams) -> types.CodeActionResult: - logger.info(f"Received code action: {params}") - actions = server.get_actions_for_range(params) - return actions - - -@server.feature( - types.CODE_ACTION_RESOLVE, -) -def code_action_resolve(server: CodegenLanguageServer, params: types.CodeAction) -> types.CodeAction: - return server.resolve_action(params) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - server.start_io() diff --git a/src/codegen/extensions/lsp/progress.py b/src/codegen/extensions/lsp/progress.py deleted file mode 100644 index 70eb365e5..000000000 --- a/src/codegen/extensions/lsp/progress.py +++ /dev/null @@ -1,60 +0,0 @@ -import uuid - -from lsprotocol import types -from lsprotocol.types import ProgressToken -from pygls.lsp.server import LanguageServer - -from codegen.sdk.codebase.progress.progress import Progress -from codegen.sdk.codebase.progress.stub_task import StubTask -from codegen.sdk.codebase.progress.task import Task - - -class LSPTask(Task): - count: int | None - - def __init__(self, server: LanguageServer, message: str, token: ProgressToken, count: int | None = None, create_token: bool = True) -> None: - self.token = token - if create_token: - server.work_done_progress.begin(self.token, types.WorkDoneProgressBegin(title=message)) - self.server = server - self.message = message - self.count = count - self.create_token = create_token - - def update(self, message: str, count: int | None = None) -> None: - if self.count is not None and count is not None: - percent = int(count * 100 / self.count) - else: - percent = None - self.server.work_done_progress.report(self.token, types.WorkDoneProgressReport(message=message, percentage=percent)) - - def end(self) -> None: - if self.create_token: - self.server.work_done_progress.end(self.token, value=types.WorkDoneProgressEnd()) - - -class LSPProgress(Progress[LSPTask | StubTask]): - initialized = False - - def __init__(self, server: LanguageServer, initial_token: ProgressToken | None = None): - self.server = server - self.initial_token = initial_token - if initial_token is not None: - self.server.work_done_progress.begin(initial_token, types.WorkDoneProgressBegin(title="Parsing codebase...")) - - def begin_with_token(self, message: str, token: ProgressToken | None = None, *, count: int | None = None, create_token: bool = True) -> LSPTask | StubTask: - if token is None: - return StubTask() - return LSPTask(self.server, message, token, count, create_token=create_token) - - def begin(self, message: str, count: int | None = None) -> LSPTask | StubTask: - if self.initialized: - token = str(uuid.uuid4()) - self.server.work_done_progress.create(token).result() - return LSPTask(self.server, message, token, count, create_token=False) - return self.begin_with_token(message, self.initial_token, count=None, create_token=False) - - def finish_initialization(self) -> None: - self.initialized = False # We can't initiate server work during syncs - if self.initial_token is not None: - self.server.work_done_progress.end(self.initial_token, value=types.WorkDoneProgressEnd()) diff --git a/src/codegen/extensions/lsp/protocol.py b/src/codegen/extensions/lsp/protocol.py deleted file mode 100644 index cc0d55c29..000000000 --- a/src/codegen/extensions/lsp/protocol.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -from pathlib import Path -from typing import TYPE_CHECKING - -from lsprotocol.types import INITIALIZE, InitializeParams, InitializeResult -from pygls.protocol import LanguageServerProtocol, lsp_method - -from codegen.configs.models.codebase import CodebaseConfig -from codegen.extensions.lsp.io import LSPIO -from codegen.extensions.lsp.progress import LSPProgress -from codegen.extensions.lsp.utils import get_path -from codegen.sdk.core.codebase import Codebase - -if TYPE_CHECKING: - from codegen.extensions.lsp.server import CodegenLanguageServer - - -class CodegenLanguageServerProtocol(LanguageServerProtocol): - _server: "CodegenLanguageServer" - - def _init_codebase(self, params: InitializeParams) -> None: - progress = LSPProgress(self._server, params.work_done_token) - if params.root_path: - root = Path(params.root_path) - elif params.root_uri: - root = get_path(params.root_uri) - else: - root = os.getcwd() - config = CodebaseConfig().model_copy(update={"full_range_index": True}) - io = LSPIO(self.workspace) - self._server.codebase = Codebase(repo_path=str(root), config=config, io=io, progress=progress) - self._server.progress_manager = progress - self._server.io = io - progress.finish_initialization() - - @lsp_method(INITIALIZE) - def lsp_initialize(self, params: InitializeParams) -> InitializeResult: - ret = super().lsp_initialize(params) - self._init_codebase(params) - return ret diff --git a/src/codegen/extensions/lsp/range.py b/src/codegen/extensions/lsp/range.py deleted file mode 100644 index 9762e9d00..000000000 --- a/src/codegen/extensions/lsp/range.py +++ /dev/null @@ -1,32 +0,0 @@ -import tree_sitter -from lsprotocol.types import Position, Range -from pygls.workspace import TextDocument - -from codegen.sdk.core.interfaces.editable import Editable - - -def get_range(node: Editable) -> Range: - start_point = node.start_point - end_point = node.end_point - for extended_node in node.extended_nodes: - if extended_node.start_point.row < start_point.row: - start_point = extended_node.start_point - if extended_node.end_point.row > end_point.row: - end_point = extended_node.end_point - return Range( - start=Position(line=start_point.row, character=start_point.column), - end=Position(line=end_point.row, character=end_point.column), - ) - - -def get_tree_sitter_range(range: Range, document: TextDocument) -> tree_sitter.Range: - start_pos = tree_sitter.Point(row=range.start.line, column=range.start.character) - end_pos = tree_sitter.Point(row=range.end.line, column=range.end.character) - start_byte = document.offset_at_position(range.start) - end_byte = document.offset_at_position(range.end) - return tree_sitter.Range( - start_point=start_pos, - end_point=end_pos, - start_byte=start_byte, - end_byte=end_byte, - ) diff --git a/src/codegen/extensions/lsp/server.py b/src/codegen/extensions/lsp/server.py deleted file mode 100644 index 4d24cc7f2..000000000 --- a/src/codegen/extensions/lsp/server.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Any, Optional - -from lsprotocol import types -from lsprotocol.types import Position, Range -from pygls.lsp.server import LanguageServer - -from codegen.extensions.lsp.codemods import ACTIONS -from codegen.extensions.lsp.codemods.base import CodeAction -from codegen.extensions.lsp.execute import execute_action -from codegen.extensions.lsp.io import LSPIO -from codegen.extensions.lsp.progress import LSPProgress -from codegen.extensions.lsp.range import get_tree_sitter_range -from codegen.extensions.lsp.utils import get_path -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.file import File, SourceFile -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.symbol import Symbol -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class CodegenLanguageServer(LanguageServer): - codebase: Optional[Codebase] - io: Optional[LSPIO] - progress_manager: Optional[LSPProgress] - actions: dict[str, CodeAction] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.actions = {action.command_name(): action for action in ACTIONS} - # for action in self.actions.values(): - # self.command(action.command_name())(get_execute_action(action)) - - def get_file(self, uri: str) -> SourceFile | File: - path = get_path(uri) - return self.codebase.get_file(str(path)) - - def get_symbol(self, uri: str, position: Position) -> Symbol | None: - node = self.get_node_under_cursor(uri, position) - if node is None: - logger.warning(f"No node found for {uri} at {position}") - return None - return node.parent_of_type(Symbol) - - def get_node_under_cursor(self, uri: str, position: Position, end_position: Position | None = None) -> Editable | None: - file = self.get_file(uri) - resolved_uri = file.path.absolute().as_uri() - logger.info(f"Getting node under cursor for {resolved_uri} at {position}") - document = self.workspace.get_text_document(resolved_uri) - candidates = [] - target_byte = document.offset_at_position(position) - end_byte = document.offset_at_position(end_position) if end_position is not None else None - for node in file._range_index.nodes: - if node.start_byte <= target_byte and node.end_byte >= target_byte: - if end_position is not None: - if node.end_byte < end_byte: - continue - candidates.append(node) - if not candidates: - return None - return min(candidates, key=lambda node: abs(node.end_byte - node.start_byte)) - - def get_node_for_range(self, uri: str, range: Range) -> Editable | None: - file = self.get_file(uri) - document = self.workspace.get_text_document(uri) - ts_range = get_tree_sitter_range(range, document) - for node in file._range_index.get_all_for_range(ts_range): - return node - return None - - def get_actions_for_range(self, params: types.CodeActionParams) -> list[types.CodeAction]: - if params.context.only is not None: - only = [types.CodeActionKind(kind) for kind in params.context.only] - else: - only = None - node = self.get_node_under_cursor(params.text_document.uri, params.range.start) - if node is None: - logger.warning(f"No node found for range {params.range} in {params.text_document.uri}") - return [] - actions = [] - task = self.progress_manager.begin_with_token(f"Getting code actions for {params.text_document.uri}", params.work_done_token, count=len(self.actions)) - for idx, action in enumerate(self.actions.values()): - task.update(f"Checking action {action.name}", idx) - if only and action.kind not in only: - logger.warning(f"Skipping action {action.kind} because it is not in {only}") - continue - if action.is_applicable(self, node): - actions.append(action.to_lsp(params.text_document.uri, params.range)) - task.end() - return actions - - def resolve_action(self, action: types.CodeAction) -> types.CodeAction: - name = action.data[0] - action_codemod = self.actions.get(name, None) - if action_codemod is None: - return action - execute_action(self, action_codemod, action.data[1:]) - action.edit = self.io.get_workspace_edit() - return action diff --git a/src/codegen/extensions/lsp/utils.py b/src/codegen/extensions/lsp/utils.py deleted file mode 100644 index 3dce5f751..000000000 --- a/src/codegen/extensions/lsp/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from pathlib import Path - -from pygls.uris import to_fs_path - - -def get_path(uri: str) -> Path: - return Path(to_fs_path(uri)).absolute() diff --git a/src/codegen/extensions/mcp/README.md b/src/codegen/extensions/mcp/README.md deleted file mode 100644 index 0f5d3e2a7..000000000 --- a/src/codegen/extensions/mcp/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Codegen MCP Servers - -This directory contains reference implementations of MCP (Machine Control Protocol) servers that extend AI Agent capabilities using the Codegen SDK. These servers enable AI Agents to: - -- Query and analyze your codebase (`codebase_agent.py`) -- Run deterministic codemods (`codebase_mods.py`) -- Invoke tools built with Codegen SDK (`codebase_tools.py`) - -## What is MCP? - -MCP (Model Context Protocol) allows AI Agents to interact with local tools and services through a standardized interface. The servers in this directory demonstrate how you might write an MCP server that leverages Codegen's capabilities. - -## Setup Instructions - -### Cline - -Add this to your `cline_mcp_settings.json` file to get started: - -``` -{ - "mcpServers": { - "codegen-cli": { - "command": "uv", - "args": [ - "--directory", - "/codegen-sdk/src/codegen/extensions/mcp", - "run", - "codebase_agent.py | codebase_mods | codebase_tools" - ] - } - } -} -``` - -### Cursor: - -Under the `Settings` > `Feature` > `MCP Servers` section, click "Add New MCP Server" and add the following: - -``` -Name: codegen-mcp -Type: Command -Command: uv --directory /codegen-sdk/src/codegen/cli/mcp run -``` diff --git a/src/codegen/extensions/mcp/codebase_mods.py b/src/codegen/extensions/mcp/codebase_mods.py deleted file mode 100644 index b47055945..000000000 --- a/src/codegen/extensions/mcp/codebase_mods.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -import os -from typing import Annotated - -from mcp.server.fastmcp import FastMCP - -from codegen.sdk.core.codebase import Codebase -from codegen.shared.enums.programming_language import ProgrammingLanguage - -mcp = FastMCP( - "codebase-mods-mcp", - instructions="Use this server to invoke deterministic codemods for your codebase. This implements a variety of codemods to be used to modify your codebase to your satisfaction", -) - - -@mcp.tool(name="split_files_by_function", description="split out the functions in defined in the provided file into new files") -def split_files_by_function( - target_file: Annotated[str, "file path to the target file to split"], - codebase_dir: Annotated[str, "Absolute path to the codebase root directory. It is highly encouraged to provide the root codebase directory and not a sub directory"], - codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"], -): - if not os.path.exists(codebase_dir): - return {"error": f"Codebase directory '{codebase_dir}' does not exist. Please provide a valid directory path."} - codebase = Codebase(repo_path=codebase_dir, language=codebase_language) - new_files = {} - file = codebase.get_file(target_file) - # for each test_function in the file - for function in file.functions: - # Create a new file for each test function using its name - new_file = codebase.create_file(f"{file.directory.path}/{function.name}.py", sync=False) - - print(f"🚠 🚠 Moving `{function.name}` to new file `{new_file.name}`") - # Move the test function to the newly created file - function.move_to_file(new_file) - new_files[new_file.filepath] = [function.name] - - codebase.commit() - - result = {"description": "the following new files have been created with each with containing the function specified", "new_files": new_files} - - return json.dumps(result, indent=2) - - -if __name__ == "__main__": - # Initialize and run the server - print("Starting codebase mods server...") - mcp.run(transport="stdio") diff --git a/src/codegen/extensions/mcp/codebase_tools.py b/src/codegen/extensions/mcp/codebase_tools.py deleted file mode 100644 index 52a25b1d6..000000000 --- a/src/codegen/extensions/mcp/codebase_tools.py +++ /dev/null @@ -1,59 +0,0 @@ -import json -from typing import Annotated, Optional - -from mcp.server.fastmcp import FastMCP - -from codegen.extensions.tools import reveal_symbol -from codegen.extensions.tools.search import search -from codegen.sdk.core.codebase import Codebase -from codegen.shared.enums.programming_language import ProgrammingLanguage - -mcp = FastMCP( - "codebase-tools-mcp", - instructions="""Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase. - Use this tool for all questions, queries regarding your codebase.""", -) - - -@mcp.tool(name="reveal_symbol", description="Reveal the dependencies and usages of a symbol up to N degrees") -def reveal_symbol_tool( - symbol_name: Annotated[str, "Name of the symbol to inspect"], - target_file: Annotated[Optional[str], "The file path of the file containing the symbol to inspect"], - codebase_dir: Annotated[str, "The root directory of your codebase"], - codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"], - max_depth: Annotated[Optional[int], "depth up to which symbol information is retrieved"], - collect_dependencies: Annotated[Optional[bool], "includes dependencies of symbol"], - collect_usages: Annotated[Optional[bool], "includes usages of symbol"], -): - codebase = Codebase(repo_path=codebase_dir, language=codebase_language) - result = reveal_symbol( - codebase=codebase, - symbol_name=symbol_name, - filepath=target_file, - max_depth=max_depth, - collect_dependencies=collect_dependencies, - collect_usages=collect_usages, - ) - return json.dumps(result, indent=2) - - -@mcp.tool(name="search_codebase", description="The search query to find in the codebase. When ripgrep is available, this will be passed as a ripgrep pattern. For regex searches, set use_regex=True") -def search_codebase_tool( - query: Annotated[str, "The search query to find in the codebase. When ripgrep is available, this will be passed as a ripgrep pattern. For regex searches, set use_regex=True."], - codebase_dir: Annotated[str, "The root directory of your codebase"], - codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"], - target_directories: Annotated[Optional[list[str]], "list of directories to search within"] = None, - file_extensions: Annotated[Optional[list[str]], "list of file extensions to search (e.g. ['.py', '.ts'])"] = None, - page: Annotated[int, "page number to return (1-based)"] = 1, - files_per_page: Annotated[int, "number of files to return per page"] = 10, - use_regex: Annotated[bool, "use regex for the search query"] = False, -): - codebase = Codebase(repo_path=codebase_dir, language=codebase_language) - result = search(codebase, query, target_directories=target_directories, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex) - return json.dumps(result, indent=2) - - -if __name__ == "__main__": - # Initialize and run the server - print("Starting codebase tools server...") - mcp.run(transport="stdio") diff --git a/src/codegen/extensions/slack/types.py b/src/codegen/extensions/slack/types.py deleted file mode 100644 index a7203c526..000000000 --- a/src/codegen/extensions/slack/types.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - - -class RichTextElement(BaseModel): - type: str - user_id: str | None = None - text: str | None = None - style: dict | None = None - url: str | None = None - channel_id: str | None = None - - -class RichTextSection(BaseModel): - type: Literal["rich_text_section", "rich_text_list", "rich_text_quote", "rich_text_preformatted", "text", "channel", "user", "emoji", "link"] - elements: list[RichTextElement] - style: dict | str | None = None # Can be either a dict for rich text styling or a string for list styles (e.g. "bullet") - - -class Block(BaseModel): - type: Literal["rich_text", "section", "divider", "header", "context", "actions", "image"] - block_id: str - elements: list[RichTextSection] - - -class SlackEvent(BaseModel): - user: str - type: str - ts: str - client_msg_id: str | None = None - text: str - team: str | None = None - blocks: list[Block] | None = None - channel: str - event_ts: str - thread_ts: str | None = None - - -class SlackWebhookPayload(BaseModel): - token: str | None = Field(None) - team_id: str | None = Field(None) - api_app_id: str | None = Field(None) - event: SlackEvent | None = Field(None) - type: str | None = Field(None) - event_id: str | None = Field(None) - event_time: int | None = Field(None) - challenge: str | None = Field(None) - subtype: str | None = Field(None) - - -class SlackMessageReaction(BaseModel): - """Model for a reaction on a Slack message.""" - - name: str - users: list[str] - count: int - - -class SlackMessage(BaseModel): - """Model for a message in a Slack conversation.""" - - user: str - type: str - ts: str - client_msg_id: str | None = None - text: str - team: str | None = None - blocks: list[Block] | None = None - language: dict | None = None - reactions: list[SlackMessageReaction] | None = None - thread_ts: str | None = None - reply_count: int | None = None - reply_users_count: int | None = None - latest_reply: str | None = None - reply_users: list[str] | None = None - is_locked: bool | None = None - subscribed: bool | None = None - parent_user_id: str | None = None diff --git a/src/codegen/extensions/swebench/README.md b/src/codegen/extensions/swebench/README.md deleted file mode 100644 index 12063180d..000000000 --- a/src/codegen/extensions/swebench/README.md +++ /dev/null @@ -1,29 +0,0 @@ -## Codegen Harness and Evaluator for SWE Bennch Development Tool - -This folder contains a harness and evaluator for the SWE Bench leaderboard, and enables developers to test and evaluate their codegen models on the SWE Bench leaderboard. - -It integrates directly into the Codegen agentic framework and can be built on top of. - -### Setup - -Remember to install all the dependencies for the environment. - -### Usage - -#### Edit agent.py, your codegen agent - -This file contains the main logic for the agent. - -The agent taps into the tree sitter using codegen. You can modify this by adding additional tools, extending its capabilities, prompts, and more. - -It is invoked in the harness script. - -#### Run harness.py to run the agent - -This script will gather the correct dataset, run the agent, and save the results. - -#### Run report.py to generate a report - -This script will generate a report from the results. It will loop through all the results and generate a report to evaluate each. Currently, there is an error in the docker image. - -There are currently example predictions in the `predictions/results` folder. diff --git a/src/codegen/extensions/swebench/__init__.py b/src/codegen/extensions/swebench/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/extensions/swebench/enums.py b/src/codegen/extensions/swebench/enums.py deleted file mode 100644 index 0cf3a484a..000000000 --- a/src/codegen/extensions/swebench/enums.py +++ /dev/null @@ -1,13 +0,0 @@ -from enum import Enum - - -class SWEBenchDataset(Enum): - LITE = "princeton-nlp/SWE-bench_Lite" - FULL = "princeton-nlp/SWE-bench" - VERIFIED = "princeton-nlp/SWE-bench-verified" - - -class SWEBenchLiteSubset(Enum): - LITE_SMALL = "lite_small" - LITE_MEDIUM = "lite_medium" - LITE_LARGE = "lite_large" diff --git a/src/codegen/extensions/swebench/harness.py b/src/codegen/extensions/swebench/harness.py deleted file mode 100644 index 456c52fca..000000000 --- a/src/codegen/extensions/swebench/harness.py +++ /dev/null @@ -1,200 +0,0 @@ -"""This is the harness for running an AI agent on the SWE Bench dataset.""" - -#!/usr/bin/env python -import json -import pprint -import random -import subprocess -import sys -from pathlib import Path - -import lox - -from codegen import Codebase -from codegen.configs.models.codebase import CodebaseConfig -from codegen.extensions.swebench.utils import ( - SweBenchExample, - get_swe_bench_examples, - load_predictions, -) - -PARENT_DIR = Path(__file__).parent - -PREDS_DNAME = PARENT_DIR / "predictions" - - -def diff_versus_commit(git_dname, commit): - """Take a diff of `git_dname` current contents versus the `commit`.""" - diff_cmd = f"git -C {git_dname} diff {commit}" - diff_output = subprocess.check_output(diff_cmd.split()).decode() - return diff_output - - -def files_in_patch(patch): - """Extract the list of modified files from a unified diff patch string.""" - files = [] - for line in patch.split("\n"): - if line.startswith("--- a/") or line.startswith("+++ b/"): - fname = line.split("/", 1)[1] - if fname not in files: - files.append(fname) - return files - - -def show_problems(dataset): - """Print out all the instance_id and problem_descriptions.""" - for inst, entry in dataset.items(): - problem = entry.problem_statement.splitlines()[0] - print(f"{inst}: {problem}") - - -def run_agent_on_entry(entry: SweBenchExample, model: str, codebase: Codebase | None = None, run_id: str | None = None): - """Process one `entry` from SWE Bench using the LLM `models` at the - given `temperature`. Set `model_name_or_path` in the result json. - """ - instance_id = entry.instance_id - base_commit = entry.base_commit - - print("=" * 60) - pprint.pprint(instance_id) - print("=" * 60) - problem_statement = entry.problem_statement - print(problem_statement) - - gold_files = files_in_patch(entry.patch) - - if codebase is None: - config = CodebaseConfig( - disable_file_parse=True, # Disable the graph AND disable file parsing (file.edit only) - ) - codebase = Codebase.from_repo(repo_full_name=entry.repo, commit=base_commit, language="python", config=config) # check out the repo - - metadata = {"run_id": run_id, "instance_id": instance_id, "difficulty": f"difficulty_{entry.difficulty}"} - tags = [str(value) for value in metadata.values()] - # agent = CodeAgent(codebase=codebase, tags=tags, metadata=metadata) - - pprint.pprint(instance_id) - pprint.pprint(gold_files) - - message = """Below is a real GitHub issue from a popular GitHub repository. -The issue was filed some time ago. -The repo has been checked out at the commit that existed at the moment the issue was filed. -If you are already familiar with this repo, be cautious! -You are working with an old version of the repo! -Filenames, directory names, file contents, etc may be different than what you're used to. - -Propose changes to update the repo to fix the problem below. -*** IMPORTANT: *** DO NOT MODIFY ANY TESTS! -*** IMPORTANT: *** DO NOT ADD ANY TESTS! - -Before commiting to do any modifications, double check your work with the Reflection tool. -you can also use that tool to check your work after you think you are done. -if you ever get stuck using other tools, use the Reflection tool to re asses your situation. -after every file edit, use the Reflection tool to check your work and sanity check yourself. -after editing a file you need to double check your work and use the ViewFiles tool to make sure you didn't break anything and that your edits are indeed correct. - -You should follow the advices of the Reflection tool when ever they seem reasonable. - -Also DO NOT ADD OR EDIT ANY TESTS! - -""" - message += problem_statement - - try: - pass - # result = agent.run(prompt=message) - except Exception as agent_error: - pprint.pprint(f"Instance ID: {instance_id} terminated with error: {agent_error}") - raise agent_error - - # Get the diff between the current state and the original commit - model_patch = codebase.get_diff(base=base_commit) - pprint.pprint(model_patch) - - # Record the results for the logs - result = dict( - # Required args for running eval tests - instance_id=instance_id, - model_patch=model_patch, - # For computing stats - gold_files=gold_files, - edited_files=files_in_patch(model_patch), - ) - - # Did we get a successful patch? - if not model_patch: - pprint.pprint("=" * 60) - pprint.pprint("Failed to generate a patch") - pprint.pprint("=" * 60) - - return result - - -def process_instances(dataset: dict[str, SweBenchExample], threads: int): - """Dataset - The subset of the SWE Bench dataset to process. - threads - How many problems to attempt concurrently. - prior_dnames - Names of predictions/ dirnames from previous runs. - If they contain a plausible solution for an instance, - don't continue looking. - """ - # Create the predictions directory if it doesn't exist - PREDS_DNAME.mkdir(exist_ok=True) - out_dname = PREDS_DNAME / "results" - out_dname.mkdir(exist_ok=True) - - pprint.pprint(out_dname) - - # If we are restarting this run, figure out which instances are already done. - done_preds = load_predictions([out_dname]) - done_instances = set(done_preds.keys()) - pprint.pprint(len(done_instances)) - - all_instances = set(dataset.keys()) - - remaining_instances = set(all_instances) - remaining_instances -= done_instances - - remaining_instances = list(remaining_instances) - random.shuffle(remaining_instances) - - pprint.pprint(sorted(remaining_instances)) - pprint.pprint(len(remaining_instances)) - - print() - print("press enter...") - input() - - if threads > 1: - process_one_instance_lox = lox.process(threads)(run_agent_on_entry) - process_one_instance_func = process_one_instance_lox.scatter - gather = process_one_instance_lox.gather - else: - process_one_instance_func = run_agent_on_entry - - for instance_id in remaining_instances: - if instance_id in done_instances: - print("skipping", instance_id) - continue - - result = process_one_instance_func( - dataset[instance_id], - ) - with open(out_dname / f"{instance_id}.json", "w") as f: - json.dump(result, f) - - print("#" * 60) - # input() - - if threads > 1: - gather() - - -def main(): - # Load the SWE Bench dataset - dataset = {example.instance_id: example for example in get_swe_bench_examples()} - process_instances(dataset, threads=10) - - -if __name__ == "__main__": - status = main() - sys.exit(status) diff --git a/src/codegen/extensions/swebench/report.py b/src/codegen/extensions/swebench/report.py deleted file mode 100755 index f8100e36d..000000000 --- a/src/codegen/extensions/swebench/report.py +++ /dev/null @@ -1,154 +0,0 @@ -#!/usr/bin/env python - -import json -import subprocess -from collections import defaultdict -from pathlib import Path - -from codegen.extensions.swebench.enums import SWEBenchDataset -from codegen.extensions.swebench.tests import remove_patches_to_tests - -NUM_EVAL_PROCS = 5 - - -def run_evals(predictions_jsonl, logs_dir: Path, dataset: SWEBenchDataset, run_id: str): - """Run the evaluations on the predictions on modal.""" - run_evals_cmd = f""" -python -m swebench.harness.run_evaluation - --predictions_path {predictions_jsonl} - --run_id {run_id} - --dataset_name {dataset.value} - --cache_level instance - --report_dir {logs_dir} - --modal true -""" - run_evals_cmd = " ".join([line.strip() for line in run_evals_cmd.split() if line.strip()]) - print("Running evaluation command:", run_evals_cmd) - - subprocess.run(run_evals_cmd.split(), check=True) - - -def get_report(predictions_jsonl, logs_dir: Path): - # Load and parse the evaluation results directly from the predictions file - results = defaultdict(list) - - with open(predictions_jsonl) as f: - for line in f: - pred = json.loads(line) - instance_id = pred["instance_id"] - - # Track basic stats - results["generated"].append(instance_id) - - # Check for evaluation logs - log_file = logs_dir / f"{instance_id}.eval.log" - if log_file.exists(): - results["with_logs"].append(instance_id) - log_content = log_file.read_text() - - if "PASS" in log_content: - results["resolved"].append(instance_id) - results["applied"].append(instance_id) - elif "FAIL" in log_content: - results["applied"].append(instance_id) - else: - results["no_apply"].append(instance_id) - else: - results["no_logs"].append(instance_id) - - # Convert lists to sets for compatibility with existing code - return {k: set(v) for k, v in results.items()} - - -def update_pred_json(predictions, report, predictions_dir: Path): - all_instances = set(report.get("generated", [])) - all_instances.update(set(report.get("no_generation", []))) - - for instance_id, pred in predictions.items(): - # Use get() to handle missing 'resolved' key, defaulting to empty set - was_resolved = instance_id in report.get("resolved", set()) - if "resolved" in pred and pred["resolved"] == was_resolved: - continue - - assert instance_id in all_instances, instance_id - - pred["resolved"] = was_resolved - save = dict(pred) - - # Construct json_fname if it doesn't exist - if "json_fname" not in pred: - json_fname = predictions_dir / f"{instance_id}.json" - else: - json_fname = pred["json_fname"] - del save["json_fname"] # Remove from save data if it exists - - Path(json_fname).write_text(json.dumps(save, indent=4)) - - return predictions - - -def preds_to_jsonl(predictions, predictions_dir: Path): - dname = predictions_dir - - predictions_jsonl = str(dname / "all_preds.jsonl") - print(f"Creating JSONL file: {predictions_jsonl}") - - # Use a default model name since it's not in the predictions - model_name = "results" - - with open(predictions_jsonl, "w") as fh: - for inst, pred in predictions.items(): - minimal_pred = { - "model_name_or_path": model_name, # Use default model name - "model_patch": remove_patches_to_tests(pred["model_patch"]) if "model_patch" in pred else pred.get("patch", ""), - "instance_id": pred["instance_id"], - } - fh.write(json.dumps(minimal_pred) + "\n") - return predictions_jsonl - - -def generate_report(predictions_dir: Path, logs_dir: Path, dataset: SWEBenchDataset, run_id: str): - # Automatically find all JSON files in predictions/results - if not predictions_dir.exists(): - print(f"Directory does not exist: {predictions_dir}") - return 1 - - predictions_jsonl = predictions_dir / "all_preds.jsonl" - existing_preds = predictions_jsonl.exists() - prediction_files = list(predictions_dir.glob("*.json")) - print(f"Found {len(prediction_files)} prediction files") - - predictions = {} - for file_path in prediction_files: - try: - with open(file_path) as f: - prediction = json.load(f) - if isinstance(prediction, dict) and "instance_id" in prediction: - predictions[prediction["instance_id"]] = prediction - except json.JSONDecodeError: - print(f"Error reading JSON from {file_path}") - continue - if not existing_preds: - if not predictions: - print("No valid predictions found") - return 1 - - print(f"Successfully loaded {len(predictions)} predictions") - - predictions_jsonl = preds_to_jsonl(predictions, predictions_dir) - - # Setup log directory - log_dir = logs_dir / "results" - log_dir.mkdir(exist_ok=True, parents=True) - print(f"Using log directory: {log_dir}") - - # Run evaluations - run_evals(predictions_jsonl, logs_dir, dataset, run_id) - - # Get and display report - report = get_report(predictions_jsonl, logs_dir) - - # Update prediction JSONs with results - predictions = update_pred_json(predictions, report, predictions_dir) - - return 0 diff --git a/src/codegen/extensions/swebench/subsets.py b/src/codegen/extensions/swebench/subsets.py deleted file mode 100644 index a2f522ffe..000000000 --- a/src/codegen/extensions/swebench/subsets.py +++ /dev/null @@ -1,146 +0,0 @@ -from codegen.extensions.swebench.enums import SWEBenchLiteSubset - -SMALL_LITE_SUBSET = [ - "mwaskom__seaborn-2848", - "sphinx-doc__sphinx-8627", - "sphinx-doc__sphinx-7975", - "django__django-17087", - "sympy__sympy-17655", - "matplotlib__matplotlib-26020", - "sympy__sympy-20154", - "scikit-learn__scikit-learn-13439", - "pytest-dev__pytest-7373", - "django__django-16527", -] - -MEDIUM_LITE_SUBSET = [ - "sympy__sympy-15346", - "sympy__sympy-16281", - "sympy__sympy-22840", - "pytest-dev__pytest-7220", - "django__django-12284", - "pytest-dev__pytest-7490", - "matplotlib__matplotlib-25442", - "django__django-13757", - "django__django-15790", - "sympy__sympy-18532", - "sympy__sympy-13471", - "scikit-learn__scikit-learn-15535", - "django__django-13447", - "django__django-15789", - "scikit-learn__scikit-learn-14894", - "django__django-14238", - "django__django-10914", - "pytest-dev__pytest-11143", - "django__django-16255", - "django__django-13658", -] - -LARGE_LITE_SUBSET = [ - "pytest-dev__pytest-5495", - "django__django-11797", - "django__django-14730", - "scikit-learn__scikit-learn-25500", - "sphinx-doc__sphinx-8506", - "django__django-16408", - "django__django-16910", - "sympy__sympy-12236", - "matplotlib__matplotlib-24265", - "django__django-15320", - "matplotlib__matplotlib-25311", - "django__django-12125", - "django__django-12747", - "matplotlib__matplotlib-24334", - "scikit-learn__scikit-learn-14983", - "scikit-learn__scikit-learn-13497", - "django__django-14580", - "pylint-dev__pylint-6506", - "matplotlib__matplotlib-23987", - "scikit-learn__scikit-learn-13497", - "django__django-14017", - "django__django-15213", - "django__django-12284", - "pylint-dev__pylint-7114", - "django__django-11422", - "django__django-11620", - "django__django-12284", - "sympy__sympy-13971", - "django__django-12284", - "sphinx-doc__sphinx-7975", - "scikit-learn__scikit-learn-15512", - "scikit-learn__scikit-learn-15512", - "pylint-dev__pylint-7993", - "django__django-12184", - "django__django-13315", - "sympy__sympy-15609", - "pylint-dev__pylint-7993", - "sympy__sympy-17022", - "pylint-dev__pylint-7993", - "sympy__sympy-15678", - "sympy__sympy-18057", - "sympy__sympy-17655", - "sympy__sympy-17655", - "django__django-13028", - "sympy__sympy-17139", - "django__django-14999", - "django__django-15790", - "scikit-learn__scikit-learn-11281", - "astropy__astropy-12907", - "django__django-11815", - "sympy__sympy-18621", - "django__django-11999", - "sphinx-doc__sphinx-8721", - "matplotlib__matplotlib-23314", - "sphinx-doc__sphinx-8721", - "sympy__sympy-18621", - "django__django-12497", - "scikit-learn__scikit-learn-13584", - "matplotlib__matplotlib-24970", - "scikit-learn__scikit-learn-13584", - "django__django-12453", - "sympy__sympy-20154", - "django__django-13447", - "sphinx-doc__sphinx-8595", - "sympy__sympy-20154", - "sympy__sympy-20154", - "django__django-12700", - "psf__requests-2317", - "django__django-16046", - "sympy__sympy-20154", - "sympy__sympy-20212", - "django__django-13710", - "sympy__sympy-13647", - "django__django-15851", - "scikit-learn__scikit-learn-14894", - "sympy__sympy-24213", - "scikit-learn__scikit-learn-13779", - "django__django-13710", - "django__django-13933", - "sympy__sympy-20212", - "django__django-14855", - "django__django-11039", - "django__django-16379", - "pydata__xarray-5131", - "pytest-dev__pytest-7373", - "django__django-16139", - "django__django-14382", - "pytest-dev__pytest-5227", - "django__django-16595", - "django__django-16379", - "django__django-16527", - "django__django-13658", - "django__django-16255", - "django__django-16527", - "django__django-13658", - "django__django-13658", - "django__django-13658", - "django__django-11099", - "django__django-16527", - "django__django-11099", -] - -LITE_SUBSETS = { - SWEBenchLiteSubset.LITE_SMALL: SMALL_LITE_SUBSET, - SWEBenchLiteSubset.LITE_MEDIUM: MEDIUM_LITE_SUBSET, - SWEBenchLiteSubset.LITE_LARGE: LARGE_LITE_SUBSET, -} diff --git a/src/codegen/extensions/swebench/success_rates.py b/src/codegen/extensions/swebench/success_rates.py deleted file mode 100644 index 2d3cbbdf1..000000000 --- a/src/codegen/extensions/swebench/success_rates.py +++ /dev/null @@ -1,302 +0,0 @@ -LITE_SUCCESS_RATES = { - "pallets__flask-5063": 0.0, - "sphinx-doc__sphinx-8282": 0.0, - "django__django-14667": 0.0, - "sphinx-doc__sphinx-8474": 0.0, - "sympy__sympy-11400": 0.0, - "sympy__sympy-11870": 0.0, - "sympy__sympy-11897": 0.0, - "sympy__sympy-12171": 0.0, - "sympy__sympy-12236": 0.0, - "sympy__sympy-13146": 0.0, - "sympy__sympy-13773": 0.0, - "sympy__sympy-13895": 0.0, - "django__django-13220": 0.0, - "sympy__sympy-13915": 0.0, - "sympy__sympy-14024": 0.0, - "sympy__sympy-14308": 0.0, - "django__django-14730": 0.0, - "sphinx-doc__sphinx-7738": 0.0, - "sphinx-doc__sphinx-7686": 0.0, - "django__django-14997": 0.0, - "matplotlib__matplotlib-25079": 0.0, - "pydata__xarray-4493": 0.0, - "matplotlib__matplotlib-22835": 0.0, - "matplotlib__matplotlib-18869": 0.0, - "pylint-dev__pylint-7228": 0.0, - "pytest-dev__pytest-5103": 0.0, - "pytest-dev__pytest-5221": 0.0, - "sympy__sympy-14317": 0.0, - "django__django-16820": 0.0, - "django__django-16229": 0.0, - "pytest-dev__pytest-9359": 0.0, - "scikit-learn__scikit-learn-10508": 0.0, - "scikit-learn__scikit-learn-10949": 0.0, - "scikit-learn__scikit-learn-11040": 0.0, - "django__django-15695": 0.0, - "scikit-learn__scikit-learn-25638": 0.0, - "django__django-16816": 0.0, - "sympy__sympy-15308": 0.0, - "matplotlib__matplotlib-25433": 0.0, - "sympy__sympy-18087": 0.0, - "astropy__astropy-7746": 0.0, - "django__django-11630": 0.0, - "sympy__sympy-18199": 0.0, - "sympy__sympy-23191": 0.0, - "sympy__sympy-17630": 0.0, - "sympy__sympy-19254": 0.0, - "sympy__sympy-21627": 0.0, - "sympy__sympy-16281": 0.0, - "sympy__sympy-16106": 0.0, - "sympy__sympy-24102": 0.0, - "django__django-11905": 0.0, - "sympy__sympy-21171": 0.0, - "sympy__sympy-20639": 0.0, - "django__django-12589": 0.0, - "sympy__sympy-20322": 0.0, - "django__django-11564": 0.0, - "django__django-11019": 0.0, - "django__django-16910": 0.02, - "django__django-15252": 0.02, - "pytest-dev__pytest-5413": 0.02, - "django__django-11742": 0.02, - "sphinx-doc__sphinx-8273": 0.02, - "pytest-dev__pytest-8906": 0.02, - "django__django-15996": 0.02, - "sympy__sympy-19007": 0.02, - "django__django-11910": 0.02, - "matplotlib__matplotlib-22711": 0.02, - "django__django-13768": 0.02, - "astropy__astropy-14182": 0.02, - "mwaskom__seaborn-3407": 0.02, - "pallets__flask-4045": 0.02, - "django__django-12908": 0.02, - "pallets__flask-4992": 0.02, - "pydata__xarray-3364": 0.02, - "sympy__sympy-16503": 0.02, - "django__django-15738": 0.02, - "pydata__xarray-4248": 0.02, - "django__django-13265": 0.02, - "sympy__sympy-13177": 0.02, - "django__django-13448": 0.02, - "django__django-12113": 0.02, - "sympy__sympy-13043": 0.02, - "sympy__sympy-12454": 0.02, - "sympy__sympy-13437": 0.02, - "django__django-16408": 0.03, - "pytest-dev__pytest-6116": 0.03, - "pytest-dev__pytest-8365": 0.03, - "psf__requests-2148": 0.03, - "sympy__sympy-21612": 0.03, - "astropy__astropy-14365": 0.03, - "matplotlib__matplotlib-23299": 0.03, - "django__django-11283": 0.03, - "django__django-14155": 0.03, - "sphinx-doc__sphinx-8506": 0.03, - "django__django-11797": 0.03, - "sympy__sympy-18698": 0.03, - "django__django-15320": 0.03, - "sphinx-doc__sphinx-10451": 0.03, - "django__django-15388": 0.03, - "sympy__sympy-20049": 0.03, - "django__django-15781": 0.05, - "django__django-13321": 0.05, - "sympy__sympy-18835": 0.05, - "django__django-14534": 0.05, - "matplotlib__matplotlib-24265": 0.05, - "django__django-15202": 0.05, - "django__django-12856": 0.05, - "matplotlib__matplotlib-23476": 0.05, - "django__django-15061": 0.05, - "sphinx-doc__sphinx-11445": 0.06, - "django__django-12470": 0.06, - "django__django-16400": 0.06, - "sympy__sympy-15346": 0.06, - "pytest-dev__pytest-5495": 0.06, - "sphinx-doc__sphinx-8801": 0.08, - "matplotlib__matplotlib-23563": 0.08, - "sympy__sympy-21379": 0.08, - "django__django-15819": 0.08, - "mwaskom__seaborn-2848": 0.08, - "scikit-learn__scikit-learn-25500": 0.08, - "sympy__sympy-12419": 0.08, - "django__django-12308": 0.09, - "sympy__sympy-14396": 0.09, - "sympy__sympy-15345": 0.09, - "sympy__sympy-19487": 0.09, - "pytest-dev__pytest-7168": 0.09, - "scikit-learn__scikit-learn-25747": 0.09, - "matplotlib__matplotlib-25498": 0.11, - "sympy__sympy-22840": 0.11, - "sphinx-doc__sphinx-8627": 0.11, - "pydata__xarray-4094": 0.11, - "pytest-dev__pytest-7220": 0.11, - "django__django-12747": 0.11, - "sympy__sympy-13031": 0.12, - "django__django-13660": 0.12, - "scikit-learn__scikit-learn-14983": 0.12, - "sphinx-doc__sphinx-8435": 0.14, - "sympy__sympy-20590": 0.14, - "scikit-learn__scikit-learn-14087": 0.14, - "sympy__sympy-24909": 0.14, - "django__django-15400": 0.14, - "matplotlib__matplotlib-25311": 0.14, - "pylint-dev__pylint-6506": 0.15, - "django__django-12125": 0.15, - "matplotlib__matplotlib-24334": 0.15, - "scikit-learn__scikit-learn-13497": 0.17, - "sympy__sympy-16792": 0.17, - "django__django-14580": 0.17, - "pylint-dev__pylint-7080": 0.18, - "matplotlib__matplotlib-25332": 0.18, - "sympy__sympy-22005": 0.18, - "sympy__sympy-20442": 0.2, - "django__django-13551": 0.2, - "sympy__sympy-14817": 0.2, - "matplotlib__matplotlib-23987": 0.2, - "django__django-13033": 0.21, - "sphinx-doc__sphinx-7975": 0.21, - "django__django-13925": 0.23, - "sphinx-doc__sphinx-10325": 0.23, - "sympy__sympy-16988": 0.23, - "pytest-dev__pytest-7490": 0.24, - "django__django-15213": 0.24, - "django__django-12284": 0.24, - "pytest-dev__pytest-11148": 0.24, - "django__django-11964": 0.24, - "pylint-dev__pylint-7114": 0.26, - "django__django-11422": 0.26, - "django__django-14017": 0.27, - "django__django-15902": 0.27, - "django__django-10924": 0.27, - "django__django-13158": 0.29, - "django__django-11620": 0.29, - "sympy__sympy-13971": 0.29, - "django__django-15498": 0.3, - "django__django-12184": 0.3, - "django__django-13964": 0.3, - "psf__requests-1963": 0.3, - "matplotlib__matplotlib-25442": 0.3, - "django__django-13757": 0.32, - "scikit-learn__scikit-learn-15512": 0.32, - "sympy__sympy-21614": 0.33, - "sympy__sympy-15609": 0.33, - "matplotlib__matplotlib-23562": 0.33, - "django__django-13315": 0.33, - "django__django-11848": 0.35, - "django__django-17087": 0.35, - "matplotlib__matplotlib-26011": 0.36, - "sympy__sympy-21055": 0.36, - "sympy__sympy-17022": 0.36, - "pylint-dev__pylint-7993": 0.36, - "astropy__astropy-6938": 0.38, - "sympy__sympy-15678": 0.38, - "django__django-17051": 0.38, - "scikit-learn__scikit-learn-14092": 0.38, - "pylint-dev__pylint-5859": 0.39, - "django__django-14411": 0.39, - "django__django-11001": 0.41, - "astropy__astropy-12907": 0.41, - "sympy__sympy-18057": 0.42, - "sympy__sympy-23262": 0.44, - "sympy__sympy-18189": 0.44, - "sympy__sympy-17139": 0.45, - "django__django-15790": 0.45, - "django__django-14999": 0.45, - "sympy__sympy-18532": 0.47, - "scikit-learn__scikit-learn-11281": 0.47, - "django__django-12915": 0.47, - "sympy__sympy-12481": 0.47, - "sympy__sympy-24066": 0.48, - "django__django-11815": 0.48, - "django__django-13028": 0.48, - "sympy__sympy-17655": 0.48, - "django__django-12708": 0.48, - "matplotlib__matplotlib-24970": 0.5, - "mwaskom__seaborn-3190": 0.52, - "scikit-learn__scikit-learn-13142": 0.52, - "matplotlib__matplotlib-26020": 0.53, - "scikit-learn__scikit-learn-15535": 0.53, - "sympy__sympy-13471": 0.53, - "sympy__sympy-15011": 0.53, - "psf__requests-3362": 0.55, - "matplotlib__matplotlib-24149": 0.55, - "matplotlib__matplotlib-23314": 0.55, - "django__django-14608": 0.56, - "scikit-learn__scikit-learn-13241": 0.56, - "scikit-learn__scikit-learn-25570": 0.56, - "sympy__sympy-18621": 0.56, - "scikit-learn__scikit-learn-13584": 0.56, - "django__django-13401": 0.58, - "pytest-dev__pytest-5692": 0.58, - "django__django-14787": 0.58, - "django__django-15814": 0.58, - "sphinx-doc__sphinx-8721": 0.58, - "django__django-14016": 0.58, - "django__django-11999": 0.59, - "django__django-12497": 0.59, - "psf__requests-2674": 0.59, - "matplotlib__matplotlib-23913": 0.59, - "pytest-dev__pytest-7432": 0.59, - "django__django-11049": 0.59, - "sympy__sympy-22714": 0.62, - "scikit-learn__scikit-learn-12471": 0.62, - "psf__requests-863": 0.62, - "django__django-14672": 0.62, - "sympy__sympy-20154": 0.62, - "django__django-13590": 0.64, - "django__django-12700": 0.64, - "sphinx-doc__sphinx-8595": 0.64, - "django__django-15789": 0.65, - "django__django-12453": 0.68, - "django__django-13447": 0.68, - "psf__requests-2317": 0.7, - "django__django-11583": 0.7, - "django__django-16046": 0.7, - "django__django-14238": 0.71, - "django__django-15851": 0.71, - "django__django-13710": 0.73, - "sympy__sympy-21847": 0.73, - "sympy__sympy-23117": 0.73, - "django__django-12983": 0.73, - "scikit-learn__scikit-learn-13779": 0.74, - "sympy__sympy-13647": 0.74, - "django__django-16041": 0.74, - "scikit-learn__scikit-learn-10297": 0.74, - "django__django-15347": 0.74, - "scikit-learn__scikit-learn-13496": 0.74, - "sympy__sympy-20212": 0.76, - "scikit-learn__scikit-learn-13439": 0.76, - "django__django-13933": 0.76, - "django__django-12286": 0.76, - "django__django-13230": 0.77, - "astropy__astropy-14995": 0.77, - "django__django-11179": 0.77, - "sphinx-doc__sphinx-8713": 0.77, - "sympy__sympy-24213": 0.77, - "matplotlib__matplotlib-23964": 0.79, - "scikit-learn__scikit-learn-14894": 0.79, - "django__django-10914": 0.8, - "pydata__xarray-5131": 0.8, - "django__django-11039": 0.82, - "pytest-dev__pytest-7373": 0.82, - "django__django-14915": 0.82, - "django__django-16595": 0.83, - "pytest-dev__pytest-11143": 0.85, - "sympy__sympy-14774": 0.85, - "pytest-dev__pytest-5227": 0.85, - "django__django-16873": 0.85, - "django__django-16139": 0.85, - "mwaskom__seaborn-3010": 0.86, - "django__django-14382": 0.86, - "django__django-14752": 0.86, - "sympy__sympy-13480": 0.86, - "django__django-16379": 0.86, - "sympy__sympy-24152": 0.88, - "django__django-14855": 0.88, - "django__django-11133": 0.88, - "django__django-11099": 0.91, - "django__django-13658": 0.91, - "django__django-16255": 0.91, - "django__django-16527": 0.91, -} diff --git a/src/codegen/extensions/swebench/tests.py b/src/codegen/extensions/swebench/tests.py deleted file mode 100755 index 9233f0c07..000000000 --- a/src/codegen/extensions/swebench/tests.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python - -# A no-op patch which creates an empty file is used to stand in for -# the `model_patch` and/or `test_patch` when running SWE Bench tests -# without one or both of those patches. -NOOP_PATCH = "diff --git a/empty.file.{nonce}.ignore b/empty.file.{nonce}.ignore\nnew file mode 100644\nindex 0000000..e69de29\n" - - -def remove_patches_to_tests(model_patch): - """Remove any changes to the tests directory from the provided patch. - This is to ensure that the model_patch does not disturb the repo's - tests when doing acceptance testing with the `test_patch`. - """ - if not model_patch: - return model_patch - - lines = model_patch.splitlines(keepends=True) - filtered_lines = [] - is_tests = False - - for line in lines: - if line.startswith("diff --git a/"): - pieces = line.split() - to = pieces[-1] - if to.startswith("b/") and ("/test/" in to or "/tests/" in to or "/testing/" in to or "/test_" in to or "/tox.ini" in to): - is_tests = True - else: - is_tests = False - - if not is_tests: - filtered_lines.append(line) - - return "".join(filtered_lines) diff --git a/src/codegen/extensions/swebench/utils.py b/src/codegen/extensions/swebench/utils.py deleted file mode 100644 index c5054b2d0..000000000 --- a/src/codegen/extensions/swebench/utils.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -from dataclasses import dataclass -from pathlib import Path -from pprint import pprint -from typing import Literal, Optional - -from datasets import load_dataset - -from codegen.extensions.swebench.enums import SWEBenchDataset, SWEBenchLiteSubset -from codegen.extensions.swebench.subsets import LITE_SUBSETS -from codegen.extensions.swebench.success_rates import LITE_SUCCESS_RATES - - -@dataclass -class SweBenchExample: - """A single example from the SWE-bench dataset.""" - - repo: str - instance_id: str - base_commit: str - patch: str - test_patch: str - problem_statement: str - hints_text: Optional[str] - created_at: str - version: str - fail_to_pass: str - pass_to_pass: Optional[str] - environment_setup_commit: Optional[str] - difficulty: Optional[int] - - -def load_predictions(paths): - prediction_paths = [] - for path in paths: - path = Path(path) - if path.is_file(): - prediction_paths.append(path) - elif path.is_dir(): - prediction_paths += list(path.glob("*.json")) - else: - assert False, path - - # prediction_paths.sort(key=lambda p: p.stat().st_mtime) - - predictions = dict() - for fname in prediction_paths: - try: - pred = json.loads(fname.read_text()) - except json.decoder.JSONDecodeError as err: - pprint(fname) - raise err - - if "instance_id" not in pred: - print("Skipping json without instance_id", fname) - continue - - inst = pred["instance_id"] - pred["json_fname"] = str(fname) - predictions[inst] = pred - - return predictions - - -def get_difficulty(instance_id: str) -> int | None: - if instance_id in LITE_SUCCESS_RATES: - return 10 - int(LITE_SUCCESS_RATES[instance_id] * 10) - return None - - -def get_swe_bench_examples( - dataset: SWEBenchDataset | SWEBenchLiteSubset = SWEBenchLiteSubset.LITE_SMALL, - split: Literal["train", "dev", "test"] = "test", - length: int | None = None, - instance_id: str | None = None, - instance_ids: list[str] = [], - repo: str | None = None, -) -> list[SweBenchExample]: - """Fetch examples from the SWE-bench dataset using the datasets library. - - Args: - dataset: The dataset to use ("lite", "full", or "verified") - split: The dataset split to use - length: Number of examples to fetch - instance_id: Optional specific instance ID to fetch - instance_ids: Optional list of instance IDs to fetch - repo: Optional specific repo to fetch - - Returns: - List of SweBenchExample objects - """ - # Load the dataset with caching enabled - if isinstance(dataset, SWEBenchLiteSubset): - if instance_ids: - msg = "instance_ids is not supported for lite subsets. Please pass a list of instance IDs instead." - raise ValueError(msg) - swe_bench_dataset = load_dataset(SWEBenchDataset.LITE.value, download_mode="reuse_dataset_if_exists") - instance_ids = LITE_SUBSETS[dataset] - else: - swe_bench_dataset = load_dataset(dataset.value, download_mode="reuse_dataset_if_exists") - - # Get the requested split - split_data = swe_bench_dataset[split] - - # Convert to SweBenchExample objects - examples = [] - for row in split_data: - if instance_id and row["instance_id"] != instance_id: - continue - if repo and row["repo"] != repo: - continue - if instance_ids and row["instance_id"] not in instance_ids: - continue - - example = SweBenchExample( - repo=row["repo"], - instance_id=row["instance_id"], - base_commit=row["base_commit"], - patch=row["patch"], - test_patch=row["test_patch"], - problem_statement=row["problem_statement"], - hints_text=row.get("hints_text"), - created_at=row["created_at"], - version=row["version"], - fail_to_pass=row["FAIL_TO_PASS"], - pass_to_pass=row.get("PASS_TO_PASS"), - environment_setup_commit=row.get("environment_setup_commit"), - difficulty=get_difficulty(row["instance_id"]), - ) - examples.append(example) - - if length: - examples = examples[:length] - - return examples diff --git a/src/codegen/gsbuild/README.md b/src/codegen/gsbuild/README.md deleted file mode 100644 index f2a4c9a02..000000000 --- a/src/codegen/gsbuild/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Codegen GS Build - -A codegen module that builds the codegen SDK. diff --git a/src/codegen/gsbuild/build.py b/src/codegen/gsbuild/build.py deleted file mode 100644 index 55b695ffc..000000000 --- a/src/codegen/gsbuild/build.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -from pathlib import Path -from typing import Any - -from hatchling.builders.hooks.plugin.interface import BuildHookInterface - - -def update_init_file(file: Path) -> None: - path = Path(__file__).parent.parent.parent - sys.path.append(str(path)) - from codegen.gscli.generate.runner_imports import generate_exported_modules, get_runner_imports - - content = file.read_text() - content = get_runner_imports(include_codegen=False) + "\n" + content + "\n" + generate_exported_modules() - file.write_text(content) - - -class SpecialBuildHook(BuildHookInterface): - PLUGIN_NAME = "codegen_build" - - def initialize(self, version: str, build_data: dict[str, Any]) -> None: - file = Path(self.root) / "src" / "codegen" / "sdk" / "__init__.py" - update_init_file(file) - build_data["artifacts"].append(f"/{file}") diff --git a/src/codegen/gscli/README.md b/src/codegen/gscli/README.md deleted file mode 100644 index 9bcf652c3..000000000 --- a/src/codegen/gscli/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Codegen GS CLI - -This module to be moved out into `src/code_generation` diff --git a/src/codegen/gscli/__init__.py b/src/codegen/gscli/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/gscli/backend/__init__.py b/src/codegen/gscli/backend/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/gscli/backend/typestub_utils.py b/src/codegen/gscli/backend/typestub_utils.py deleted file mode 100644 index 4bab38674..000000000 --- a/src/codegen/gscli/backend/typestub_utils.py +++ /dev/null @@ -1,138 +0,0 @@ -import ast -import os -import re -from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor - -import astor - -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class MethodRemover(ast.NodeTransformer): - def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]): - self.conditions = conditions - - def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: - body = [] - - for child in node.body: - if not self.should_remove(child): - body.append(child) - else: - logger.debug("removing", child.name) - node.body = body - return self.generic_visit(node) - - def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef | None: - body = [] - for child in node.body: - if not (isinstance(child, ast.FunctionDef) and any(cond(child) for cond in self.conditions)): - body.append(child) - else: - logger.debug("removing", child.name) - node.body = body - return self.generic_visit(node) - - def should_remove(self, node: ast.FunctionDef | ast.AnnAssign) -> bool: - if isinstance(node, ast.FunctionDef): - return any(cond(node) for cond in self.conditions) - - return False - - -class FieldRemover(ast.NodeTransformer): - def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]): - self.conditions = conditions - - def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: - body = [] - for child in node.body: - if not self.should_remove(child): - body.append(child) - else: - if isinstance(child, ast.AnnAssign): - logger.debug("removing", child.target.id) - if isinstance(child, ast.Assign): - for target in child.targets: - logger.debug("removing", target.id) - node.body = body - return self.generic_visit(node) - - def should_remove(self, node: ast.AnnAssign | ast.Assign) -> bool: - if isinstance(node, ast.AnnAssign): - return any(cond(node) for cond in self.conditions) - - elif isinstance(node, ast.Assign): - if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): - # Check if it's a property annotation (e.g., var: property) - return any(cond(node) for cond in self.conditions) - return False - - -def _remove_methods(source: str, conditions: list[Callable[[ast.FunctionDef], bool]]) -> str: - tree = ast.parse(source) - transformer = MethodRemover(conditions) - modified_tree = transformer.visit(tree) - return astor.to_source(modified_tree) - - -def _remove_fields(source: str, conditions: list[Callable[[ast.FunctionDef], bool]]) -> str: - tree = ast.parse(source) - transformer = FieldRemover(conditions) - modified_tree = transformer.visit(tree) - return astor.to_source(modified_tree) - - -def _starts_with_underscore(node: ast.FunctionDef | ast.AnnAssign | ast.Assign) -> bool: - if isinstance(node, ast.FunctionDef): - return node.name.startswith("_") and (not node.name.startswith("__") and not node.name.endswith("__")) - elif isinstance(node, ast.Assign): - return node.targets[0].id.startswith("_") - elif isinstance(node, ast.AnnAssign): - return node.target.id.startswith("_") - return False - - -def _has_decorator(decorator_name: str) -> Callable[[ast.FunctionDef], bool]: - def test(node): - has = any(isinstance(d, ast.Name) and d.id == decorator_name for d in node.decorator_list) - # if (has): - # logger.debug(node.name, 'has decorator', [d.id for d in node.decorator_list]) - return has - - return test - - -def _matches_regex(pattern: str) -> Callable[[ast.FunctionDef], bool]: - return lambda node: re.match(pattern, node.name) is not None - - -def _strip_internal_symbols(file: str, root: str) -> None: - if file.endswith(".pyi"): - file_path = os.path.join(root, file) - with open(file_path) as f: - original_content = f.read() - - conditions = [ - _starts_with_underscore, - _has_decorator("noapidoc"), - ] - - modified_content = _remove_fields(original_content, [_starts_with_underscore]) - modified_content = _remove_methods(modified_content, conditions) - - if modified_content.strip().endswith(":"): - modified_content += " pass\n" - with open(file_path, "w") as f: - f.write(modified_content) - logger.debug(f"Typestub file {file_path} has been modified.") - - -def strip_internal_symbols(typing_directory: str) -> None: - with ThreadPoolExecutor() as exec: - for root, _, files in os.walk(typing_directory): - for file in files: - exec.submit(_strip_internal_symbols, file, root) diff --git a/src/codegen/gscli/backend/utils.py b/src/codegen/gscli/backend/utils.py deleted file mode 100644 index 39340c27b..000000000 --- a/src/codegen/gscli/backend/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -######################################################################################################################## -# MISC -######################################################################################################################## - - -def filepath_to_modulename(filepath: str) -> str: - """Used to convert a an app ref passed in as a filepath to a module""" - module = filepath.removesuffix(".py") - return module.replace("/", ".") diff --git a/src/codegen/gscli/cli.py b/src/codegen/gscli/cli.py deleted file mode 100644 index 0724c8702..000000000 --- a/src/codegen/gscli/cli.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/python - -import click - -from codegen.gscli.generate.commands import generate - - -@click.group() -def main() -> None: - pass - - -# ============= Import all command groups ============= -main.add_command(generate) - - -if __name__ == "__main__": - main() diff --git a/src/codegen/gscli/generate/__init__.py b/src/codegen/gscli/generate/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/gscli/generate/commands.py b/src/codegen/gscli/generate/commands.py deleted file mode 100644 index 3f96419ff..000000000 --- a/src/codegen/gscli/generate/commands.py +++ /dev/null @@ -1,242 +0,0 @@ -import json -import os -import re -import shutil - -import click -from termcolor import colored - -import codegen.sdk as sdk -from codegen.gscli.generate.runner_imports import _generate_runner_imports -from codegen.gscli.generate.system_prompt import get_system_prompt -from codegen.gscli.generate.utils import LanguageType, generate_builtins_file -from codegen.sdk.ai.client import get_openai_client -from codegen.sdk.code_generation.changelog_generation import generate_changelog -from codegen.sdk.code_generation.codegen_sdk_codebase import get_codegen_sdk_codebase -from codegen.sdk.code_generation.doc_utils.generate_docs_json import generate_docs_json -from codegen.sdk.code_generation.mdx_docs_generation import render_mdx_page_for_class -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -AUTO_GENERATED_COMMENT = "THE CODE BELOW IS AUTO GENERATED. UPDATE THE SNIPPET BY UPDATING THE SKILL" -CODE_SNIPPETS_REGEX = r"(?:```python\n(?:(?!```)[\s\S])*?\n```|(?:(?!)[\s\S])*?)" - - -@click.group() -def generate() -> None: - """Commands for running auto-generate commands, currently for typestubs, imports to include in runners, and docs""" - ... - - -@generate.command() -@click.argument("docs_dir", default="docs", required=False) -def docs(docs_dir: str) -> None: - """Compile new .MDX files for the auto-generated docs pages and write them to the file system. - To actually deploy these changes, you must commit and merge the changes into develop - - This will generate docs using the codebase locally, including any unstaged changes - """ - docs_dir = os.path.join(os.getcwd(), docs_dir) - generate_docs(docs_dir) - - -@generate.command() -@click.argument("imports_file", default="function_imports.py", required=False) -def runner_imports(imports_file: str) -> None: - """Generate imports to include in runner execution environment""" - _generate_runner_imports(imports_file) - - -@generate.command() -def typestubs() -> None: - """Generate typestubs for the the graphsitter Codebase module - The Codebase class and it's constituents contain methods that should not be exposed, i.e we have private methods - and private properties that we'd like to keep internal. So the way this works is we generate the typestubs and the remove - the "internal" symbols. For example we'll remove: - - "_" prefixed methods and properties - - methods with `@noapidocs` decorator - """ - _generate_codebase_typestubs() - - -def _generate_codebase_typestubs() -> None: - initial_dir = os.getcwd() - - # right now this command expects you to run it from here - if not initial_dir.endswith("codegen/codegen-backend"): - print(colored("Error: Must be in a directory ending with 'codegen/codegen-backend'", "red")) - exit(1) - - out_dir = os.path.abspath(os.path.join(initial_dir, "typings")) - frontend_typestubs_dir = os.path.abspath(os.path.join(initial_dir, os.pardir, "codegen-frontend/assets/typestubs/graphsitter")) - if os.path.isdir(out_dir): - # remove typings dir if it exists - shutil.rmtree(out_dir) - if os.path.isdir(frontend_typestubs_dir): - # remove typings dir if it exists - shutil.rmtree(frontend_typestubs_dir) - # generate typestubs in codegen-frontend/assets/typestubs/graphsitter using pyright - os.system("uv run pyright -p . --createstub codegen.sdk.core.codebase") - os.system("uv run pyright -p . --createstub codegen.git") - os.system("uv run pyright -p . --createstub networkx") - # also generate for codemod context model and all its nested models - os.system("uv run pyright -p . --createstub app.codemod.compilation.models.context") - os.system("uv run pyright -p . --createstub app.codemod.compilation.models.pr_options") - os.system("uv run pyright -p . --createstub app.codemod.compilation.models.github_named_user_context") - os.system("uv run pyright -p . --createstub app.codemod.compilation.models.pull_request_context") - os.system("uv run pyright -p . --createstub app.codemod.compilation.models.pr_part_context") - - # TODO fix this, to remove noapidoc and hidden methods - # right now it uses astor.to_source, which doesn't respect the generics, and breaks things - # strip_internal_symbols(frontend_typestubs_dir) - - # Autogenerate the builtins file based on what has apidoc, we use the same logic here as we do to generate the runner imports - generate_builtins_file(frontend_typestubs_dir + "/__builtins__.pyi", LanguageType.BOTH) - generate_builtins_file(frontend_typestubs_dir + "/__builtins__python__.pyi", LanguageType.PYTHON) - generate_builtins_file(frontend_typestubs_dir + "/__builtins__typescript__.pyi", LanguageType.TYPESCRIPT) - - if os.path.isdir(out_dir): - # remove typings dir if it exists - shutil.rmtree(out_dir) - - -def generate_docs(docs_dir: str) -> None: - """Compile new .MDX files for the auto-generated docs pages and write them to the file system. - To actually deploy these changes, you must commit and merge the changes into develop - - This will generate docs using the codebase locally, including any unstaged changes - """ - generate_codegen_sdk_docs(docs_dir) - - -@generate.command() -@click.argument("filepath", default=sdk.__path__[0] + "/system-prompt.txt", required=False) -def system_prompt(filepath: str) -> None: - print(f"Generating system prompt and writing to {filepath}...") - new_system_prompt = get_system_prompt() - with open(filepath, "w") as f: - f.write(new_system_prompt) - print(f"Successfully wrote system prompt to {filepath}.") - - -def get_snippet_pattern(target_name: str) -> str: - pattern = rf"\[//\]: # \(--{re.escape(target_name)}--\)\s*(?:\[//\]: # \(--{re.escape(AUTO_GENERATED_COMMENT)}--\)\s*)?" - pattern += CODE_SNIPPETS_REGEX - return pattern - - -def generate_codegen_sdk_docs(docs_dir: str) -> None: - """Generate the docs for the codegen_sdk API and update the mint.json""" - print(colored("Generating codegen_sdk docs", "green")) - - # Generate docs page for codebase api and write to the file system - codebase = get_codegen_sdk_codebase() - gs_docs = generate_docs_json(codebase, "HEAD") - - # Prepare the directories for the new docs - # Delete existing documentation directories if they exist - # So we remove generated docs for any classes which no longer exist - python_docs_dir = os.path.join(docs_dir, "api-reference", "python") - typescript_docs_dir = os.path.join(docs_dir, "api-reference", "typescript") - core_dir = os.path.join(docs_dir, "api-reference", "core") - - for dir_path in [python_docs_dir, typescript_docs_dir, core_dir]: - if os.path.exists(dir_path): - shutil.rmtree(dir_path) - - os.makedirs(python_docs_dir, exist_ok=True) - os.makedirs(typescript_docs_dir, exist_ok=True) - os.makedirs(core_dir, exist_ok=True) - - # Generate the docs pages for core, python, and typescript classes - - # Write the generated docs to the file system, splitting between core, python, and typescript - # keep track of where we put each one so we can update the mint.json - python_set = set() - typescript_set = set() - core_set = set() - # TODO replace this with new `get_mdx_for_class` function - for class_doc in gs_docs.classes: - class_name = class_doc.title - lower_class_name = class_name.lower() - if lower_class_name.startswith("py"): - file_path = os.path.join(python_docs_dir, f"{class_name}.mdx") - python_set.add(f"api-reference/python/{class_name}") - elif lower_class_name.startswith(("ts", "jsx")): - file_path = os.path.join(typescript_docs_dir, f"{class_name}.mdx") - typescript_set.add(f"api-reference/typescript/{class_name}") - else: - file_path = os.path.join(core_dir, f"{class_name}.mdx") - core_set.add(f"api-reference/core/{class_name}") - - mdx_page = render_mdx_page_for_class(cls_doc=class_doc) - with open(file_path, "w") as f: - f.write(mdx_page) - print(colored("Finished writing new .mdx files", "green")) - - # Update the core, python, and typescript page sets in mint.json - mint_file_path = os.path.join(docs_dir, "mint.json") - with open(mint_file_path) as mint_file: - mint_data = json.load(mint_file) - - # Find the "Codebase SDK" group where we want to add the pages - codebase_sdk_group = next(group for group in mint_data["navigation"] if group["group"] == "API Reference") - - # Update the pages for each language group - for group in codebase_sdk_group["pages"]: - if isinstance(group, dict): # Ensure group is a dictionary - if group["group"] == "Core": - group["pages"] = sorted(core_set) - elif group["group"] == "Python": - group["pages"] = sorted(python_set) - elif group["group"] == "Typescript": - group["pages"] = sorted(typescript_set) - - with open(mint_file_path, "w") as mint_file: - json.dump(mint_data, mint_file, indent=2) - - print(colored("Updated mint.json with new page sets", "green")) - - -@generate.command() -@click.option("--docs-dir", default="docs", required=False) -@click.option("--openai-key", required=True) -@click.option("--complete", is_flag=True, help="Generate a complete changelog for the codegen_sdk API") -def changelog(docs_dir: str, openai_key: str, complete: bool = False) -> None: - """Generate the changelog for the codegen_sdk API and update the changelog.mdx file""" - print(colored("Generating changelog", "green")) - header = """--- -title: "Codegen Updates" -icon: "clock" -iconType: "solid" ---- -""" - - client = get_openai_client(openai_key) - - if complete: - entire_release_history = generate_changelog(client) - new_changelog = header + entire_release_history - else: - # Read existing changelog and append new releases - with open(os.path.join(docs_dir, "changelog/changelog.mdx")) as f: - # read the existing changelog - existing_changelog = f.read() - # Remove header from existing changelog - existing_changelog = existing_changelog.split(header)[1] - # find the latest existing version - latest_existing_version = re.search(r'label="(v[\d.]+)"', existing_changelog) - # if there is a latest existing version, generate new releases - if latest_existing_version: - # generate new releases - new_releases = generate_changelog(client, latest_existing_version.group(1)) - # append new releases to the existing changelog - new_changelog = header + new_releases + existing_changelog - else: - # if there is no latest existing version, generate a complete changelog - new_releases = generate_changelog(client) - new_changelog = header + new_releases - - with open(os.path.join(docs_dir, "changelog/changelog.mdx"), "w") as f: - f.write(new_changelog) diff --git a/src/codegen/gscli/generate/runner_imports.py b/src/codegen/gscli/generate/runner_imports.py deleted file mode 100644 index d07b86062..000000000 --- a/src/codegen/gscli/generate/runner_imports.py +++ /dev/null @@ -1,108 +0,0 @@ -from itertools import chain -from pathlib import Path - -import tomlkit -from termcolor import colored - -from codegen.git.utils.file_utils import split_git_path -from codegen.sdk.code_generation.current_code_codebase import get_documented_objects -from codegen.shared.decorators.docs import DocumentedObject - -EXTERNAL_IMPORTS = """ -import os -import re -from pathlib import Path -import networkx as nx -import plotly -""".strip() -CODEGEN_IMPORTS = """ -from codegen.git.models.codemod_context import CodemodContext -from codegen.git.models.github_named_user_context import GithubNamedUserContext -from codegen.git.models.pr_options import PROptions -from codegen.git.models.pr_part_context import PRPartContext -from codegen.git.models.pull_request_context import PullRequestContext -""" -# TODO: these should also be made public (i.e. included in the docs site) -GS_PRIVATE_IMPORTS = """ -from codegen.shared.exceptions.control_flow import StopCodemodException -""".strip() - -IMPORT_STRING_TEMPLATE = """ -# External imports -{external_imports} - -# GraphSitter imports (private) -{codegen_imports} -{gs_private_imports} - -# GraphSitter imports (public) -{gs_public_imports} -""".strip() - -IMPORT_FILE_TEMPLATE = ( - ''' -# This file is auto-generated, do not modify manually. Edit this in src/codegen/gscli/generate/runner_imports.py. -def get_generated_imports(): - return """ -{import_str} -""" -'''.strip() - + "\n" -) - - -def fix_ruff_imports(objects: list[DocumentedObject]): - root, _ = split_git_path(str(Path(__file__))) - to_add = [] - for obj in objects: - to_add.append(f"{obj.module}.{obj.name}") - generics = tomlkit.array() - for val in dict.fromkeys(to_add): - generics.add_line(val, indent=" ") - generics.add_line(indent="") - config = Path(root) / "ruff.toml" - toml_config = tomlkit.parse(config.read_text()) - toml_config["lint"]["pyflakes"]["extend-generics"] = generics - config.write_text(tomlkit.dumps(toml_config)) - - -def get_runner_imports(include_codegen=True, include_private_imports: bool = True) -> str: - # get the imports from the apidoc, py_apidoc, and ts_apidoc - gs_objects = get_documented_objects() - gs_public_objects = list(chain(gs_objects["apidoc"], gs_objects["py_apidoc"], gs_objects["ts_apidoc"])) - fix_ruff_imports(gs_public_objects) - gs_public_imports = {f"from {obj.module} import {obj.name}" for obj in gs_public_objects} - - # construct import string with all imports - ret = IMPORT_STRING_TEMPLATE.format( - codegen_imports=CODEGEN_IMPORTS if include_codegen else "", - external_imports=EXTERNAL_IMPORTS, - gs_private_imports=GS_PRIVATE_IMPORTS if include_private_imports else "", - gs_public_imports="\n".join(sorted(gs_public_imports)), - ) - return ret - - -EXPORT_TEMPLATE = """ -__all__ = [ - "__version__", - "__version_tuple__", - "StopCodemodException", -{modules} -] -""".strip() - - -def generate_exported_modules() -> str: - gs_objects = get_documented_objects() - gs_public_objects = list(chain(gs_objects["apidoc"], gs_objects["py_apidoc"], gs_objects["ts_apidoc"])) - return EXPORT_TEMPLATE.format(modules=",\n".join(dict.fromkeys(' "' + obj.name + '"' for obj in sorted(gs_public_objects, key=lambda x: x.name)))) - - -def _generate_runner_imports(imports_file: str) -> None: - print(colored(f"Generating runner imports string in {imports_file}", "green")) - - import_str = get_runner_imports() - # write the imports to the file - with open(imports_file, "w") as f: - f.write(IMPORT_FILE_TEMPLATE.format(import_str=import_str)) diff --git a/src/codegen/gscli/generate/system_prompt.py b/src/codegen/gscli/generate/system_prompt.py deleted file mode 100644 index 33b4a18a5..000000000 --- a/src/codegen/gscli/generate/system_prompt.py +++ /dev/null @@ -1,29 +0,0 @@ -import json -from pathlib import Path - -docs = Path("./docs") -mint = json.load(open(docs / "mint.json")) - - -def render_page(page_str: str): - return open(docs / (page_str + ".mdx")).read() - - -def render_group(page_strs: list[str]): - return "\n\n".join([render_page(x) for x in page_strs]) - - -def get_group(name) -> list[str]: - group = next((x for x in mint["navigation"] if x.get("group") == name), None) - if group: - return group["pages"] - - -def render_groups(group_names: list[str]) -> str: - groups = [get_group(x) for x in group_names] - return "\n\n".join([render_group(g) for g in groups]) - - -def get_system_prompt() -> str: - """Generates a string system prompt based on the docs""" - return render_groups(["Introduction", "Building with Codegen", "Tutorials"]) diff --git a/src/codegen/gscli/generate/utils.py b/src/codegen/gscli/generate/utils.py deleted file mode 100644 index d579f9288..000000000 --- a/src/codegen/gscli/generate/utils.py +++ /dev/null @@ -1,55 +0,0 @@ -import inspect -from enum import StrEnum -from itertools import chain - -from codegen.sdk.code_generation.current_code_codebase import get_documented_objects -from codegen.sdk.core import codebase - - -class LanguageType(StrEnum): - PYTHON = "PYTHON" - TYPESCRIPT = "TYPESCRIPT" - BOTH = "BOTH" - - -def generate_builtins_file(path_to_builtins: str, language_type: LanguageType): - """Generates and writes the builtins file""" - documented_imports = get_documented_objects() - all_objects = chain(documented_imports["apidoc"], documented_imports["py_apidoc"], documented_imports["ts_apidoc"]) - unique_imports = {f"from {obj.module} import {obj.name} as {obj.name}" for obj in all_objects} - all_imports = "\n".join(sorted(unique_imports)) - # TODO: re-use code with runner_imports list - # TODO: also auto generate import string for CodemodContext + MessageType - - if language_type == LanguageType.PYTHON: - codebase_type = "PyCodebaseType" - elif language_type == LanguageType.TYPESCRIPT: - codebase_type = "TSCodebaseType" - else: # BOTH - codebase_type = "PyCodebaseType | TSCodebaseType" - - BUILTINS_FILE_TEMPLATE = f""" -# This file is auto-generated, do not modify manually - -{{all_imports}} -from codegen.git.models.codemod_context import CodemodContext -from codegen.git.models.pr_options import PROptions -from codegen.git.models.github_named_user_context import GithubNamedUserContext -from codegen.git.models.pr_part_context import PRPartContext -from codegen.git.models.pull_request_context import PullRequestContext -from codegen.sdk.codebase.flagging.code_flag import MessageType as MessageType - -{"\n".join(inspect.getsource(codebase).splitlines()[-2:])} -CodebaseType = {codebase_type} - -# declare global type for 'codebase' -codebase: CodebaseType - -# declare global type for 'context' -context: CodemodContext - -pr_options: PROptions -""" - - with open(path_to_builtins, "w") as f: - f.write(BUILTINS_FILE_TEMPLATE.format(all_imports=all_imports)) diff --git a/src/codegen/runner/README.md b/src/codegen/runner/README.md deleted file mode 100644 index facb14e63..000000000 --- a/src/codegen/runner/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Codegen Runner - -A module to run functions with managed state + lifecycle. - -### Dependencies - -- [codegen.sdk](https://github.com/codegen-sh/codegen-sdk/tree/develop/src/codegen/sdk) -- [codegen.git](https://github.com/codegen-sh/codegen-sdk/tree/develop/src/codegen/git) -- [codegen.shared](https://github.com/codegen-sh/codegen-sdk/tree/develop/src/codegen/shared) diff --git a/src/codegen/runner/__init__.py b/src/codegen/runner/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/runner/clients/client.py b/src/codegen/runner/clients/client.py deleted file mode 100644 index 2e2e4e132..000000000 --- a/src/codegen/runner/clients/client.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Client used to abstract the weird stdin/stdout communication we have with the sandbox""" - -import requests -from fastapi import params - -from codegen.runner.models.apis import ServerInfo -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -DEFAULT_SERVER_PORT = 4002 - -EPHEMERAL_SERVER_PATH = "codegen.runner.sandbox.ephemeral_server:app" - - -class Client: - """Client for interacting with the sandbox server.""" - - host: str - port: int - base_url: str - - def __init__(self, host: str, port: int) -> None: - self.host = host - self.port = port - self.base_url = f"http://{host}:{port}" - - def is_running(self) -> bool: - try: - self.get("/") - return True - except requests.exceptions.ConnectionError: - return False - - def server_info(self, raise_on_error: bool = False) -> ServerInfo: - try: - response = self.get("/") - return ServerInfo.model_validate(response.json()) - except requests.exceptions.ConnectionError: - if raise_on_error: - raise - return ServerInfo() - - def get(self, endpoint: str, data: dict | None = None) -> requests.Response: - url = f"{self.base_url}{endpoint}" - response = requests.get(url, json=data) - response.raise_for_status() - return response - - def post(self, endpoint: str, data: dict | None = None, authorization: str | params.Header | None = None) -> requests.Response: - url = f"{self.base_url}{endpoint}" - headers = {"Authorization": str(authorization)} if authorization else None - response = requests.post(url, json=data, headers=headers) - response.raise_for_status() - return response diff --git a/src/codegen/runner/clients/codebase_client.py b/src/codegen/runner/clients/codebase_client.py deleted file mode 100644 index 7b4bf16ce..000000000 --- a/src/codegen/runner/clients/codebase_client.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Client used to abstract the weird stdin/stdout communication we have with the sandbox""" - -import os -import subprocess -import time - -from codegen.configs.models.secrets import SecretsConfig -from codegen.git.schemas.repo_config import RepoConfig -from codegen.runner.clients.client import Client -from codegen.runner.models.apis import SANDBOX_SERVER_PORT -from codegen.shared.logging.get_logger import get_logger - -DEFAULT_SERVER_PORT = 4002 -EPHEMERAL_SERVER_PATH = "codegen.runner.sandbox.ephemeral_server:app" -RUNNER_SERVER_PATH = "codegen.runner.sandbox.server:app" - - -logger = get_logger(__name__) - - -class CodebaseClient(Client): - """Client for interacting with the locally hosted sandbox server.""" - - repo_config: RepoConfig - - def __init__(self, repo_config: RepoConfig, host: str = "127.0.0.1", port: int = SANDBOX_SERVER_PORT, server_path: str = RUNNER_SERVER_PATH): - super().__init__(host=host, port=port) - self.repo_config = repo_config - self._process = None - self._start_server(server_path) - - def __del__(self): - """Cleanup the subprocess when the client is destroyed""" - if self._process is not None: - self._process.terminate() - self._process.wait() - - def _start_server(self, server_path: str) -> None: - """Start the FastAPI server in a subprocess""" - envs = self._get_envs() - logger.info(f"Starting local server on {self.base_url} with envvars: {envs}") - - self._process = subprocess.Popen( - [ - "uvicorn", - server_path, - "--host", - self.host, - "--port", - str(self.port), - ], - env=envs, - ) - self._wait_for_server() - - def _wait_for_server(self, timeout: int = 30, interval: float = 0.3) -> None: - """Wait for the server to start by polling the health endpoint""" - start_time = time.time() - while (time.time() - start_time) < timeout: - if self.is_running(): - return - time.sleep(interval) - msg = "Server failed to start within timeout period" - raise TimeoutError(msg) - - def _get_envs(self) -> dict: - envs = os.environ.copy() - codebase_envs = { - "REPOSITORY_PATH": str(self.repo_config.repo_path), - "REPOSITORY_OWNER": self.repo_config.organization_name, - "REPOSITORY_LANGUAGE": self.repo_config.language.value, - "GITHUB_TOKEN": SecretsConfig().github_token, - } - - envs.update(codebase_envs) - return envs - - -if __name__ == "__main__": - test_config = RepoConfig.from_repo_path("/Users/caroljung/git/codegen/codegen-agi") - test_config.full_name = "codegen-sh/codegen-agi" - client = CodebaseClient(test_config) - print(client.is_running()) diff --git a/src/codegen/runner/clients/docker_client.py b/src/codegen/runner/clients/docker_client.py deleted file mode 100644 index 89b8a1844..000000000 --- a/src/codegen/runner/clients/docker_client.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Client for interacting with the locally hosted sandbox server hosted on a docker container.""" - -from codegen.cli.commands.start.docker_container import DockerContainer -from codegen.cli.commands.start.docker_fleet import DockerFleet -from codegen.cli.utils.function_finder import DecoratedFunction -from codegen.runner.clients.client import Client -from codegen.runner.models.apis import RUN_FUNCTION_ENDPOINT, RunFunctionRequest -from codegen.runner.models.codemod import CodemodRunResult - - -class DockerClient(Client): - """Client for interacting with the locally hosted sandbox server hosted on a docker container.""" - - def __init__(self, container: DockerContainer): - if not container.is_running() or container.host is None or container.port is None: - msg = f"Container {container.name} is not running." - raise Exception(msg) - super().__init__(container.host, container.port) - - def run(self, codemod_source: str, commit: bool | None = None) -> CodemodRunResult: - req = RunFunctionRequest(function_name="unnamed", codemod_source=codemod_source, commit=commit) - response = self.post(RUN_FUNCTION_ENDPOINT, req.model_dump()) - return CodemodRunResult.model_validate(response.json()) - - def run_function(self, function: DecoratedFunction, commit: bool) -> CodemodRunResult: - req = RunFunctionRequest(function_name=function.name, codemod_source=function.source, commit=commit) - response = self.post(RUN_FUNCTION_ENDPOINT, req.model_dump()) - return CodemodRunResult.model_validate(response.json()) - - -if __name__ == "__main__": - fleet = DockerFleet.load() - cur = next((container for container in fleet.containers if container.is_running()), None) - if cur is None: - msg = "No running container found. Run `codegen start` from a git repo first." - raise Exception(msg) - client = DockerClient(cur) - print(f"healthcheck: {client.is_running()}") - result = client.run("print(codebase)") - print(result) diff --git a/src/codegen/runner/constants/envvars.py b/src/codegen/runner/constants/envvars.py deleted file mode 100644 index 8d47fd6a4..000000000 --- a/src/codegen/runner/constants/envvars.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Environment variables used in the sandbox.""" - -# ==== [ Environment variable names ] ==== -FEATURE_FLAGS_BASE64 = "FEATURE_FLAGS_BASE64" -REPO_CONFIG_BASE64 = "REPO_CONFIG_BASE64" -GITHUB_TOKEN = "GITHUB_TOKEN" diff --git a/src/codegen/runner/diff/get_raw_diff.py b/src/codegen/runner/diff/get_raw_diff.py deleted file mode 100644 index 463584249..000000000 --- a/src/codegen/runner/diff/get_raw_diff.py +++ /dev/null @@ -1,94 +0,0 @@ -import io - -from unidiff import LINE_TYPE_CONTEXT, Hunk, PatchedFile, PatchSet -from unidiff.patch import Line - -from codegen.sdk.core.codebase import Codebase -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -def append_flag(file: PatchedFile, append_at: int, line_no: int, codebase: Codebase) -> None: - added_hunk = Hunk( - src_start=line_no, - src_len=1, - tgt_start=line_no, - tgt_len=1, - ) - line = codebase.get_file(file.path).content.split("\n")[line_no - 1] - added_hunk.append(Line(f"{line}\n", line_type=LINE_TYPE_CONTEXT)) - file.insert(append_at, added_hunk) - - -def patch_to_limited_diff_string(patch, codebase: Codebase, max_lines=10000): - diff_lines = [] - total_lines = 0 - - # Add flags that are not in the diff - filenames = [patched_file.path for patched_file in patch] - flags_not_in_diff = list(filter(lambda flag: flag.symbol.filepath not in filenames, codebase.ctx.flags._flags)) - - for flag in flags_not_in_diff: - filename = flag.symbol.filepath - patched_file = PatchedFile( - patch_info=f"diff --git a/{filename} b/{filename}\n", - source=f"a/{filename}", - target=f"b/{filename}", - ) - patch.append(patched_file) - - for patched_file in patch: - filtered_flags = filter(lambda flag: flag.symbol.filepath == patched_file.path, codebase.ctx.flags._flags) - sorted_flags = list(map(lambda flag: flag.symbol.start_point.row + 1, filtered_flags)) - sorted_flags.sort() - - for flag in sorted_flags: - is_in_diff = False - - for i, hunk in enumerate(patched_file): - contains_flag = hunk.source_start <= flag <= hunk.source_start + hunk.source_length - - if contains_flag: - is_in_diff = True - break - - is_after_flag = hunk.source_start > flag - - if is_after_flag: - is_in_diff = True - append_flag(patched_file, i, flag, codebase) - break - - if not is_in_diff: - append_flag(patched_file, len(patched_file), flag, codebase) - - # Add file header - raw_diff = str(patched_file) - diff_length = len(raw_diff.splitlines()) - - total_lines += diff_length - diff_lines.append(raw_diff) - - if total_lines >= max_lines: - break - - return "\n".join(diff_lines) - - -def get_raw_diff(codebase: Codebase, base: str = "HEAD", max_lines: int = 10000) -> str: - raw_diff = codebase.get_diff(base) - patch_set = PatchSet(io.StringIO(raw_diff)) - - raw_diff_length = len(raw_diff.split("\n")) - logger.info(f"Truncating diff (total: {raw_diff_length}) to {max_lines} lines ...") - raw_diff_trunc = patch_to_limited_diff_string(patch=patch_set, max_lines=max_lines, codebase=codebase) - - return raw_diff_trunc - - -def get_filenames_from_diff(diff: str) -> list[str]: - patch_set = PatchSet(io.StringIO(diff)) - filenames = [patched_file.path for patched_file in patch_set] - - return filenames diff --git a/src/codegen/runner/enums/warmup_state.py b/src/codegen/runner/enums/warmup_state.py deleted file mode 100644 index c75c6f553..000000000 --- a/src/codegen/runner/enums/warmup_state.py +++ /dev/null @@ -1,7 +0,0 @@ -from enum import StrEnum - - -class WarmupState(StrEnum): - PENDING = "PENDING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" diff --git a/src/codegen/runner/models/apis.py b/src/codegen/runner/models/apis.py deleted file mode 100644 index 961a5c93b..000000000 --- a/src/codegen/runner/models/apis.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Dataclasses used by the sandboxes server APIs""" - -from pydantic import BaseModel - -from codegen.runner.enums.warmup_state import WarmupState -from codegen.runner.models.codemod import BranchConfig, Codemod, CodemodRunResult, CreatedBranch, GroupingConfig - -SANDBOX_SERVER_PORT = 4000 -EPHEMERAL_SANDBOX_SERVER_PORT = 4001 - -# APIs -DIFF_ENDPOINT = "/diff" -BRANCH_ENDPOINT = "/branch" -RUN_FUNCTION_ENDPOINT = "/run" - -# Ephemeral sandbox apis -RUN_ON_STRING_ENDPOINT = "/run_on_string" - - -class ServerInfo(BaseModel): - repo_name: str | None = None - synced_commit: str | None = None - warmup_state: WarmupState = WarmupState.PENDING - - -class GetDiffRequest(BaseModel): - codemod: Codemod - max_transactions: int | None = None - max_seconds: int | None = None - - -class GetDiffResponse(BaseModel): - result: CodemodRunResult - - -class CreateBranchRequest(BaseModel): - codemod: Codemod - commit_msg: str - grouping_config: GroupingConfig - branch_config: BranchConfig - - -class CreateBranchResponse(BaseModel): - results: list[CodemodRunResult] | None = None - branches: list[CreatedBranch] | None = None - num_flags: int | None = None - group_segments: list[str] | None = None - - -class GetRunOnStringRequest(BaseModel): - codemod_source: str - language: str - files: dict[str, str] - - -class GetRunOnStringResult(BaseModel): - result: CodemodRunResult - - -class RunFunctionRequest(BaseModel): - codemod_source: str - function_name: str - commit: bool = False diff --git a/src/codegen/runner/models/codemod.py b/src/codegen/runner/models/codemod.py deleted file mode 100644 index ac15389a1..000000000 --- a/src/codegen/runner/models/codemod.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Dataclasses used by the sandbox runners""" - -from datetime import datetime - -from pydantic import BaseModel - -from codegen.git.models.codemod_context import CodemodContext -from codegen.git.models.pr_options import PROptions -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - - -class Codemod(BaseModel): - user_code: str - codemod_context: CodemodContext = CodemodContext() - - -class GroupingConfig(BaseModel): - subdirectories: list[str] | None = None - group_by: GroupBy | None = None - max_prs: int | None = None - - -class BranchConfig(BaseModel): - branch_name: str | None = None - custom_base_branch: str | None = None - custom_head_branch: str | None = None - force_push_head_branch: bool = False - - -class CodemodRunResult(BaseModel): - is_complete: bool = False - observation: str | None = None - visualization: dict | None = None - observation_meta: dict | None = None - base_commit: str | None = None - logs: str | None = None - error: str | None = None - completed_at: datetime | None = None - highlighted_diff: str | None = None - pr_options: PROptions | None = None - flags: list[dict] | None = None - - -class CreatedBranch(BaseModel): - base_branch: str - head_ref: str | None = None - - -class SandboxRunnerTag(BaseModel): - repo_id: str - runner_id: str diff --git a/src/codegen/runner/sandbox/ephemeral_server.py b/src/codegen/runner/sandbox/ephemeral_server.py deleted file mode 100644 index 6f67e8c30..000000000 --- a/src/codegen/runner/sandbox/ephemeral_server.py +++ /dev/null @@ -1,56 +0,0 @@ -import tempfile -from contextlib import asynccontextmanager - -from fastapi import FastAPI - -from codegen.runner.enums.warmup_state import WarmupState -from codegen.runner.models.apis import ( - RUN_ON_STRING_ENDPOINT, - GetRunOnStringRequest, - GetRunOnStringResult, - ServerInfo, -) -from codegen.runner.sandbox.executor import SandboxExecutor -from codegen.sdk.codebase.factory.get_session import get_codebase_session -from codegen.shared.compilation.string_to_code import create_execute_function_from_codeblock -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -server_info: ServerInfo - - -@asynccontextmanager -async def lifespan(server: FastAPI): - global server_info - server_info = ServerInfo(warmup_state=WarmupState.COMPLETED) - logger.info("Ephemeral server is ready to accept requests") - yield - logger.info("Shutting down fastapi server") - - -app = FastAPI(lifespan=lifespan) - - -@app.get("/") -def health() -> ServerInfo: - return server_info - - -@app.post(RUN_ON_STRING_ENDPOINT) -async def run_on_string(request: GetRunOnStringRequest) -> GetRunOnStringResult: - logger.info(f"====[ run_on_string ]====\n> Codemod source: {request.codemod_source}\n> Input: {request.files}\n> Language: {request.language}\n") - language = ProgrammingLanguage(request.language.upper()) - with get_codebase_session(tmpdir=tempfile.mkdtemp(), files=request.files, programming_language=language) as codebase: - executor = SandboxExecutor(codebase) - code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod_source) - result = await executor.execute(code_to_exec) - logger.info(f"Result: {result}") - return GetRunOnStringResult(result=result) - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/codegen/runner/sandbox/executor.py b/src/codegen/runner/sandbox/executor.py deleted file mode 100644 index 393ba93cb..000000000 --- a/src/codegen/runner/sandbox/executor.py +++ /dev/null @@ -1,175 +0,0 @@ -from collections.abc import Callable -from datetime import UTC, datetime - -from github.PullRequest import PullRequest - -from codegen.git.models.pr_options import PROptions -from codegen.runner.diff.get_raw_diff import get_raw_diff -from codegen.runner.models.codemod import BranchConfig, CodemodRunResult, CreatedBranch, GroupingConfig -from codegen.runner.sandbox.repo import SandboxRepo -from codegen.runner.utils.branch_name import get_head_branch_name -from codegen.runner.utils.exception_utils import update_observation_meta -from codegen.sdk.codebase.config import SessionOptions -from codegen.sdk.codebase.factory.codebase_factory import CodebaseType -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.utils import get_grouper_by_group_by -from codegen.shared.exceptions.control_flow import StopCodemodException -from codegen.shared.logging.get_logger import get_logger -from codegen.shared.performance.stopwatch_utils import stopwatch -from codegen.visualizations.viz_utils import get_graph_json - -logger = get_logger(__name__) - - -class SandboxExecutor: - """Responsible for executing the user defined codemod in the sandbox.""" - - codebase: CodebaseType - remote_repo: SandboxRepo - - def __init__(self, codebase: CodebaseType): - self.codebase = codebase - self.remote_repo = SandboxRepo(self.codebase) - - async def find_flags(self, execute_func: Callable) -> list[CodeFlag]: - """Runs the execute_func in find_mode to find flags""" - self.codebase.set_find_mode(True) - await self._execute_with_try_catch(execute_func, commit=False) - code_flags = self.codebase.ctx.flags._flags - logger.info(f"> Found {len(self.codebase.ctx.flags._flags)} CodeFlags") - return code_flags - - async def find_flag_groups(self, code_flags: list[CodeFlag], grouping_config: GroupingConfig) -> list[Group]: - """Groups the code flags as specified by grouping_config""" - if grouping_config.subdirectories and len(grouping_config.subdirectories) > 0: - logger.info(f"> Filtering flags by subdirectories: {grouping_config.subdirectories}") - code_flags = [flag for flag in code_flags if any([flag.filepath.startswith(x) for x in grouping_config.subdirectories])] - logger.info(f"> Flags remaining: {len(code_flags)}") - - # =====[ Group the code flags ]===== - logger.info(f"> Grouping CodeFlags by config: {grouping_config}") - grouper = get_grouper_by_group_by(grouping_config.group_by) - groups = grouper.create_all_groups(flags=code_flags, repo_operator=self.codebase.op) - logger.info(f"> Created {len(groups)} groups") - return groups - - async def execute_flag_groups(self, commit_msg: str, execute_func: Callable, flag_groups: list[Group], branch_config: BranchConfig) -> tuple[list[CodemodRunResult], list[CreatedBranch]]: - run_results = [] - head_branches = [] - for idx, group in enumerate(flag_groups): - if idx > 0 and run_results[-1].error: - logger.info("Skipping remaining groups because of error in previous group") - break - if group: - logger.info(f"Running group {group.segment} ({idx + 1} out of {len(flag_groups)})...") - - head_branch = branch_config.custom_head_branch or get_head_branch_name(branch_config.branch_name, group) - logger.info(f"Running with head branch: {head_branch}") - self.remote_repo.reset_branch(branch_config.custom_base_branch, head_branch) - - run_result = await self.execute(execute_func, group=group) - created_branch = CreatedBranch(base_branch=branch_config.custom_base_branch, head_ref=None) - if self.remote_repo.push_changes_to_remote(commit_msg, head_branch, branch_config.force_push_head_branch): - created_branch.head_ref = head_branch - - self.codebase.reset() - run_results.append(run_result) - head_branches.append(created_branch) - - self.codebase.ctx.flags._flags.clear() - return run_results, head_branches - - async def execute(self, execute_func: Callable, group: Group | None = None, session_options: SessionOptions = SessionOptions()) -> CodemodRunResult: - """Runs the execute_func in edit_mode and returns the saved the result""" - self.codebase.set_find_mode(False) - if group: - self.codebase.set_active_group(group) - result = await self._execute_with_try_catch(execute_func, session_options=session_options) - return await self._get_structured_run_output(result) - - async def execute_on_pr(self, execute_func: Callable, pr: PullRequest, session_options: SessionOptions = SessionOptions()) -> CodemodRunResult: - """Runs the execute_func in edit_mode and returns the saved the result""" - # TODO: only difference is this sets `set_find_mode` to True to capture flags. Shouldn't need to do this, flags should always appear. - self.codebase.set_find_mode(True) - result = await self._execute_with_try_catch(execute_func, session_options=session_options, pr=pr) - return await self._get_structured_run_output(result) - - @stopwatch - async def _execute_with_try_catch( - self, - execute_func: Callable, - *, - sync_graph: bool = False, - commit: bool = True, - session_options: SessionOptions = SessionOptions(), - pr: PullRequest | None = None, - ) -> CodemodRunResult: - """Runs the execute_func in a try/catch with a codebase session""" - logger.info(f"Running safe execute with sync_graph: {sync_graph} commit: {commit} session_options: {session_options}") - result = CodemodRunResult() - pr_options = PROptions() - try: - with self.codebase.session(sync_graph, commit, session_options=session_options): - execute_func(self.codebase, pr_options, pr=pr) - result.is_complete = True - - except StopCodemodException as e: - logger.info(f"Stopping codemod due to {e.__class__.__name__}: {e}") - result.observation_meta = update_observation_meta(e, result.observation_meta) - result.is_complete = True - - except Exception as e: - error_message = str(e) - logger.exception(e) - result.error = error_message - result.is_complete = False - - finally: - # =====[ Capture completed_at ]===== - result.completed_at = datetime.now(tz=UTC) - - # =====[ Capture PR options ]===== - result.pr_options = pr_options - - # =====[ Build graph.json ]===== - viz_results = get_graph_json(self.codebase.op) - if viz_results is not None: - result.visualization = viz_results - - return result - - async def _get_structured_run_output(self, result: CodemodRunResult) -> CodemodRunResult: - """Formats output into a CodemodRunResult""" - # =====[ Save flags ]===== - # Note: I think we should just store this on the CodemodRunResult.flags, not meta - # Also note: we should type this object, since we end up using it in several locations - flags = [ - { - "filepath": flag.symbol.filepath, - "startLine": flag.symbol.start_point.row, - "startColumn": flag.symbol.start_point.column, - "endLine": flag.symbol.end_point.row, - "endColumn": flag.symbol.end_point.column, - "message": flag.message, - "messageType": str(flag.message_type), - "messageRecipient": flag.message_recipient, - } - for flag in self.codebase.ctx.flags._flags - ] - result.flags = flags - if result.observation_meta is None: - result.observation_meta = {} - result.observation_meta["flags"] = flags - - # =====[ Get and store raw diff ]===== - logger.info("> Extracting diff") - raw_diff = get_raw_diff(codebase=self.codebase) - result.observation = raw_diff - result.base_commit = self.codebase.current_commit.hexsha if self.codebase.current_commit else "HEAD" - - # =====[ Finalize CodemodRun state ]===== - # Include logs etc. - logger.info("> Extracting/formatting logs") - result.logs = self.codebase.get_finalized_logs() - return result diff --git a/src/codegen/runner/sandbox/middlewares.py b/src/codegen/runner/sandbox/middlewares.py deleted file mode 100644 index 8edea49b9..000000000 --- a/src/codegen/runner/sandbox/middlewares.py +++ /dev/null @@ -1,50 +0,0 @@ -import traceback -from http import HTTPStatus -from typing import Callable, TypeVar - -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request -from starlette.responses import JSONResponse, Response - -from codegen.runner.sandbox.runner import SandboxRunner -from codegen.shared.exceptions.compilation import UserCodeException -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -TRequest = TypeVar("TRequest", bound=Request) -TResponse = TypeVar("TResponse", bound=Response) - - -class CodemodRunMiddleware[TRequest, TResponse](BaseHTTPMiddleware): - def __init__(self, app, path: str, runner_fn: Callable[[], SandboxRunner]) -> None: - super().__init__(app) - self.path = path - self.runner_fn = runner_fn - - @property - def runner(self) -> SandboxRunner: - return self.runner_fn() - - async def dispatch(self, request: TRequest, call_next: RequestResponseEndpoint) -> TResponse: - if request.url.path == self.path: - return await self.process_request(request, call_next) - return await call_next(request) - - async def process_request(self, request: TRequest, call_next: RequestResponseEndpoint) -> TResponse: - try: - logger.info(f"> (CodemodRunMiddleware) Request: {request.url.path}") - self.runner.codebase.viz.clear_graphviz_data() - response = await call_next(request) - return response - - except UserCodeException as e: - message = f"Invalid user code for {request.url.path}" - logger.info(message) - return JSONResponse(status_code=HTTPStatus.BAD_REQUEST, content={"detail": message, "error": str(e), "traceback": traceback.format_exc()}) - - except Exception as e: - message = f"Unexpected error for {request.url.path}" - logger.exception(message) - res = JSONResponse(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, content={"detail": message, "error": str(e), "traceback": traceback.format_exc()}) - return res diff --git a/src/codegen/runner/sandbox/repo.py b/src/codegen/runner/sandbox/repo.py deleted file mode 100644 index dc938361f..000000000 --- a/src/codegen/runner/sandbox/repo.py +++ /dev/null @@ -1,70 +0,0 @@ -from codegen.sdk.codebase.factory.codebase_factory import CodebaseType -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class SandboxRepo: - """Responsible for managing the state of the git repo stored in the sandbox runner""" - - codebase: CodebaseType - - def __init__(self, codebase: CodebaseType) -> None: - self.codebase = codebase - - def set_up_base_branch(self, base_branch: str | None) -> None: - """Set-up base branch by pushing latest highside branch to lowside and checking out the branch.""" - # If base branch is already checked out, do nothing - if self.codebase.op.is_branch_checked_out(base_branch): - return - - # fetch the base branch from highside (do not checkout yet) - highside_remote = self.codebase.op.git_cli.remote(name="origin") - self.codebase.op.fetch_remote(highside_remote.name, refspec=f"{base_branch}:{base_branch}") - - # checkout the base branch (and possibly sync graph) - self.codebase.checkout(branch=base_branch) - - def set_up_head_branch(self, head_branch: str, force_push_head_branch: bool): - """Set-up head branch by pushing latest highside branch to lowside and fetching the branch (so that it can be checked out later).""" - # If head branch is not specified, do nothing - if head_branch is None: - return - - if head_branch and head_branch == self.codebase.default_branch: - # NOTE: assuming that the main branch is always protected, instead should pull this from github (but it requires admin permissions) - error = f"Branch {head_branch} is protected and cannot be used as the head branch!" - logger.error(error) - raise ValueError(error) - - # If are force pushing the head branch, don't checkout the remote. - # This will cause set-up group to create a new branch off of master by the same name - if force_push_head_branch: - return - - # fetch the head branch from highside (do not checkout yet) - highside_remote = self.codebase.op.git_cli.remote(name="origin") - self.codebase.op.fetch_remote(highside_remote.name, refspec=f"{head_branch}:{head_branch}") - - def reset_branch(self, base_branch: str, head_branch: str) -> None: - logger.info(f"Checking out base branch {base_branch} ...") - self.codebase.checkout(branch=base_branch, create_if_missing=True) - # =====[ Checkout head branch ]===== - logger.info(f"Checking out head branch {head_branch} ...") - self.codebase.checkout(branch=head_branch, create_if_missing=True) - - def push_changes_to_remote(self, commit_msg: str, head_branch: str, force_push: bool) -> bool: - """Takes current state of repo and pushes it""" - # =====[ Stage changes ]===== - has_staged_commit = self.codebase.git_commit(f"[Codegen] {commit_msg}") - if not has_staged_commit: - logger.info("Skipping opening pull request for cm_run b/c the codemod produced no changes") - return False - - # =====[ Push changes highside ]===== - highside_remote = self.codebase.op.git_cli.remote(name="origin") - highside_res = self.codebase.op.push_changes(remote=highside_remote, refspec=f"{head_branch}:{head_branch}", force=force_push) - return not any(push_info.flags & push_info.ERROR for push_info in highside_res) - - # TODO: move bunch of codebase git operations into this class. - # The goal is to make the codebase class ONLY allow RepoOperator. diff --git a/src/codegen/runner/sandbox/runner.py b/src/codegen/runner/sandbox/runner.py deleted file mode 100644 index 4a86bc618..000000000 --- a/src/codegen/runner/sandbox/runner.py +++ /dev/null @@ -1,87 +0,0 @@ -import sys - -from codegen.configs.models.codebase import CodebaseConfig -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.enums import SetupOption -from codegen.git.schemas.repo_config import RepoConfig -from codegen.runner.models.apis import CreateBranchRequest, CreateBranchResponse, GetDiffRequest, GetDiffResponse -from codegen.runner.sandbox.executor import SandboxExecutor -from codegen.sdk.codebase.config import ProjectConfig, SessionOptions -from codegen.sdk.codebase.factory.codebase_factory import CodebaseType -from codegen.sdk.core.codebase import Codebase -from codegen.shared.compilation.string_to_code import create_execute_function_from_codeblock -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class SandboxRunner: - """Responsible for orchestrating the lifecycle of a warmed sandbox""" - - # =====[ __init__ instance attributes ]===== - repo: RepoConfig - op: RepoOperator | None - - # =====[ computed instance attributes ]===== - codebase: CodebaseType - executor: SandboxExecutor - - def __init__(self, repo_config: RepoConfig, op: RepoOperator | None = None) -> None: - self.repo = repo_config - self.op = op or RepoOperator(repo_config=self.repo, setup_option=SetupOption.PULL_OR_CLONE, bot_commit=True) - - async def warmup(self, codebase_config: CodebaseConfig | None = None) -> None: - """Warms up this runner by cloning the repo and parsing the graph.""" - logger.info(f"===== Warming runner for {self.repo.full_name or self.repo.name} =====") - sys.setrecursionlimit(10000) # for graph parsing - - self.codebase = await self._build_graph(codebase_config) - self.executor = SandboxExecutor(self.codebase) - - async def _build_graph(self, codebase_config: CodebaseConfig | None = None) -> Codebase: - logger.info("> Building graph...") - projects = [ProjectConfig(programming_language=self.repo.language, repo_operator=self.op, base_path=self.repo.base_path, subdirectories=self.repo.subdirectories)] - return Codebase(projects=projects, config=codebase_config) - - async def get_diff(self, request: GetDiffRequest) -> GetDiffResponse: - custom_scope = {"context": request.codemod.codemod_context} if request.codemod.codemod_context else {} - code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod.user_code, custom_scope=custom_scope) - session_options = SessionOptions(max_transactions=request.max_transactions, max_seconds=request.max_seconds) - - res = await self.executor.execute(code_to_exec, session_options=session_options) - - return GetDiffResponse(result=res) - - async def create_branch(self, request: CreateBranchRequest) -> CreateBranchResponse: - custom_scope = {"context": request.codemod.codemod_context} if request.codemod.codemod_context else {} - code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod.user_code, custom_scope=custom_scope) - branch_config = request.branch_config - - branch_config.custom_base_branch = branch_config.custom_base_branch or self.codebase.default_branch - self.executor.remote_repo.set_up_base_branch(branch_config.custom_base_branch) - self.executor.remote_repo.set_up_head_branch(branch_config.custom_head_branch, branch_config.force_push_head_branch) - - response = CreateBranchResponse() - if "codebase.flag_instance" in request.codemod.user_code: - flags = await self.executor.find_flags(code_to_exec) - flag_groups = await self.executor.find_flag_groups(flags, request.grouping_config) - response.num_flags = len(flags) - response.group_segments = [group.segment for group in flag_groups] - if len(flag_groups) == 0: - logger.info("No flag groups found. Running without flagging.") - flag_groups = [None] - else: - flag_groups = [None] - - # TODO: do this as part of find_flag_groups? - max_prs = request.grouping_config.max_prs - if max_prs and len(flag_groups) >= max_prs: - logger.info(f"Max PRs limit reached: {max_prs}. Skipping remaining groups.") - flag_groups = flag_groups[:max_prs] - - run_results, branches = await self.executor.execute_flag_groups(request.commit_msg, code_to_exec, flag_groups, branch_config) - response.results = run_results - response.branches = branches - - self.codebase.ctx.flags._flags.clear() - return response diff --git a/src/codegen/runner/sandbox/server.py b/src/codegen/runner/sandbox/server.py deleted file mode 100644 index a6a346fcf..000000000 --- a/src/codegen/runner/sandbox/server.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -from contextlib import asynccontextmanager - -from fastapi import FastAPI - -from codegen.configs.models.repository import RepositoryConfig -from codegen.git.schemas.repo_config import RepoConfig -from codegen.runner.enums.warmup_state import WarmupState -from codegen.runner.models.apis import ( - BRANCH_ENDPOINT, - DIFF_ENDPOINT, - CreateBranchRequest, - CreateBranchResponse, - GetDiffRequest, - GetDiffResponse, - ServerInfo, -) -from codegen.runner.sandbox.middlewares import CodemodRunMiddleware -from codegen.runner.sandbox.runner import SandboxRunner -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -server_info: ServerInfo -runner: SandboxRunner - - -@asynccontextmanager -async def lifespan(server: FastAPI): - global server_info - global runner - - default_repo_config = RepositoryConfig() - repo_name = default_repo_config.full_name or default_repo_config.name - server_info = ServerInfo(repo_name=repo_name) - try: - logger.info(f"Starting up sandbox fastapi server for repo_name={repo_name}") - repo_config = RepoConfig( - name=default_repo_config.name, - full_name=default_repo_config.full_name, - base_dir=os.path.dirname(default_repo_config.path), - language=ProgrammingLanguage(default_repo_config.language.upper()), - ) - runner = SandboxRunner(repo_config=repo_config) - server_info.warmup_state = WarmupState.PENDING - await runner.warmup() - server_info.synced_commit = runner.op.git_cli.head.commit.hexsha - server_info.warmup_state = WarmupState.COMPLETED - except Exception: - logger.exception("Failed to build graph during warmup") - server_info.warmup_state = WarmupState.FAILED - - logger.info("Sandbox fastapi server is ready to accept requests") - yield - logger.info("Shutting down sandbox fastapi server") - - -app = FastAPI(lifespan=lifespan) -app.add_middleware( - CodemodRunMiddleware[GetDiffRequest, GetDiffResponse], - path=DIFF_ENDPOINT, - runner_fn=lambda: runner, -) -app.add_middleware( - CodemodRunMiddleware[CreateBranchRequest, CreateBranchResponse], - path=BRANCH_ENDPOINT, - runner_fn=lambda: runner, -) - - -@app.get("/") -def health() -> ServerInfo: - return server_info - - -@app.post(DIFF_ENDPOINT) -async def get_diff(request: GetDiffRequest) -> GetDiffResponse: - return await runner.get_diff(request=request) - - -@app.post(BRANCH_ENDPOINT) -async def create_branch(request: CreateBranchRequest) -> CreateBranchResponse: - return await runner.create_branch(request=request) diff --git a/src/codegen/runner/servers/local_daemon.py b/src/codegen/runner/servers/local_daemon.py deleted file mode 100644 index 1d24006ae..000000000 --- a/src/codegen/runner/servers/local_daemon.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -from contextlib import asynccontextmanager - -from fastapi import FastAPI - -from codegen.configs.models.codebase import DefaultCodebaseConfig -from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.enums import SetupOption -from codegen.git.schemas.repo_config import RepoConfig -from codegen.runner.enums.warmup_state import WarmupState -from codegen.runner.models.apis import ( - RUN_FUNCTION_ENDPOINT, - GetDiffRequest, - RunFunctionRequest, - ServerInfo, -) -from codegen.runner.models.codemod import Codemod, CodemodRunResult -from codegen.runner.sandbox.runner import SandboxRunner -from codegen.shared.logging.get_logger import get_logger - -# Configure logging at module level -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - force=True, -) -logger = get_logger(__name__) - -server_info: ServerInfo -runner: SandboxRunner - - -@asynccontextmanager -async def lifespan(server: FastAPI): - global server_info - global runner - - try: - repo_config = RepoConfig.from_envs() - server_info = ServerInfo(repo_name=repo_config.full_name or repo_config.name) - - # Set the bot email and username - op = RepoOperator(repo_config=repo_config, setup_option=SetupOption.SKIP, bot_commit=True) - runner = SandboxRunner(repo_config=repo_config, op=op) - logger.info(f"Configuring git user config to {CODEGEN_BOT_EMAIL} and {CODEGEN_BOT_NAME}") - runner.op.git_cli.git.config("user.email", CODEGEN_BOT_EMAIL) - runner.op.git_cli.git.config("user.name", CODEGEN_BOT_NAME) - - # Parse the codebase with sync enabled - logger.info(f"Starting up fastapi server for repo_name={repo_config.name}") - server_info.warmup_state = WarmupState.PENDING - codebase_config = DefaultCodebaseConfig.model_copy(update={"sync_enabled": True}) - await runner.warmup(codebase_config=codebase_config) - server_info.synced_commit = runner.op.head_commit.hexsha - server_info.warmup_state = WarmupState.COMPLETED - - except Exception: - logger.exception("Failed to build graph during warmup") - server_info.warmup_state = WarmupState.FAILED - - logger.info("Local daemon is ready to accept requests!") - yield - logger.info("Shutting down local daemon server") - - -app = FastAPI(lifespan=lifespan) - - -@app.get("/") -def health() -> ServerInfo: - return server_info - - -@app.post(RUN_FUNCTION_ENDPOINT) -async def run(request: RunFunctionRequest) -> CodemodRunResult: - _save_uncommitted_changes_and_sync() - diff_req = GetDiffRequest(codemod=Codemod(user_code=request.codemod_source)) - diff_response = await runner.get_diff(request=diff_req) - if request.commit: - if commit_sha := runner.codebase.git_commit(f"[Codegen] {request.function_name}", exclude_paths=[".codegen/*"]): - logger.info(f"Committed changes to {commit_sha.hexsha}") - return diff_response.result - - -def _save_uncommitted_changes_and_sync() -> None: - if commit := runner.codebase.git_commit("[Codegen] Save uncommitted changes", exclude_paths=[".codegen/*"]): - logger.info(f"Saved uncommitted changes to {commit.hexsha}") - - cur_commit = runner.op.head_commit - if cur_commit != runner.codebase.ctx.synced_commit: - logger.info(f"Syncing codebase to head commit: {cur_commit.hexsha}") - runner.codebase.sync_to_commit(target_commit=cur_commit) - else: - logger.info("Codebase is already synced to head commit") - - server_info.synced_commit = cur_commit.hexsha diff --git a/src/codegen/runner/utils/branch_name.py b/src/codegen/runner/utils/branch_name.py deleted file mode 100644 index 2b31db709..000000000 --- a/src/codegen/runner/utils/branch_name.py +++ /dev/null @@ -1,11 +0,0 @@ -from uuid import uuid4 - -from codegen.sdk.codebase.flagging.group import Group - - -def get_head_branch_name(branch_name: str | None, group: Group | None = None) -> str: - if branch_name is None: - branch_name = f"codegen-{uuid4()}" - if group: - return f"{branch_name}-group-{group.id}" - return branch_name diff --git a/src/codegen/runner/utils/exception_utils.py b/src/codegen/runner/utils/exception_utils.py deleted file mode 100644 index b7f3c9e56..000000000 --- a/src/codegen/runner/utils/exception_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -from codegen.shared.exceptions.control_flow import StopCodemodException - - -def update_observation_meta( - e: StopCodemodException, - observation_meta: dict | None = None, -) -> dict: - observation_meta = observation_meta or {} - observation_meta.update( - { - "stop_codemod_exception_type": e.__class__.__name__, - "threshold": e.threshold, - }, - ) - return observation_meta diff --git a/src/codegen/sdk/README.md b/src/codegen/sdk/README.md deleted file mode 100644 index f9e94756b..000000000 --- a/src/codegen/sdk/README.md +++ /dev/null @@ -1,117 +0,0 @@ -
- -

- - - -

- -

- Scriptable interface to a powerful, multi-lingual language server. -

- -
- -[![PyPI](https://img.shields.io/badge/PyPi-codegen-gray?style=flat-square&color=blue)](https://pypi.org/project/codegen/) -[![Documentation](https://img.shields.io/badge/Docs-docs.codegen.com-purple?style=flat-square)](https://docs.codegen.com) -[![Slack Community](https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&style=flat-square)](https://community.codegen.com) -[![License](https://img.shields.io/badge/Code%20License-Apache%202.0-gray?&color=gray)](https://github.com/codegen-sh/codegen-sdk/tree/develop?tab=Apache-2.0-1-ov-file) -[![Follow on X](https://img.shields.io/twitter/follow/codegen?style=social)](https://x.com/codegen) - -
- -
- -[Codegen](https://docs.codegen.com) is a python library for manipulating codebases. - -```python -from codegen import Codebase - -# Codegen builds a complete graph connecting -# functions, classes, imports and their relationships -codebase = Codebase("./") - -# Work with code without dealing with syntax trees or parsing -for function in codebase.functions: - # Comprehensive static analysis for references, dependencies, etc. - if not function.usages: - # Auto-handles references and imports to maintain correctness - function.move_to_file("deprecated.py") -``` - -Write code that transforms code. Codegen combines the parsing power of [Tree-sitter](https://tree-sitter.github.io/tree-sitter/) with the graph algorithms of [rustworkx](https://github.com/Qiskit/rustworkx) to enable scriptable, multi-language code manipulation at scale. - -## Installation and Usage - -We support - -- Running Codegen in Python 3.12 - 3.13 (recommended: Python 3.13+) -- macOS and Linux - - macOS is supported - - Linux is supported on x86_64 and aarch64 with glibc 2.34+ - - Windows is supported via WSL. See [here](https://docs.codegen.com/building-with-codegen/codegen-with-wsl) for more details. -- Python, Typescript, Javascript and React codebases - -``` -# Install inside existing project -uv pip install codegen - -# Install global CLI -uv tool install codegen --python 3.13 - -# Create a codemod for a given repo -cd path/to/repo -codegen init -codegen create test-function - -# Run the codemod -codegen run test-function - -# Create an isolated venv with codegen => open jupyter -codegen notebook -``` - -## Usage - -See [Getting Started](https://docs.codegen.com/introduction/getting-started) for a full tutorial. - -``` -from codegen import Codebase -``` - -## Troubleshooting - -Having issues? Here are some common problems and their solutions: - -- **I'm hitting an UV error related to `[[ packages ]]`**: This means you're likely using an outdated version of UV. Try updating to the latest version with: `uv self update`. -- **I'm hitting an error about `No module named 'codegen.sdk.extensions.utils'`**: The compiled cython extensions are out of sync. Update them with `uv sync --reinstall-package codegen`. -- **I'm hitting a `RecursionError: maximum recursion depth exceeded` error while parsing my codebase**: If you are using python 3.12, try upgrading to 3.13. If you are already on 3.13, try upping the recursion limit with `sys.setrecursionlimit(10000)`. - -If you run into additional issues not listed here, please [join our slack community](https://community.codegen.com) and we'll help you out! - -## Resources - -- [Docs](https://docs.codegen.com) -- [Getting Started](https://docs.codegen.com/introduction/getting-started) -- [Contributing](CONTRIBUTING.md) -- [Contact Us](https://codegen.com/contact) - -## Why Codegen? - -Software development is fundamentally programmatic. Refactoring a codebase, enforcing patterns, or analyzing control flow - these are all operations that can (and should) be expressed as programs themselves. - -We built Codegen backwards from real-world refactors performed on enterprise codebases. Instead of starting with theoretical abstractions, we focused on creating APIs that match how developers actually think about code changes: - -- **Natural mental model**: Write transforms that read like your thought process - "move this function", "rename this variable", "add this parameter". No more wrestling with ASTs or manual import management. - -- **Battle-tested on complex codebases**: Handle Python, TypeScript, and React codebases with millions of lines of code. - -- **Built for advanced intelligences**: As AI developers become more sophisticated, they need expressive yet precise tools to manipulate code. Codegen provides a programmatic interface that both humans and AI can use to express complex transformations through code itself. - -## Contributing - -Please see our [Contributing Guide](CONTRIBUTING.md) for instructions on how to set up the development environment and submit contributions. - -## Enterprise - -For more information on enterprise engagements, please [contact us](https://codegen.com/contact) or [request a demo](https://codegen.com/request-demo). diff --git a/src/codegen/sdk/_proxy.py b/src/codegen/sdk/_proxy.py deleted file mode 100644 index f50f49766..000000000 --- a/src/codegen/sdk/_proxy.py +++ /dev/null @@ -1,30 +0,0 @@ -import functools -from collections.abc import Callable -from typing import Generic, ParamSpec, TypeVar - -from lazy_object_proxy import Proxy -from lazy_object_proxy.simple import make_proxy_method - -try: - from codegen.sdk.extensions.utils import cached_property -except ModuleNotFoundError: - from functools import cached_property - -T = TypeVar("T") -P = ParamSpec("P") - - -class ProxyProperty(Proxy, Generic[P, T]): - """Lazy proxy that can behave like a method or a property depending on how its used. The class it's proxying should not implement __call__""" - - __factory__: Callable[P, T] - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - return self.__factory__(*args, **kwargs) - - __repr__ = make_proxy_method(repr) - - -def proxy_property(func: Callable[P, T]) -> cached_property[ProxyProperty[P, T]]: - """Proxy a property so it behaves like a method and property simultaneously. When invoked as a property, results are cached and invalidated using uncache_all""" - return cached_property(lambda obj: ProxyProperty(functools.partial(func, obj))) diff --git a/src/codegen/sdk/ai/client.py b/src/codegen/sdk/ai/client.py deleted file mode 100644 index 8902a2fa1..000000000 --- a/src/codegen/sdk/ai/client.py +++ /dev/null @@ -1,5 +0,0 @@ -from openai import OpenAI - - -def get_openai_client(key: str) -> OpenAI: - return OpenAI(api_key=key) diff --git a/src/codegen/sdk/ai/utils.py b/src/codegen/sdk/ai/utils.py deleted file mode 100644 index b903a9a1a..000000000 --- a/src/codegen/sdk/ai/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -import tiktoken - -ENCODERS = { - "gpt-4o": tiktoken.encoding_for_model("gpt-4o"), -} - - -def count_tokens(s: str, model_name: str = "gpt-4o") -> int: - """Uses tiktoken""" - if s is None: - return 0 - enc = ENCODERS.get(model_name, None) - if not enc: - ENCODERS[model_name] = tiktoken.encoding_for_model(model_name) - enc = ENCODERS[model_name] - tokens = enc.encode(s) - return len(tokens) diff --git a/src/codegen/sdk/code_generation/__init__.py b/src/codegen/sdk/code_generation/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/code_generation/changelog_generation.py b/src/codegen/sdk/code_generation/changelog_generation.py deleted file mode 100644 index 1f982e04c..000000000 --- a/src/codegen/sdk/code_generation/changelog_generation.py +++ /dev/null @@ -1,142 +0,0 @@ -import json -from dataclasses import dataclass -from pathlib import Path - -from git import Repo -from openai import OpenAI -from semantic_release import ParsedCommit, ParseError -from semantic_release.changelog.release_history import Release, ReleaseHistory -from semantic_release.cli.cli_context import CliContextObj -from semantic_release.cli.config import GlobalCommandLineOptions - -import codegen -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -SYSTEM_PROMPT = """ -## Role -You are a Release Manager for an open source project and have a gift for gleaning the most important and relevant changes from a list of commits. - -## Objective -You will be given a list of commits for a specifc release and you will need to write a high level summary of the changes in 1 to 5 bullet points and generate a very concise description of the release. -The description should be a maximum of 60 characters and should only highlight the most important change(s). -Please do not include specific details about pull requests or commits, only summarize the changes in the context of the release. - -## Instructions -- Do not include specific details about pull requests or commits, only summarize the changes in the context of the release. -- Do not include any other text than the bullet points and the one sentence description of the release.f -- Do not include pull request links or numbers. -- Only include information that is relevant to users and contributors. -- The description should be a maximum of 60 characters. - -## Output -- Output the bullet points and the one sentence description of the release, no other text. The output should be a json object with the following keys: - - `bullet_points`: A list of bullet points - - `description`: A one sentence description of the release - -## Example Output -{ - "bullet_points": [ - "Add new feature X", - "Fix bug Y", - "Improve performance" - ], - "description": "adds a new feature, fixes a bug, and improves performance." -} - -## Things to exclude -- Removed development package publishing to AWS -- Updated various dependencies and pre-commit hooks -- Do not wrap the output in ```json ```. The output should be a json object that can be parsed with json.loads() - -## Poor Release Descriptions -- "This release includes platform support updates, file handling improvements, and module resolution adjustments." -- "This release adds ARM support for Linux, enhances documentation, and includes dependency updates." - -## Better Release Descriptions -- "Platform support updates" -- "ARM support for Linux" -""" - - -@dataclass -class ContextMock: - config_file = "/Users/jesusmeza/Documents/codegen-sdk/pyproject.toml" - - def get_parameter_source(self, param_name): - if hasattr(self, param_name): - return getattr(self, param_name) - return None - - -def generate_release_summary_context(release: Release): - release_summary_context = {"version": release["version"].tag_format, "date": release["tagged_date"].strftime("%B %d, %Y"), "commits": dict()} - elements = release["elements"] - for title, commits in elements.items(): - release_summary_context["commits"][title] = [] - for parsed_commit in commits: - if isinstance(parsed_commit, ParsedCommit): - release_summary_context["commits"][title].append(parsed_commit.descriptions[0]) - elif isinstance(parsed_commit, ParseError): - release_summary_context["commits"][title].append(parsed_commit.message) - return release_summary_context - - -def generate_release_summary(client: OpenAI, release: Release) -> str: - release_summary_context = generate_release_summary_context(release) - response = client.chat.completions.create( - model="gpt-4o", - max_tokens=1000, - messages=[ - { - "role": "system", - "content": SYSTEM_PROMPT, - }, - { - "role": "user", - "content": f""" -Here is some context on the release: - -{json.dumps(release_summary_context)} - -Please write a high level summary of the changes in 1 to 5 bullet points. -""", - }, - ], - ) - - return json.loads(response.choices[0].message.content) - - -def generate_changelog(client: OpenAI, latest_existing_version: str | None = None): - ctx = CliContextObj(ContextMock(), logger=logger, global_opts=GlobalCommandLineOptions()) - runtime = ctx.runtime_ctx - translator = runtime.version_translator - with Repo(Path(codegen.__file__).parents[2]) as codegen_sdk_repo: - release_history = ReleaseHistory.from_git_history( - repo=codegen_sdk_repo, - translator=translator, - commit_parser=runtime.commit_parser, - exclude_commit_patterns=runtime.changelog_excluded_commit_patterns, - ) - - releases = [] - parsed_releases: list[Release] = release_history.released.values() - parsed_releases = sorted(parsed_releases, key=lambda x: x["tagged_date"], reverse=True) - for release in parsed_releases: - version = f"v{release['version']!s}" - if latest_existing_version and version == latest_existing_version: - break - - tag_url = f"https://github.com/codegen-sh/codegen-sdk/releases/tag/{version}" - release_summary = generate_release_summary(client, release) - release_content = f""" - -### [{release_summary["description"]}]({tag_url}) -- {"\n- ".join(release_summary["bullet_points"])} - -""" - releases.append(release_content) - - return "\n".join(releases) diff --git a/src/codegen/sdk/code_generation/codegen_sdk_codebase.py b/src/codegen/sdk/code_generation/codegen_sdk_codebase.py deleted file mode 100644 index d6ede5175..000000000 --- a/src/codegen/sdk/code_generation/codegen_sdk_codebase.py +++ /dev/null @@ -1,15 +0,0 @@ -import os.path - -from codegen.sdk.code_generation.current_code_codebase import get_codegen_codebase_base_path, get_current_code_codebase -from codegen.sdk.core.codebase import Codebase - - -def get_codegen_sdk_subdirectories() -> list[str]: - base = get_codegen_codebase_base_path() - return [os.path.join(base, "codegen/sdk"), os.path.join(base, "codemods")] - - -def get_codegen_sdk_codebase() -> Codebase: - """Grabs a Codebase w/ GraphSitter content. Responsible for figuring out where it is, e.g. in Modal or local""" - codebase = get_current_code_codebase(subdirectories=get_codegen_sdk_subdirectories()) - return codebase diff --git a/src/codegen/sdk/code_generation/current_code_codebase.py b/src/codegen/sdk/code_generation/current_code_codebase.py deleted file mode 100644 index bfcee2232..000000000 --- a/src/codegen/sdk/code_generation/current_code_codebase.py +++ /dev/null @@ -1,94 +0,0 @@ -# TODO: move out of graph sitter, useful for other projects - -import importlib -from pathlib import Path -from typing import TypedDict - -from codegen.configs.models.codebase import CodebaseConfig -from codegen.configs.models.secrets import SecretsConfig -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig -from codegen.sdk.codebase.config import ProjectConfig -from codegen.sdk.core.codebase import Codebase, CodebaseType -from codegen.shared.decorators.docs import DocumentedObject, apidoc_objects, no_apidoc_objects, py_apidoc_objects, ts_apidoc_objects -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -def get_graphsitter_repo_path() -> str: - """Points to base directory of the Codegen repo (.git) that is currently running""" - import codegen.sdk as sdk - - filepath = sdk.__file__ - codegen_base_dir = filepath.replace("/codegen/sdk/__init__.py", "") - codegen_base_dir = codegen_base_dir.replace("/src", "") - return codegen_base_dir - - -def get_codegen_codebase_base_path() -> str: - import codegen.sdk as sdk - - filepath = sdk.__file__ - codegen_base_dir = filepath.replace("/codegen/sdk/__init__.py", "") - return "src" if "src" in codegen_base_dir else "" - - -def get_current_code_codebase(config: CodebaseConfig | None = None, secrets: SecretsConfig | None = None, subdirectories: list[str] | None = None) -> CodebaseType: - """Returns a Codebase for the code that is *currently running* (i.e. the Codegen repo)""" - codegen_repo_path = get_graphsitter_repo_path() - base_dir = get_codegen_codebase_base_path() - logger.info(f"Creating codebase from repo at: {codegen_repo_path} with base_path {base_dir}") - - repo_config = RepoConfig.from_repo_path(codegen_repo_path) - repo_config.respect_gitignore = False - op = RepoOperator(repo_config=repo_config, bot_commit=False) - - config = (config or CodebaseConfig()).model_copy(update={"base_path": base_dir}) - projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON, subdirectories=subdirectories, base_path=base_dir)] - codebase = Codebase(projects=projects, config=config, secrets=secrets) - return codebase - - -def import_all_codegen_sdk_modules(): - # for file in codegen.sdk: - - CODEGEN_SDK_DIR = Path(get_graphsitter_repo_path()) - if base := get_codegen_codebase_base_path(): - CODEGEN_SDK_DIR /= base - CODEGEN_SDK_DIR /= "codegen/sdk" - for file in CODEGEN_SDK_DIR.rglob("*.py"): - relative_path = file.relative_to(CODEGEN_SDK_DIR) - # ignore braintrust_evaluator because it runs stuff on import - if "__init__" in file.name or "braintrust_evaluator" in file.name: - continue - module_name = "codegen.sdk." + str(relative_path).replace("/", ".").removesuffix(".py") - try: - importlib.import_module(module_name) - except Exception as e: - print(f"Error importing {module_name}: {e}") - - -class DocumentedObjects(TypedDict): - apidoc: list[DocumentedObject] - ts_apidoc: list[DocumentedObject] - py_apidoc: list[DocumentedObject] - no_apidoc: list[DocumentedObject] - - -def get_documented_objects() -> DocumentedObjects: - """Get all the objects decorated with apidoc, py_apidoc, ts_apidoc, and no_apidoc decorators, - by importing all codegen.sdk modules and keeping track of decorated objects at import time using - the respective decorators - """ - import_all_codegen_sdk_modules() - from codegen.sdk.core.codebase import CodebaseType, PyCodebaseType, TSCodebaseType - - if PyCodebaseType not in apidoc_objects: - apidoc_objects.append(DocumentedObject(name="PyCodebaseType", module="codegen.sdk.core.codebase", object=PyCodebaseType)) - if TSCodebaseType not in apidoc_objects: - apidoc_objects.append(DocumentedObject(name="TSCodebaseType", module="codegen.sdk.core.codebase", object=TSCodebaseType)) - if CodebaseType not in apidoc_objects: - apidoc_objects.append(DocumentedObject(name="CodebaseType", module="codegen.sdk.core.codebase", object=CodebaseType)) - return {"apidoc": apidoc_objects, "py_apidoc": py_apidoc_objects, "ts_apidoc": ts_apidoc_objects, "no_apidoc": no_apidoc_objects} diff --git a/src/codegen/sdk/code_generation/doc_utils/__init__.py b/src/codegen/sdk/code_generation/doc_utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py b/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py deleted file mode 100644 index 3370f5686..000000000 --- a/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py +++ /dev/null @@ -1,183 +0,0 @@ -from tqdm import tqdm - -from codegen.sdk.code_generation.doc_utils.parse_docstring import parse_docstring -from codegen.sdk.code_generation.doc_utils.schemas import ClassDoc, GSDocs, MethodDoc -from codegen.sdk.code_generation.doc_utils.utils import create_path, extract_class_description, get_type, get_type_str, has_documentation, is_settter, replace_multiple_types -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.placeholder.placeholder_type import TypePlaceholder - -ATTRIBUTES_TO_IGNORE = [ - "ctx", - "node_id", - "angular", - "model_config", - "constructor_keyword", - "viz", - "console", - "items", - "node_type", - "ts_node", - "file_node_id", - "statement_type", - "assignment_types", -] - - -def generate_docs_json(codebase: Codebase, head_commit: str, raise_on_missing_docstring: bool = False) -> GSDocs: - """Update documentation table for classes, methods and attributes in the codebase. - - Args: - codebase (Codebase): the codebase to update the docs for - head_commit (str): the head commit hash - Returns: - dict[str, dict[str, Any]]: the documentation for the codebase - """ - codegen_sdk_docs = GSDocs(classes=[]) - types_cache = {} - - def process_class_doc(cls): - """Update or create documentation for a class.""" - description = cls.docstring.source.strip('"""') if cls.docstring else None - parent_classes = [f"<{create_path(parent)}>" for parent in cls.superclasses if isinstance(parent, Class) and has_documentation(parent)] - - cls_doc = ClassDoc( - title=cls.name, - description=extract_class_description(description), - content=" ", - path=create_path(cls), - inherits_from=parent_classes, - version=str(head_commit), - github_url=cls.github_url, - ) - - return cls_doc - - def process_method(method, cls, cls_doc, seen_methods): - """Process a single method and update its documentation.""" - if any(dec.name == "noapidoc" for dec in method.decorators): - return - - if method.name in seen_methods and not is_settter(method): - return - - if not method.docstring: - msg = f"Method {cls.name}.{method.name} does not have a docstring" - raise ValueError(msg) - - method_path = create_path(method, cls) - parameters = [] - - parsed = parse_docstring(method.docstring.source) - if parsed is None: - msg = f"Method {cls.name}.{method.name} docstring does not exist or has incorrect format." - raise ValueError(msg) - - # Update parameter types - for param, parsed_param in zip(method.parameters[1:], parsed["arguments"]): - if param.name == parsed_param.name: - if isinstance(param.type, TypePlaceholder): - resolved_types = [] - else: - resolved_types = param.type.resolved_types - - parsed_param.type = replace_multiple_types( - codebase=codebase, input_str=parsed_param.type, resolved_types=resolved_types, parent_class=cls, parent_symbol=method, types_cache=types_cache - ) - if param.default: - parsed_param.default = param.default - - parameters.append(parsed_param) - # Update return type - - if not isinstance(method.return_type, TypePlaceholder): - return_type = replace_multiple_types( - codebase=codebase, input_str=method.return_type.source, resolved_types=method.return_type.resolved_types, parent_class=cls, parent_symbol=method, types_cache=types_cache - ) - else: - return_type = None - parsed["return_types"] = [return_type] - - meta_data = {"parent": create_path(method.parent_class), "path": method.file.filepath} - return MethodDoc( - name=method.name, - description=parsed["description"], - parameters=parsed["arguments"], - return_type=parsed["return_types"], - return_description=parsed["return_description"], - method_type=get_type(method), - code=method.function_signature, - path=method_path, - raises=parsed["raises"], - metainfo=meta_data, - version=str(head_commit), - github_url=method.github_url, - ) - - def process_attribute(attr, cls, cls_doc, seen_methods): - """Process a single attribute and update its documentation.""" - if attr.name in seen_methods or attr.name in ATTRIBUTES_TO_IGNORE: - return - - attr_path = create_path(attr, cls) - - description = attr.docstring(attr.parent_class) - if raise_on_missing_docstring and not description: - msg = f"Attribute {attr.parent_class.name}.{attr.name} does not have a docstring" - raise ValueError(msg) - attr_return_type = [] - if r_type := get_type_str(attr): - if isinstance(r_type, TypePlaceholder): - resolved_types = [] - else: - resolved_types = r_type.resolved_types - r_type_source = replace_multiple_types(codebase=codebase, input_str=r_type.source, resolved_types=resolved_types, parent_class=cls, parent_symbol=attr, types_cache=types_cache) - attr_return_type.append(r_type_source) - - attr_info = {"description": description, "attr_return_type": attr_return_type} - - meta_data = {"parent": create_path(attr.parent_class), "path": attr.file.filepath} - - return MethodDoc( - name=attr.name, - description=attr_info["description"], - parameters=[], - return_type=attr_info["attr_return_type"], - return_description=None, - method_type="attribute", - code=attr.attribute_docstring, - path=attr_path, - raises=[], - metainfo=meta_data, - version=str(head_commit), - github_url=attr.github_url, - ) - - # Process all documented classes - documented_classes = [cls for cls in codebase.classes if has_documentation(cls)] - - for cls in tqdm(documented_classes): - cls_doc = process_class_doc(cls) - codegen_sdk_docs.classes.append(cls_doc) - seen_methods = set() - - # Process methods - for method in cls.methods(max_depth=None, private=False, magic=False): - method_doc = process_method(method, cls, cls_doc, seen_methods) - if not method_doc: - continue - seen_methods.add(method_doc.name) - cls_doc.methods.append(method_doc) - - # Process attributes - for attr in cls.attributes(max_depth=None, private=False): - if attr.name in ATTRIBUTES_TO_IGNORE: - continue - - attr_doc = process_attribute(attr, cls, cls_doc, seen_methods) - if not attr_doc: - continue - seen_methods.add(attr_doc.name) - cls_doc.attributes.append(attr_doc) - - return codegen_sdk_docs diff --git a/src/codegen/sdk/code_generation/doc_utils/parse_docstring.py b/src/codegen/sdk/code_generation/doc_utils/parse_docstring.py deleted file mode 100644 index d367ad7b3..000000000 --- a/src/codegen/sdk/code_generation/doc_utils/parse_docstring.py +++ /dev/null @@ -1,68 +0,0 @@ -import re - -from codegen.sdk.code_generation.doc_utils.schemas import ParameterDoc - -SECTION_PATTERN = re.compile(r"(Args|Returns|Raises|Note):\s*(.+?)(?=(?:Args|Returns|Raises|Note):|$)", re.DOTALL) -ARG_PATTERN = re.compile(r"\s*(\w+)\s*\(([^)]+)\):\s*([^\n]+)") - - -def parse_docstring(docstring: str) -> dict | None: - """Parse a docstring into its components with optimized performance. - - Args: - docstring (str): The docstring to parse - - Returns: - dict | None: Parsed docstring components or None if parsing fails - """ - # Strip once at the start - docstring = docstring.strip().strip('"""').strip("'''") - - # Initialize result dictionary - result = {"description": "", "arguments": [], "return_description": None, "raises": [], "note": None} - - # Find all sections - sections = {match.group(1): match.group(2).strip() for match in SECTION_PATTERN.finditer(docstring)} - - # Get description (everything before first section) - first_section = docstring.find(":") - if first_section != -1: - result["description"] = docstring[:first_section].split("\n")[0].strip() - else: - result["description"] = docstring.split("\n")[0].strip() - - # Parse Args section - if "Args" in sections: - args_text = sections["Args"] - if args_text.lower() != "none": - result["arguments"] = [ParameterDoc(name=m.group(1), type=m.group(2), description=m.group(3).strip()) for m in ARG_PATTERN.finditer(args_text)] - - # Parse Returns section - if "Returns" in sections: - returns_text = sections["Returns"] - # Split on colon to separate type and description - parts = returns_text.split(":", 1) - if len(parts) > 1: - # Only keep the description part after the colon - result["return_description"] = " ".join(line.strip() for line in parts[1].split("\n") if line.strip()) - else: - # If there's no colon, check if it's just a plain description without types - # Remove any type-like patterns (words followed by brackets or vertical bars) - cleaned_text = re.sub(r"^[^:]*?(?=\s*[A-Za-z].*:|\s*$)", "", returns_text) - if cleaned_text: - result["return_description"] = " ".join(line.strip() for line in cleaned_text.split("\n") if line.strip()) - - # Parse Raises section - if "Raises" in sections: - raises_text = sections["Raises"] - for line in raises_text.split("\n"): - if ":" in line: - exc_type, desc = line.split(":", 1) - if exc_type.strip(): - result["raises"].append({"type": exc_type.strip(), "description": desc.strip()}) - - # Parse Note section - if "Note" in sections: - result["note"] = " ".join(line.strip() for line in sections["Note"].split("\n") if line.strip()) - - return result diff --git a/src/codegen/sdk/code_generation/doc_utils/schemas.py b/src/codegen/sdk/code_generation/doc_utils/schemas.py deleted file mode 100644 index d1243a9c6..000000000 --- a/src/codegen/sdk/code_generation/doc_utils/schemas.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel -from pydantic.fields import Field - - -class ParameterDoc(BaseModel): - name: str = Field(..., description="The name of the parameter") - description: str = Field(..., description="The description of the parameter") - type: str = Field(..., description="The type of the parameter") - default: str = Field(default="", description="The default value of the parameter") - - -class MethodDoc(BaseModel): - name: str = Field(..., description="The name of the method") - description: str | None = Field(..., description="The description of the method") - parameters: list[ParameterDoc] = Field(..., description="The parameters of the method") - return_type: list[str] | None = Field(default=None, description="The return types of the method") - return_description: str | None = Field(default=None, description="The return description of the method") - method_type: Literal["method", "property", "attribute"] = Field(..., description="The type of the method") - code: str = Field(..., description="The signature of the method or attribute") - path: str = Field(..., description="The path of the method that indicates its parent class //") - raises: list[dict] | None = Field(..., description="The raises of the method") - metainfo: dict = Field(..., description="Information about the method's true parent class and path") - version: str = Field(..., description="The commit hash of the git commit that generated the docs") - github_url: str = Field(..., description="The github url of the method") - - -class ClassDoc(BaseModel): - title: str = Field(..., description="The title of the class") - description: str = Field(..., description="The description of the class") - content: str = Field(..., description="The content of the class") - path: str = Field(..., description="The path of the class") - inherits_from: list[str] = Field(..., description="The classes that the class inherits from") - version: str = Field(..., description="The commit hash of the git commit that generated the docs") - methods: list[MethodDoc] = Field(default=[], description="The methods of the class") - attributes: list[MethodDoc] = Field(default=[], description="The attributes of the class") - github_url: str = Field(..., description="The github url of the class") - - -class GSDocs(BaseModel): - classes: list[ClassDoc] = Field(..., description="The classes to document") diff --git a/src/codegen/sdk/code_generation/doc_utils/utils.py b/src/codegen/sdk/code_generation/doc_utils/utils.py deleted file mode 100644 index b074de009..000000000 --- a/src/codegen/sdk/code_generation/doc_utils/utils.py +++ /dev/null @@ -1,408 +0,0 @@ -import re -import textwrap - -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.function import Function -from codegen.sdk.core.interfaces.callable import Callable -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.python.statements.attribute import PyAttribute -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -# These are the classes that are not language specific, but have language specific subclasses with different names -SPECIAL_BASE_CLASSES = {"SourceFile": "File"} - - -def sanitize_docstring_for_markdown(docstring: str | None) -> str: - """Sanitize the docstring for MDX""" - if docstring is None: - return "" - docstring_lines = docstring.splitlines() - if len(docstring_lines) > 1: - docstring_lines[1:] = [textwrap.dedent(line) for line in docstring_lines[1:]] - docstring = "\n".join(docstring_lines) - if docstring.startswith('"""'): - docstring = docstring[3:] - if docstring.endswith('"""'): - docstring = docstring[:-3] - return docstring - - -def sanitize_mdx_mintlify_description(content: str) -> str: - """Mintlify description field needs to have string escaped, which content doesn't need. - the must be parsing the description differently or something - """ - content = sanitize_docstring_for_markdown(content) - # make sure all `< />` components are properly escaped with a `` inline-block - # if the string already has the single-quote then this is a no-op - content = re.sub(r"(?]+>)(?!`)", r"`\1`", content) - - # escape double quote characters - if re.search(r'\\"', content): - return content # No-op if already escaped - return re.sub(r'(")', r"\\\1", content) - - -def sanitize_html_for_mdx(html_string: str) -> str: - """Sanitize HTML string for MDX by escaping double quotes in attribute values. - - Args: - html_string (str): The input HTML string to sanitize - - Returns: - str: The sanitized HTML string with escaped quotes - """ - # Replace double quotes with " but only in HTML attributes - return re.sub(r'"', """, html_string) - - -def get_type_str(parent, curr_depth=0, max_depth=5): - """Returns the type node for an attribute.""" - if curr_depth >= max_depth: - return None - if isinstance(parent, Type): - return parent - for child in parent.children: - if attr_type := get_type_str(child, curr_depth=curr_depth + 1): - return attr_type - return None - - -def is_language_base_class(cls_obj: Class): - """Returns true if `cls_obj` is a direct parent of a language specific class. - - For example, `Symbol` which is a direct parent of `PySymbol` and `TsSymbol` is a language base class - and `Editable` is not. - - Args: - cls_obj (Class): the class object to check - - Returns: - bool: if `cls_obj` is a language base class - """ - if cls_obj.name in SPECIAL_BASE_CLASSES: - return True - - sub_classes = cls_obj.subclasses(max_depth=1) - base_name = cls_obj.name.lower() - return any(sub_class.name.lower() in [f"py{base_name}", f"ts{base_name}"] for sub_class in sub_classes) - - -def get_section(symbol: Symbol, parent_class: Class | None = None): - if parent_class: - doc_section = parent_class.filepath.split("/")[1] - else: - doc_section = symbol.filepath.split("/")[1] - return doc_section - - -def get_language(symbol: Class | Function | PyAttribute) -> str: - """Gets the language of which the symbol is an abstract representation. - - Args: - symbol (Class | Function | PyAttribute): the symbol to get the langauge of - Returns: - str: the language of the symbol - """ - if ProgrammingLanguage.PYTHON.value.lower() in symbol.filepath: - return ProgrammingLanguage.PYTHON.value - elif ProgrammingLanguage.TYPESCRIPT.value.lower() in symbol.filepath: - return ProgrammingLanguage.TYPESCRIPT.value - elif isinstance(symbol, Class) and is_language_base_class(symbol): - return "NONE" - elif isinstance(symbol.parent_class, Class) and is_language_base_class(symbol.parent_class): - return "NONE" - else: - return "ALL" - - -def get_type(method: Function): - """Return the type of method. - - Args: - method (Function): the method to check the type of. - - Returns: - str: `property` if the method is a property, `method` otherwise. - """ - if method.is_property: - return "property" - else: - return "method" - - -def is_settter(m: Function): - """Checks if `m` is a setter method - Args: - m (Function): the function (method) to check - Returns: - bool: `True` if `m` is a setter method, `False` otherwise - """ - return any([dec.name == f"{m.name}.setter" for dec in m.decorators]) - - -def create_path(symbol: Class | Function | PyAttribute, parent_class: Class | None = None) -> str: - """Creates a route path for `symbol` that will be used in the frontend - - Args: - symbol (Class | Function | PyAttribute): the object for which a path should be created - parent_class (Class | None): optional parent class of the method - Returns: - str: route path of `symbol` - """ - name = symbol.name - language = get_language(symbol) - - if language == ProgrammingLanguage.PYTHON.value: - doc_section = ProgrammingLanguage.PYTHON.value.lower() - elif language == ProgrammingLanguage.TYPESCRIPT.value: - doc_section = ProgrammingLanguage.TYPESCRIPT.value.lower() - else: - doc_section = "core" - - if isinstance(symbol, Class): - return f"api-reference/{doc_section}/{name}" - - if parent_class: - parent_name = parent_class.name - else: - parent_name = symbol.parent_class.name - - if isinstance(symbol, Function) and is_settter(symbol): - return f"api-reference/{doc_section}/{parent_name}/set_{name}" - - return f"api-reference/{doc_section}/{parent_name}/{name}" - - -def has_documentation(c: Class): - """If the class c is meant to be documented. - - Args: - c (Class): the class to check - Returns: - bool: `True` if the class is meant to be documented, `False` otherwise - """ - return any([dec.name == "ts_apidoc" or dec.name == "py_apidoc" or dec.name == "apidoc" for dec in c.decorators]) - - -def safe_get_class(codebase: Codebase, class_name: str, language: str | None = None) -> Class | None: - """Find the class in the codebase. - - Args: - codebase (Codebase): the codebase to search in - class_name (str): the name of the class to resolve - language (str | None): the language of the class to resolve - Returns: - Class | None: the class if found, None otherwise - """ - if '"' in class_name: - class_name = class_name.strip('"') - if "'" in class_name: - class_name = class_name.strip("'") - - symbols = [] - try: - class_obj = codebase.get_class(class_name, optional=True) - if not class_obj: - return None - - except Exception: - symbols = codebase.get_symbols(class_name) - possible_classes = [s for s in symbols if isinstance(s, Class) and has_documentation(s)] - if not possible_classes: - return None - if len(possible_classes) > 1: - msg = f"Found {len(possible_classes)} classes with name {class_name}" - raise ValueError(msg) - class_obj = possible_classes[0] - - if language and is_language_base_class(class_obj): - sub_classes = class_obj.subclasses(max_depth=1) - - if class_name in SPECIAL_BASE_CLASSES: - class_name = SPECIAL_BASE_CLASSES[class_name] - - if language == ProgrammingLanguage.PYTHON.value: - sub_classes = [s for s in sub_classes if s.name == f"Py{class_name}"] - elif language == ProgrammingLanguage.TYPESCRIPT.value: - sub_classes = [s for s in sub_classes if s.name == f"TS{class_name}"] - if len(sub_classes) == 1: - class_obj = sub_classes[0] - return class_obj - - -def resolve_type_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type], parent_class: Class, parent_symbol: Symbol, types_cache: dict): - """Find the symbol in the codebase. - - Args: - codebase (Codebase): the codebase to search in - symbol_name (str): the name of the symbol to resolve - resolved_types (list[Type]): the resolved types of the symbol - parent_class (Class): the parent class of the symbol - types_cache (dict): the cache to store the results in - Returns: - str: the route path of the symbol - """ - if symbol_name in ["list", "tuple", "int", "str", "dict", "set", "None", "bool", "optional", "Union"]: - return symbol_name - if symbol_name.lower() == "self": - return f"<{create_path(parent_class)}>" - - language = get_language(parent_class) - if (symbol_name, language) in types_cache: - return types_cache[(symbol_name, language)] - - trgt_symbol = None - cls_obj = safe_get_class(codebase=codebase, class_name=symbol_name, language=language) - if cls_obj: - trgt_symbol = cls_obj - - if not trgt_symbol: - if symbol := parent_symbol.file.get_symbol(symbol_name): - for resolved_type in symbol.resolved_types: - if isinstance(resolved_type, FunctionCall) and len(resolved_type.args) >= 2: - bound_arg = resolved_type.args[1] - bound_name = bound_arg.value.source - if cls_obj := safe_get_class(codebase, bound_name, language=get_language(parent_class)): - trgt_symbol = cls_obj - break - - elif symbol := codebase.get_symbol(symbol_name, optional=True): - if len(symbol.resolved_types) == 1: - trgt_symbol = symbol.resolved_types[0] - - if trgt_symbol and isinstance(trgt_symbol, Callable) and has_documentation(trgt_symbol): - trgt_path = f"<{create_path(trgt_symbol)}>" - types_cache[(symbol_name, language)] = trgt_path - return trgt_path - - return symbol_name - - -def replace_multiple_types(codebase: Codebase, input_str: str, resolved_types: list[Type], parent_class: Class, parent_symbol: Symbol, types_cache: dict) -> str: - """Replace multiple types in a string. - - Args: - codebase (Codebase): the codebase to search in - input_str (str): the string to replace the types in - parent_class (Class): the parent class of the symbol - types_cache (dict): the cache to store the results in - Returns: - str: the string with the types replaced - """ - # Remove outer quotes if present - input_str = input_str.replace('"', "") - - def process_parts(content): - # Handle nested brackets recursively - stack = [] - current = "" - parts = [] - separators = [] - in_quotes = False - quote_char = None - - i = 0 - while i < len(content): - char = content[i] - - # Handle quotes - if char in "\"'": - if not in_quotes: - in_quotes = True - quote_char = char - elif char == quote_char: - in_quotes = False - current += char - # Only process special characters if we're not in quotes - elif not in_quotes: - if char == "[": - stack.append("[") - current += char - elif char == "]": - if stack: - stack.pop() - current += char - elif (char in ",|") and not stack: # Only split when not inside brackets - if current.strip(): - parts.append(current.strip()) - separators.append(char) - current = "" - else: - current += char - else: - current += char - i += 1 - - if current.strip(): - parts.append(current.strip()) - - # Process each part - processed_parts = [] - for part in parts: - # Check if the part is quoted - if part.startswith('"') and part.endswith('"'): - processed_parts.append(part) # Keep quoted parts as-is - continue - - # Check if the part itself contains brackets - if "[" in part: - base_type = part[: part.index("[")] - bracket_content = part[part.index("[") :].strip("[]") - processed_bracket = process_parts(bracket_content) - replacement = resolve_type_symbol( - codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache - ) - processed_part = replacement + "[" + processed_bracket + "]" - else: - replacement = resolve_type_symbol(codebase=codebase, symbol_name=part, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) - processed_part = replacement - processed_parts.append(processed_part) - - # Reconstruct with original separators - result = processed_parts[0] - for i in range(len(separators)): - result += f"{separators[i]} {processed_parts[i + 1]}" - - return result - - # Check if the input contains any separators - if any(sep in input_str for sep in ",|"): - return process_parts(input_str) - # Handle bracketed input - elif "[" in input_str: - base_type = input_str[: input_str.index("[")] - bracket_content = input_str[input_str.index("[") :].strip("[]") - processed_content = process_parts(bracket_content) - replacement = resolve_type_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) - return replacement + "[" + processed_content + "]" - # Handle simple input - else: - replacement = resolve_type_symbol(codebase=codebase, symbol_name=input_str, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) - return replacement - - -def extract_class_description(docstring): - """Extract the class description from a docstring, excluding the attributes section. - - Args: - docstring (str): The class docstring to parse - - Returns: - str: The class description with whitespace normalized - """ - if not docstring: - return "" - - # Split by "Attributes:" and take only the first part - parts = docstring.split("Attributes:") - description = parts[0] - - # Normalize whitespace - lines = [line.strip() for line in description.strip().splitlines()] - return " ".join(filter(None, lines)) diff --git a/src/codegen/sdk/code_generation/enums.py b/src/codegen/sdk/code_generation/enums.py deleted file mode 100644 index 9905db6b9..000000000 --- a/src/codegen/sdk/code_generation/enums.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class DocumentationDecorators(StrEnum): - PYTHON = "py_apidoc" - TYPESCRIPT = "ts_apidoc" - CODEMOD = "canonical" - GENERAL_API = "apidoc" diff --git a/src/codegen/sdk/code_generation/mdx_docs_generation.py b/src/codegen/sdk/code_generation/mdx_docs_generation.py deleted file mode 100644 index 648a3b68e..000000000 --- a/src/codegen/sdk/code_generation/mdx_docs_generation.py +++ /dev/null @@ -1,204 +0,0 @@ -import re - -from codegen.sdk.code_generation.doc_utils.schemas import ClassDoc, MethodDoc, ParameterDoc -from codegen.sdk.code_generation.doc_utils.utils import sanitize_html_for_mdx, sanitize_mdx_mintlify_description - - -def render_mdx_page_for_class(cls_doc: ClassDoc) -> str: - """Renders the MDX for a single class""" - return f"""{render_mdx_page_title(cls_doc)} -{render_mdx_inheritence_section(cls_doc)} -{render_mdx_attributes_section(cls_doc)} -{render_mdx_methods_section(cls_doc)} -""" - - -def render_mdx_page_title(cls_doc: ClassDoc, icon: str | None = None) -> str: - """Renders the MDX for the page title""" - page_desc = cls_doc.description if hasattr(cls_doc, "description") else "" - - return f"""--- -title: "{cls_doc.title}" -sidebarTitle: "{cls_doc.title}" -icon: "{icon if icon else ""}" -description: "{sanitize_mdx_mintlify_description(page_desc)}" ---- -import {{Parameter}} from '/snippets/Parameter.mdx'; -import {{ParameterWrapper}} from '/snippets/ParameterWrapper.mdx'; -import {{Return}} from '/snippets/Return.mdx'; -import {{HorizontalDivider}} from '/snippets/HorizontalDivider.mdx'; -import {{GithubLinkNote}} from '/snippets/GithubLinkNote.mdx'; -import {{Attribute}} from '/snippets/Attribute.mdx'; - - -""" - - -def render_mdx_inheritence_section(cls_doc: ClassDoc) -> str: - """Renders the MDX for the inheritence section""" - # Filter on parents who we have docs for - parents = cls_doc.inherits_from - if not parents: - return "" - parents_string = ", ".join([parse_link(parent) for parent in parents]) - return f"""### Inherits from -{parents_string} -""" - - -def render_mdx_attributes_section(cls_doc: ClassDoc) -> str: - """Renders the MDX for the attributes section""" - sorted_attributes = sorted(cls_doc.attributes + [method for method in cls_doc.methods if method.method_type == "property"], key=lambda x: x.name) - if len(sorted_attributes) <= 0: - return "" - attributes_mdx_string = "\n".join([render_mdx_for_attribute(attribute) for attribute in sorted_attributes]) - - return f"""## Attributes - -{attributes_mdx_string} -""" - - -def render_mdx_methods_section(cls_doc: ClassDoc) -> str: - """Renders the MDX for the methods section""" - sorted_methods = sorted(cls_doc.methods, key=lambda x: x.name) - if len(sorted_methods) <= 0: - return "" - methods_mdx_string = "\n".join([render_mdx_for_method(method) for method in sorted_methods if method.method_type == "method"]) - - return f"""## Methods - -{methods_mdx_string} -""" - - -def render_mdx_for_attribute(attribute: MethodDoc) -> str: - """Renders the MDX for a single attribute""" - attribute_docstring = sanitize_mdx_mintlify_description(attribute.description) - if len(attribute.return_type) > 0: - return_type = f"{resolve_type_string(attribute.return_type[0])}" - else: - return_type = "" - if not attribute_docstring: - attribute_docstring = "\n" - return f"""### {attribute.name} - -"} }} description="{attribute_docstring}" /> -""" - - -######################################################################################################################## -# METHODS -######################################################################################################################## - - -def format_parameter_for_mdx(parameter: ParameterDoc) -> str: - type_string = resolve_type_string(parameter.type) - return f""" - -""".strip() - - -def format_parameters_for_mdx(parameters: list[ParameterDoc]) -> str: - return "\n".join([format_parameter_for_mdx(parameter) for parameter in parameters]) - - -def format_return_for_mdx(return_type: list[str], return_description: str) -> str: - description = sanitize_html_for_mdx(return_description) if return_description else "" - return_type = resolve_type_string(return_type[0]) - - return f""" - -""" - - -def render_mdx_for_method(method: MethodDoc) -> str: - description = sanitize_mdx_mintlify_description(method.description) - # =====[ RENDER ]===== - # TODO add links here - # TODO add inheritence info here - mdx_string = f"""### {method.name} -{description} - -""" - if method.parameters: - mdx_string += f""" - -{format_parameters_for_mdx(method.parameters)} - -""" - if method.return_type: - mdx_string += f""" -{format_return_for_mdx(method.return_type, method.return_description)} -""" - - return mdx_string - - -def get_mdx_route_for_class(cls_doc: ClassDoc) -> str: - """Get the expected MDX route for a class - split by /core, /python, and /typescript - """ - lower_class_name = cls_doc.title.lower() - if lower_class_name.startswith("py"): - return f"codebase-sdk/python/{cls_doc.title}" - elif lower_class_name.startswith(("ts", "jsx")): - return f"codebase-sdk/typescript/{cls_doc.title}" - else: - return f"codebase-sdk/core/{cls_doc.title}" - - -def format_type_string(type_string: str) -> str: - type_string = type_string.split("|") - return " | ".join([type_str.strip() for type_str in type_string]) - - -def resolve_type_string(type_string: str) -> str: - if "<" in type_string: - return f"<>{parse_link(type_string, href=True)}" - else: - return f'{format_type_string(type_string)}' - - -def format_builtin_type_string(type_string: str) -> str: - if "|" in type_string: - type_strings = type_string.split("|") - return " | ".join([type_str.strip() for type_str in type_strings]) - return type_string - - -def span_type_string_by_pipe(type_string: str) -> str: - if "|" in type_string: - type_strings = type_string.split("|") - return " | ".join([f"{type_str.strip()}" for type_str in type_strings]) - return type_string - - -def parse_link(type_string: str, href: bool = False) -> str: - # Match components with angle brackets, handling nested structures - - parts = [p for p in re.split(r"(<[^>]+>)", type_string) if p] - - result = [] - for part in parts: - if part.startswith("<") and part.endswith(">"): - # Extract the path from between angle brackets - path = part[1:-1] - symbol = path.split("/")[-1] - - # Create a Link object - link = f'{symbol}' if href else f"[{symbol}](/{path})" - result.append(link) - else: - part = format_builtin_type_string(part) - if href: - result.append(f"{part.strip()}") - else: - result.append(part.strip()) - - return " ".join(result) diff --git a/src/codegen/sdk/code_generation/prompts/__init__.py b/src/codegen/sdk/code_generation/prompts/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/code_generation/prompts/api_docs.py b/src/codegen/sdk/code_generation/prompts/api_docs.py deleted file mode 100644 index 7cf2e87c4..000000000 --- a/src/codegen/sdk/code_generation/prompts/api_docs.py +++ /dev/null @@ -1,235 +0,0 @@ -from codegen.sdk.code_generation.codegen_sdk_codebase import get_codegen_sdk_codebase -from codegen.sdk.code_generation.prompts.utils import get_api_classes_by_decorator, get_codegen_sdk_class_docstring -from codegen.sdk.core.codebase import Codebase -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -# TODO: the agent in codegen-backend and codegen-frontend does not use any of this. we have api_docs.py in codegen-backend!!! - -######################################################################################################################## -# UTILS -######################################################################################################################## - - -def get_docstrings_for_classes(codebase: Codebase, language: ProgrammingLanguage, classnames: list[str]) -> dict[str, str]: - """Returns map of ClassName -> Docstring""" - classes = get_api_classes_by_decorator(codebase=codebase, language=language) - class_docstrings = {k: get_codegen_sdk_class_docstring(cls=v, codebase=codebase) for k, v in classes.items()} - return {k: class_docstrings[k] for k in classnames} - - -######################################################################################################################## -# API STUBS -######################################################################################################################## -# This is like `PyFile` definition etc. - - -def get_codebase_docstring(codebase: Codebase, language: ProgrammingLanguage) -> str: - """Returns the docstring for the `Codebase` class.""" - docstrings = get_docstrings_for_classes(codebase, language, ["Codebase"]) - docstring = docstrings["Codebase"] - return f""" -The `Codebase` class is the main entrypoint to manipulating a codebase with GraphSitter. It implements the core methods that allow you to identify important symbols, make changes to the codebase, and commit those changes. - -{docstring} -""" # noqa: E501 - - -def get_behavior_docstring(codebase: Codebase, language: ProgrammingLanguage) -> str: - """These are the core classes in GraphSitter - they define things like `HasName` etc.""" - behavior_classnames = [ - "Editable", - "Typeable", - "HasBlock", - "Name", - "HasName", - "Value", - "HasValue", - "Importable", - "Exportable", - "Callable", - # "GraphSitterBase", - ] - docstrings = get_docstrings_for_classes(codebase, language, behavior_classnames) - cls_sections = "\n\n".join([docstrings[cls] for cls in behavior_classnames]) - return f""" -The following classes represent "behaviors" in GraphSitter that apply to potentially many entities. For example, many types will inherit `HasName` and will then support `x.name`, `x.set_name(new_name)`, etc. - -Look in the type inheritance of the core symbols to see which behaviors they support. - -{cls_sections} -""" # noqa: E501 - - -######################################################################################################################## -# CORE SYMBOLS -######################################################################################################################## - - -def get_core_symbol_docstring(codebase: Codebase, language: ProgrammingLanguage) -> str: - """This should return the docstrings for the symbol types in GraphSitter. Also should include the language-specific extensions.""" - symbol_types = [ - "File", - "Statement", - "CodeBlock", - "AssignmentStatement", - "ImportStatement", - "Import", - # "ImportResolution", - "Export", - "Symbol", - "Usage", - "Assignment", - "Function", - "Parameter", - "Argument", - "FunctionCall", - "Class", - "Attribute", - "Decorator", - "Comment", - "ReturnStatement", - "ExternalModule", - ] - if language == ProgrammingLanguage.TYPESCRIPT: - symbol_types.extend(["JSXElement", "JSXExpression", "JSXProp"]) - - docstrings = get_docstrings_for_classes(codebase, language, symbol_types) - cls_sections = "\n\n".join([docstrings[cls] for cls in symbol_types]) - return f""" -The following classes represent the core symbol types in GraphSitter. These classes are used to represent the various entities in a codebase, such as files, functions, classes, etc. - -Most codemods will begin by identifying the symbols in the codebase that need to be modified by searching and filtering through these symbol types, then calling various edit methods on them or their sub-components - -{cls_sections} -""" # noqa: E501 - - -######################################################################################################################## -# LANGUAGE SPECIFIC -######################################################################################################################## - - -def get_language_specific_docstring(codebase: Codebase, language: ProgrammingLanguage) -> str: - # =====[ Get language prefix ]===== - if language == ProgrammingLanguage.PYTHON: - prefix = "Py" - else: - prefix = "TS" - - # =====[ Grab docstrings ]===== - classes = get_api_classes_by_decorator(codebase=codebase, language=language) - class_docstrings = {k: get_codegen_sdk_class_docstring(cls=v, codebase=codebase) for k, v in classes.items()} - docstrings = {k: v for k, v in class_docstrings.items() if k.startswith(prefix)} - - # =====[ Get mapping from e.g. File => PyFile and TFile => PyFile ]===== - names = list(docstrings.keys()) - # stripped_names = [name.replace(prefix, "") for name in names] - # inherit_mapping = {k: v for k, v in zip(stripped_names, names)} - # type_mapping = {f"T{k}": v for k, v in inherit_mapping.items()} - # name_mapping = {**inherit_mapping, **type_mapping} - - cls_docstrings = "\n\n".join([docstrings[name] for name in names]) - return f""" -Here are language-specific extensions of some of the classes above. Anywhere you see TFile as a type, that's the generic type that corresponds to these classes. - -For example, all `File` that you encounter will be of type {prefix}File, {prefix}File inherits all methods from `File`. - -{cls_docstrings} -""" - - -######################################################################################################################## -# FULL DOCS -######################################################################################################################## - - -def get_codegen_sdk_docs(language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, codebase: Codebase | None = None) -> str: - """Computes the GraphSitter docs from scratch""" - codebase = codebase or get_codegen_sdk_codebase() - with codebase.session(sync_graph=False, commit=False): - return f""" -# Codegen SDK Docs - -Codegen SDK is a Python SDK for writing powerful programs that operate on codebases. In essence, it is a scriptable, multi-language language server that is optimized for fast code transformations and analytics. - -Consider the following: -```python -# Sets docstring to "hello world!" for all classes that end with `Resource` -num_edited = 0 -for cls in codebase.classes: - if cls.name.endswith("Resource"): - cls.set_docstring("hello, world!") # Handles all edge cases + formatting for properly setting docstrings - num_edited += 1 -# Provide developer-facing analytics on the output -print(f'⚡ Edited: {{num_edited}}') -``` -As demonstrated, you can concisely express powerful transformations and analytics on codebases with GraphSitter. - -## Motivation - -Traditional "codemods" are difficult to write and maintain due to the complexities of parsing, import resolution, and more. - -GraphSitter is specifically designed to enable AI agents to efficiently write code transformations and analytics. It enables agents to "act via code" and make powerful changes with guaranteed correctness and with minimal effort. Future additions to this library will enable agents to interact with other systems besides code via the GraphSitter API. - -## Architecture Overview - -GraphSitter enables manipulations of a codebase via a Python SDK. - -The SDK provides a set of classes that represent the various entities in a codebase, such as files, directories, functions, types, etc. These classes are designed to be used in conjunction with the `Codebase` class, which is the entrypoint to most operations. - -These classes and the `Codebase` object enable common transformations like `move_to_file`, `set_docstring`, `add_type_annotation`, etc. - -A GraphSitter codemod is implemented as a Python script that operates on a global `Codebase`, such as the following: -```python -# Sets return type to `None` for all functions that do not have a return type. (This is on a Python Codebase) -file = codebase.get_file("src/app/main.py") # or .ts, if you are operating on a TypeScript codebase -for function in file.functions: - if function.name != "main": - if len(function.return_statements) == 0 and not function.return_type: - function.set_return_type("None") # or `null` if you are operating on a TypeScript codebase -``` - -As you can see, a typical codemod will: -- Identify the symbols, files, etc. in a codebase that need to be modified (typically a set of for loops and nested conditionals) -- Make the necessary changes to the codebase by interacting with GraphSitter classes and methods (typically calling `.edit(...)` or other methods, that will call this under the hood.) - -Given a Codemod like so, the Codegen infrastructure will: -- Run the codemod efficiently -- Visualise the diff, log or other artifacts created for the developer -- Split up the changes into logical PRs, e.g. by CODEOWNER or by file (according to the developer's request) -- Upload results to version control (e.g. GitHub) - - -## `Codebase` Class Documentation - -{get_codebase_docstring(codebase=codebase, language=language)} - -## Core Symbol Type Classes Documentation - -{get_core_symbol_docstring(codebase=codebase, language=language)} - -## Language-specific Extensions Documentation - -{get_language_specific_docstring(codebase=codebase, language=language)} - -## Behaviors and Common Classes Documentation - -{get_behavior_docstring(codebase=codebase, language=language)} - -## Best Practices -- Take inspiration on best practices from the provided, curated examples - - These have been vetted by human experts and are known to be correct -- When applicable, include aesthetic and instructive logs for developers via the `print` statement, such as: - - A title - - Emoji - - Hierarchical logging, with filenames in single quotes -- You do not need to explain to the developer the code you are going to write before calling CREATE_CODEMOD - - This will just make the code -- You *DO NOT* need to import `codegen.sdk`, (this module does not exist) `codebase` or any types. - - All types in the library are available in the global namespace and are automatically imported, as is the `codebase` object. -- You *DO NOT* need to do anything to parse the codebase. - - This is done automatically by the Codegen infrastructure and pre-cached for fast execution. Just interact with the `codebase` object. -""" # noqa: E501 diff --git a/src/codegen/sdk/code_generation/prompts/utils.py b/src/codegen/sdk/code_generation/prompts/utils.py deleted file mode 100644 index fc583c73b..000000000 --- a/src/codegen/sdk/code_generation/prompts/utils.py +++ /dev/null @@ -1,110 +0,0 @@ -from codegen.sdk.code_generation.enums import DocumentationDecorators -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.enums import NodeType -from codegen.sdk.python.class_definition import PyClass -from codegen.shared.enums.programming_language import ProgrammingLanguage - - -def get_decorator_for_language( - language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, -) -> DocumentationDecorators: - if language == ProgrammingLanguage.PYTHON: - return DocumentationDecorators.PYTHON - elif language == ProgrammingLanguage.TYPESCRIPT: - return DocumentationDecorators.TYPESCRIPT - - -def get_api_classes_by_decorator( - codebase: Codebase, - language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, -) -> dict[str, PyClass]: - """Returns all classes in a directory that have a specific decorator.""" - classes = {} - language_specific_decorator = get_decorator_for_language(language).value - general_decorator = DocumentationDecorators.GENERAL_API.value - # get language specific classes - for cls in codebase.classes: - class_decorators = [decorator.name for decorator in cls.decorators] - if language_specific_decorator in class_decorators: - classes[cls.name] = cls - for cls in codebase.classes: - class_decorators = [decorator.name for decorator in cls.decorators] - if general_decorator in class_decorators and cls.name not in classes.keys(): - classes[cls.name] = cls - return classes - - -def format_python_codeblock(source: str) -> str: - """A python codeblock in markdown format.""" - # USE 4 backticks instead of 3 so backticks inside the codeblock are handled properly - cb = f"````python\n{source}\n````" - return cb - - -def set_indent(string: str, indent: int) -> str: - """Sets the indentation of a string.""" - tab = "\t" - return "\n".join([f"{tab * indent}{line}" for line in string.split("\n")]) - - -def get_codegen_sdk_class_docstring(cls: PyClass, codebase: Codebase) -> str: - """Get the documentation for a single GraphSitter class and its methods.""" - # =====[ Parent classes ]===== - parent_classes = cls.parent_class_names - parent_class_names = [parent.source for parent in parent_classes if parent.source not in ("Generic", "ABC", "Expression")] - superclasses = ", ".join([name for name in parent_class_names]) - if len(superclasses) > 0: - superclasses = f"({superclasses})" - - # =====[ Name + docstring ]===== - source = f"class {cls.name}{superclasses}:" - if cls.docstring is not None: - source += set_indent(string=f'\n"""{cls.docstring.text}"""', indent=1) - source += "\n" - - # =====[ Attributes ]===== - if cls.is_subclass_of("Enum"): - for attribute in cls.attributes: - source += set_indent(f"\n{attribute.source}", 1) - else: - for attribute in cls.attributes(private=False, max_depth=None): - # Only document attributes which have docstrings - if docstring := attribute.docstring(cls): - source += set_indent(f"\n{attribute.attribute_docstring}", 1) - source += set_indent(string=f'\n"""{docstring}"""', indent=2) - source += set_indent("\n...\n", 2) - - # =====[ Get inherited method ]===== - def get_inherited_method(superclasses, method): - """Returns True if the method is inherited""" - for s in superclasses: - for m in s.methods: - if m.name == method.name: - if m.docstring == method.docstring or method.docstring is None: - return m - return None - - # =====[ Get superclasses ]===== - superclasses = cls.superclasses - superclasses = list({s.name: s for s in superclasses}.values()) - superclasses = [x for x in superclasses if x.node_type != NodeType.EXTERNAL] - - # TODO use new filter_methods_list function here - # =====[ Get methods to be documented ]===== - doc_methods = cls.methods - doc_methods = [m for m in doc_methods if not m.name.startswith("_")] - doc_methods = [m for m in doc_methods if not any("noapidoc" in d.name for d in m.decorators)] - doc_methods = [m for m in doc_methods if get_inherited_method(superclasses, m) is None] - - # =====[ Methods ]===== - for method in doc_methods: - if "property" in [decorator.name for decorator in method.decorators]: - source += set_indent(f"\n@property\n{method.function_signature}", 1) - else: - source += set_indent(f"\n{method.function_signature}", 1) - if method.docstring is not None: - source += set_indent(string=f'\n"""{method.docstring.text}"""', indent=2) - source += set_indent("\n...\n", 2) - - # =====[ Format markdown ]===== - return f"""### {cls.name}\n\n{format_python_codeblock(source)}""" diff --git a/src/codegen/sdk/codebase/__init__.py b/src/codegen/sdk/codebase/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/codebase/codebase_ai.py b/src/codegen/sdk/codebase/codebase_ai.py deleted file mode 100644 index 658dbe23c..000000000 --- a/src/codegen/sdk/codebase/codebase_ai.py +++ /dev/null @@ -1,208 +0,0 @@ -from codegen.sdk.core.file import File -from codegen.sdk.core.interfaces.editable import Editable - - -def generate_system_prompt(target: Editable | None = None, context: None | str | Editable | list[Editable] | dict[str, str | Editable | list[Editable]] = None) -> str: - prompt = """Hey CodegenBot! -You are an incredibly precise and thoughtful AI who helps developers accomplish complex transformations on their codebase. -You always provide clear, concise, and accurate responses. -When dealing with code, you maintain the original structure and style unless explicitly asked to change it. -""" - if target: - prompt += f""" -The user has just requested a response on the following code snippet: - -[[[CODE SNIPPET BEGIN]]] -{target.extended_source} -[[[CODE SNIPPET END]]] - -Your job is to follow the instructions of the user, given the context provided. -""" - else: - prompt += """ -Your job is to follow the instructions of the user. -""" - - if context: - prompt += """ -The user has provided some additional context that you can use to assist with your response. -You may use this context to inform your answer, but you're not required to directly include it in your response. - -Here is the additional context: -""" - prompt += generate_context(context) - - prompt += """ -Please ensure your response is accurate and relevant to the user's request. You may think out loud in the response. - - -Generally, when responding with an an answer, try to follow these general "ground rules": -Remember, these are just rules you should follow by default. If the user explicitly asks for something else, you should follow their instructions instead. - -> When generating new code or new classes, such as "create me a new function that does XYZ" or "generate a helper function that does XYZ", try to: - -- Do not include extra indentation that is not necessary, unless the user explicitly asks for something else. -- Include as much information as possible. Do not write things like "# the rest of the class" or "# the rest of the method", unless the user explicitly asks for something else. -- Do try to include comments and well-documented code, unless the user explicitly asks for something else. -- Only return the NEW code without re-iterating any existing code that the user has provided to you, unless the user explicitly asks for something else. -- Do not include any code that the user has explicitly asked you to remove, unless the user explicitly asks for something else. - - -> When changing existing code, such as "change this method to do XYZ" or "update this function to do XYZ" or "remove all instances of XYZ from this class", try to: - -- Do not include extra indentation that is not necessary, unless the user explicitly asks for something else. -- Include the entire context of the code that the user has provided to you, unless the user explicitly asks for something else. -- Include as much information as possible. Do not write things like "# the rest of the class" or "# the rest of the method", unless the user explicitly asks for something else. -- Do try to include comments and well-documented code, unless the user explicitly asks for something else. -- Avoid edit existing code that does not need editing, unless the user explicitly asks for something else. -- When asked to modify a very small or trivial part of the code, try to only modify the part that the user has asked you to modify, unless the user explicitly asks for something else. -- If asked to make improvements, try not to change existing function signatures, decorators, or returns, unless the user explicitly asks for something else. - - -> When dealing with anything related to docstrings, for example "Generate a google style docstring for this method." or "Convert these existing docs to google style docstrings.", try to: - -- Do not include extra indentation that is not necessary, unless the user explicitly asks for something else. -- Use the google style docstring format first, unless the user explicitly asks for something else. -- If doing google style docstrings, do not include the "self" or "cls" argument in the list of arguments, unless the user explicitly asks for something else. -- Try to have at least one line of the docstring to be a summary line, unless the user explicitly asks for something else. -- Try to keep each line of the docstring to be less than 80 characters, unless the user explicitly asks for something else. -- Try to keep any existing before and after examples in the docstring, unless the user explicitly asks for something else. -- Only respond with the content of the docstring, without any additional context like the function signature, return type, or parameter types, unless the user explicitly asks for something else. -- Do not include formatting like tripple quotes in your response, unless the user explicitly asks for something else. -- Do not include any markdown formatting, unless the user explicitly asks for something else. - -If you need a refresher on what google-style docstrings are: -- The first line is a summary line. -- The second line is a description of the method. -- The third line is a list of arguments. -- The fourth line is a list of returns. -Google docstrings may also include other information like exceptions and examples. -When generating NEW code or NEW classes, also try to generate docstrings alongside the code with the google style docstring format, -unless the user explicitly asks for something else. - - -> When dealing with anything related to comments, such as "write me a comment for this method" or "change this existing comment to be more descriptive", try to: - -- Do not include extra indentation that is not necessary, unless the user explicitly asks for something else. -- Do not include any comment delimiters like "#" or "//" unless the user explicitly asks for something else. -- Do not include any markdown formatting, unless the user explicitly asks for something else. -- Try to keep each line of the comment to be less than 80 characters, unless the user explicitly asks for something else. -- If you are only requested to edit or create a comment, do not include any code or other context that the user has provided to you, unless the user explicitly asks for something else. - - -> When dealing with single-word or single-phrase answers, like "what is a better name for this function" or "what is a better name for this class", try to: - -- Only respond with the content of the new name, without any additional context like the function signature, return type, or parameter types, unless the user explicitly asks for something else. -- Do not include formatting like tripple quotes in your response, unless the user explicitly asks for something else. -- Do not include any markdown formatting, unless the user explicitly asks for something else. -- Do not include any code or other context that the user has provided to you, unless the user explicitly asks for something else. - -REMEMBER: When giving the final answer, you must use the set_answer tool to provide the final answer that will be used in subsequent operations such as writing to a file, renaming, or editing. - """ - - return prompt - - -def generate_flag_system_prompt(target: Editable, context: None | str | Editable | list[Editable] | dict[str, str | Editable | list[Editable]] = None) -> str: - prompt = f"""Hey CodegenBot! -You are an incredibly precise and thoughtful AI who helps developers accomplish complex transformations on their codebase. - -You are now tasked with determining whether to flag the symbol, file, attribute, or message using AI. -Flagging a symbol means to mark it as a chunk of code that should be modified in a later step. -You will be given the user prompt, and the code snippet that the user is requesting a response on. -Use the should_flag tool to return either a true or false answer to the question of whether to flag the symbol, file, attribute, or message. - -Here is the code snippet that the user is requesting a response on: - -[[[CODE SNIPPET BEGIN]]] -{target.extended_source} -[[[CODE SNIPPET END]]] -""" - - if context: - prompt += """ -The user has provided some additional context that you can use to assist with your response. -You may use this context to inform your answer, but you're not required to directly include it in your response. - -Here is the additional context: -""" - prompt += generate_context(context) - - prompt += """ -Please intelligently determine whether the user's request on the given code snippet should be flagged. -Remember, use the should_flag tool to return either a true or false answer to the question of whether to flag the symbol, file, attribute, or message -as a chunk of code that should be modified, edited, or changed in a later step. - """ - - return prompt - - -def generate_context(context: None | str | Editable | list[Editable | File] | dict[str, str | Editable | list[Editable] | File] | File = None) -> str: - output = "" - if not context: - return output - else: - if isinstance(context, str): - output += f"====== Context ======\n{context}\n====================\n\n" - elif isinstance(context, Editable): - # Get class name - output += f"====== {context.__class__.__name__} ======\n" - output += f"{context.extended_source}\n" - output += "====================\n\n" - elif isinstance(context, File): - output += f"====== {context.__class__.__name__}======\n" - output += f"{context.source}\n" - output += "====================\n\n" - elif isinstance(context, list): - for item in context: - output += generate_context(item) - elif isinstance(context, dict): - for key, value in context.items(): - output += f"[[[ {key} ]]]\n" - output += generate_context(value) - output += "\n\n" - return output - - -def generate_tools() -> list: - return [ - { - "type": "function", - "function": { - "name": "set_answer", - "description": "Use this function to set the final answer to the given prompt. This answer will be used in subsequent operations such as writing to a file, renaming, or editing.", - "parameters": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "description": "The final answer to the given prompt. Do not include any uneccesary context or commentary in your response.", - }, - }, - "required": ["answer"], - }, - }, - } - ] - - -def generate_flag_tools() -> list: - return [ - { - "type": "function", - "function": { - "name": "should_flag", - "description": "Use this function to determine whether to flag the symbol, file, attribute, or message using AI.", - "parameters": { - "type": "object", - "properties": { - "flag": { - "type": "boolean", - "description": "Whether to flag the symbol, file, attribute, or message.", - }, - }, - "required": ["flag"], - }, - }, - } - ] diff --git a/src/codegen/sdk/codebase/codebase_analysis.py b/src/codegen/sdk/codebase/codebase_analysis.py deleted file mode 100644 index 8a33707db..000000000 --- a/src/codegen/sdk/codebase/codebase_analysis.py +++ /dev/null @@ -1,87 +0,0 @@ -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.file import SourceFile -from codegen.sdk.core.function import Function -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.enums import EdgeType, SymbolType - - -def get_codebase_summary(codebase: Codebase) -> str: - node_summary = f"""Contains {len(codebase.ctx.get_nodes())} nodes -- {len(list(codebase.files))} files -- {len(list(codebase.imports))} imports -- {len(list(codebase.external_modules))} external_modules -- {len(list(codebase.symbols))} symbols -\t- {len(list(codebase.classes))} classes -\t- {len(list(codebase.functions))} functions -\t- {len(list(codebase.global_vars))} global_vars -\t- {len(list(codebase.interfaces))} interfaces -""" - edge_summary = f"""Contains {len(codebase.ctx.edges)} edges -- {len([x for x in codebase.ctx.edges if x[2].type == EdgeType.SYMBOL_USAGE])} symbol -> used symbol -- {len([x for x in codebase.ctx.edges if x[2].type == EdgeType.IMPORT_SYMBOL_RESOLUTION])} import -> used symbol -- {len([x for x in codebase.ctx.edges if x[2].type == EdgeType.EXPORT])} export -> exported symbol - """ - - return f"{node_summary}\n{edge_summary}" - - -def get_file_summary(file: SourceFile) -> str: - return f"""==== [ `{file.name}` (SourceFile) Dependency Summary ] ==== -- {len(file.imports)} imports -- {len(file.symbols)} symbol references -\t- {len(file.classes)} classes -\t- {len(file.functions)} functions -\t- {len(file.global_vars)} global variables -\t- {len(file.interfaces)} interfaces - -==== [ `{file.name}` Usage Summary ] ==== -- {len(file.imports)} importers -""" - - -def get_class_summary(cls: Class) -> str: - return f"""==== [ `{cls.name}` (Class) Dependency Summary ] ==== -- parent classes: {cls.parent_class_names} -- {len(cls.methods)} methods -- {len(cls.attributes)} attributes -- {len(cls.decorators)} decorators -- {len(cls.dependencies)} dependencies - -{get_symbol_summary(cls)} - """ - - -def get_function_summary(func: Function) -> str: - return f"""==== [ `{func.name}` (Function) Dependency Summary ] ==== -- {len(func.return_statements)} return statements -- {len(func.parameters)} parameters -- {len(func.function_calls)} function calls -- {len(func.call_sites)} call sites -- {len(func.decorators)} decorators -- {len(func.dependencies)} dependencies - -{get_symbol_summary(func)} - """ - - -def get_symbol_summary(symbol: Symbol) -> str: - usages = symbol.symbol_usages - imported_symbols = [x.imported_symbol for x in usages if isinstance(x, Import)] - - return f"""==== [ `{symbol.name}` ({type(symbol).__name__}) Usage Summary ] ==== -- {len(usages)} usages -\t- {len([x for x in usages if isinstance(x, Symbol) and x.symbol_type == SymbolType.Function])} functions -\t- {len([x for x in usages if isinstance(x, Symbol) and x.symbol_type == SymbolType.Class])} classes -\t- {len([x for x in usages if isinstance(x, Symbol) and x.symbol_type == SymbolType.GlobalVar])} global variables -\t- {len([x for x in usages if isinstance(x, Symbol) and x.symbol_type == SymbolType.Interface])} interfaces -\t- {len(imported_symbols)} imports -\t\t- {len([x for x in imported_symbols if isinstance(x, Symbol) and x.symbol_type == SymbolType.Function])} functions -\t\t- {len([x for x in imported_symbols if isinstance(x, Symbol) and x.symbol_type == SymbolType.Class])} classes -\t\t- {len([x for x in imported_symbols if isinstance(x, Symbol) and x.symbol_type == SymbolType.GlobalVar])} global variables -\t\t- {len([x for x in imported_symbols if isinstance(x, Symbol) and x.symbol_type == SymbolType.Interface])} interfaces -\t\t- {len([x for x in imported_symbols if isinstance(x, ExternalModule)])} external modules -\t\t- {len([x for x in imported_symbols if isinstance(x, SourceFile)])} files - """ diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py deleted file mode 100644 index 506f4f52d..000000000 --- a/src/codegen/sdk/codebase/codebase_context.py +++ /dev/null @@ -1,845 +0,0 @@ -from __future__ import annotations - -import os -from collections import Counter, defaultdict -from contextlib import contextmanager -from enum import IntEnum, auto, unique -from functools import lru_cache -from os import PathLike -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from rustworkx import PyDiGraph, WeightedEdgeList - -from codegen.configs.models.codebase import CodebaseConfig, PinkMode -from codegen.configs.models.secrets import SecretsConfig -from codegen.sdk.codebase.config import ProjectConfig, SessionOptions -from codegen.sdk.codebase.config_parser import ConfigParser, get_config_parser_for_language -from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite -from codegen.sdk.codebase.flagging.flags import Flags -from codegen.sdk.codebase.io.file_io import FileIO -from codegen.sdk.codebase.progress.stub_progress import StubProgress -from codegen.sdk.codebase.transaction_manager import TransactionManager -from codegen.sdk.codebase.validation import get_edges, post_reset_validation -from codegen.sdk.core.autocommit import AutoCommit, commiter -from codegen.sdk.core.directory import Directory -from codegen.sdk.core.external.dependency_manager import DependencyManager, get_dependency_manager -from codegen.sdk.core.external.language_engine import LanguageEngine, get_language_engine -from codegen.sdk.enums import Edge, EdgeType, NodeType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import uncache_all -from codegen.sdk.typescript.external.ts_declassify.ts_declassify import TSDeclassify -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.exceptions.control_flow import StopCodemodException -from codegen.shared.logging.get_logger import get_logger -from codegen.shared.performance.stopwatch_utils import stopwatch - -if TYPE_CHECKING: - from collections.abc import Generator, Mapping, Sequence - - from codeowners import CodeOwners as CodeOwnersParser - from git import Commit as GitCommit - - from codegen.git.repo_operator.repo_operator import RepoOperator - from codegen.sdk.codebase.io.io import IO - from codegen.sdk.codebase.node_classes.node_classes import NodeClasses - from codegen.sdk.codebase.progress.progress import Progress - from codegen.sdk.core.dataclasses.usage import Usage - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.external_module import ExternalModule - from codegen.sdk.core.file import File, SourceFile - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.parser import Parser - -logger = get_logger(__name__) - - -# src/vs/platform/contextview/browser/contextMenuService.ts is ignored as there is a parsing error with tree-sitter -GLOBAL_FILE_IGNORE_LIST = [ - ".git/*", - "*/.git/*", - "node_modules/*", - "*/node_modules/*", - ".yarn/releases/*", - ".*/tests/static/chunk-.*.js", - ".*/ace/.*.js", - "src/vs/platform/contextview/browser/contextMenuService.ts", - "*/semver.js", - "*/compiled/*", - "*.min.js", - "*@*.js", -] - - -@unique -class SyncType(IntEnum): - DELETE = auto() - REPARSE = auto() - ADD = auto() - - -def get_node_classes(programming_language: ProgrammingLanguage) -> NodeClasses: - if programming_language == ProgrammingLanguage.PYTHON: - from codegen.sdk.codebase.node_classes.py_node_classes import PyNodeClasses - - return PyNodeClasses - elif programming_language == ProgrammingLanguage.TYPESCRIPT: - from codegen.sdk.codebase.node_classes.ts_node_classes import TSNodeClasses - - return TSNodeClasses - else: - from codegen.sdk.codebase.node_classes.generic_node_classes import GenericNodeClasses - - return GenericNodeClasses - - -class CodebaseContext: - """MultiDiGraph Wrapper with TransactionManager""" - - # =====[ __init__ attributes ]===== - node_classes: NodeClasses - programming_language: ProgrammingLanguage - repo_path: str - repo_name: str - codeowners_parser: CodeOwnersParser | None - config: CodebaseConfig - secrets: SecretsConfig - - # =====[ computed attributes ]===== - transaction_manager: TransactionManager - pending_syncs: list[DiffLite] # Diffs that have been applied to disk, but not the graph (to be used for sync graph) - all_syncs: list[DiffLite] # All diffs that have been applied to the graph (to be used for graph reset) - _autocommit: AutoCommit - generation: int - parser: Parser[Expression] - synced_commit: GitCommit | None - directories: dict[Path, Directory] - base_url: str | None - extensions: list[str] - config_parser: ConfigParser | None - dependency_manager: DependencyManager | None - language_engine: LanguageEngine | None - _computing = False - _graph: PyDiGraph[Importable, Edge] - filepath_idx: dict[str, NodeId] - _ext_module_idx: dict[str, NodeId] - flags: Flags - session_options: SessionOptions = SessionOptions() - projects: list[ProjectConfig] - unapplied_diffs: list[DiffLite] - io: IO - progress: Progress - - def __init__( - self, - projects: list[ProjectConfig], - config: CodebaseConfig | None = None, - secrets: SecretsConfig | None = None, - io: IO | None = None, - progress: Progress | None = None, - ) -> None: - """Initializes codebase graph and TransactionManager""" - from codegen.sdk.core.parser import Parser - - self.progress = progress or StubProgress() - self.__graph = PyDiGraph() - self.__graph_ready = False - self.filepath_idx = {} - self._ext_module_idx = {} - self.generation = 0 - - # NOTE: The differences between base_path, repo_name, and repo_path - # /home/codegen/projects/my-project/src - # ^^^ <- Base Path (Optional) - # ^^^^^^^^^^ <----- Repo Name - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ <----- Repo Path - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ <- Full Path - # (full_path is unused for CGB, but is used elsewhere.) - - # =====[ __init__ attributes ]===== - self.projects = projects - context = projects[0] - self.node_classes = get_node_classes(context.programming_language) - self.config = config or CodebaseConfig() - self.secrets = secrets or SecretsConfig() - self.repo_name = context.repo_operator.repo_name - self.repo_path = str(Path(context.repo_operator.repo_path).resolve()) - self.full_path = os.path.join(self.repo_path, context.base_path) if context.base_path else self.repo_path - self.codeowners_parser = context.repo_operator.codeowners_parser - self.base_url = context.repo_operator.base_url - if not self.config.allow_external: - # TODO: Fix this to be more robust with multiple projects - self.io = io or FileIO(allowed_paths=[Path(self.repo_path).resolve()]) - else: - self.io = io or FileIO() - # =====[ computed attributes ]===== - self.transaction_manager = TransactionManager() - self._autocommit = AutoCommit(self) - self.init_nodes = None - self.init_edges = None - self.directories = dict() - self.parser = Parser.from_node_classes(self.node_classes, log_parse_warnings=self.config.debug) - self.extensions = self.node_classes.file_cls.get_extensions() - # ORDER IS IMPORTANT HERE! - self.config_parser = get_config_parser_for_language(context.programming_language, self) - self.dependency_manager = get_dependency_manager(context.programming_language, self) - self.language_engine = get_language_engine(context.programming_language, self) - self.programming_language = context.programming_language - - # Raise warning if language is not supported - if self.programming_language is ProgrammingLanguage.UNSUPPORTED or self.programming_language is ProgrammingLanguage.OTHER: - logger.warning("WARNING: The codebase is using an unsupported language!") - logger.warning("Some features may not work as expected. Advanced static analysis will be disabled but simple file IO will still work.") - - # Assert config assertions - # External import resolution must be enabled if syspath is enabled - if self.config.py_resolve_syspath: - if not self.config.allow_external: - msg = "allow_external must be set to True when py_resolve_syspath is enabled" - raise ValueError(msg) - - # Build the graph - if not self.config.exp_lazy_graph and self.config.use_pink != PinkMode.ALL_FILES: - self.build_graph(context.repo_operator) - try: - self.synced_commit = context.repo_operator.head_commit - except ValueError as e: - logger.exception("Error getting commit head %s", e) - self.synced_commit = None - self.pending_syncs = [] - self.all_syncs = [] - self.unapplied_diffs = [] - self.flags = Flags() - - def __repr__(self): - return self.__class__.__name__ - - @property - def _graph(self) -> PyDiGraph[Importable, Edge]: - if not self.__graph_ready: - logger.info("Lazily Computing Graph") - self.build_graph(self.projects[0].repo_operator) - return self.__graph - - @_graph.setter - def _graph(self, value: PyDiGraph[Importable, Edge]) -> None: - self.__graph = value - - @stopwatch - @commiter - def build_graph(self, repo_operator: RepoOperator) -> None: - """Builds a codebase graph based on the current file state of the given repo operator""" - self.__graph_ready = True - self._graph.clear() - - # =====[ Add all files to the graph in parallel ]===== - syncs = defaultdict(lambda: []) - if self.config.disable_file_parse: - logger.warning("WARNING: File parsing is disabled!") - else: - for filepath, _ in repo_operator.iter_files(subdirs=self.projects[0].subdirectories, extensions=self.extensions, ignore_list=GLOBAL_FILE_IGNORE_LIST): - syncs[SyncType.ADD].append(self.to_absolute(filepath)) - logger.info(f"> Parsing {len(syncs[SyncType.ADD])} files in {self.projects[0].subdirectories or 'ALL'} subdirectories with {self.extensions} extensions") - self._process_diff_files(syncs, incremental=False) - files: list[SourceFile] = self.get_nodes(NodeType.FILE) - logger.info(f"> Found {len(files)} files") - logger.info(f"> Found {len(self.nodes)} nodes and {len(self.edges)} edges") - if self.config.track_graph: - self.old_graph = self._graph.copy() - - @stopwatch - @commiter - def apply_diffs(self, diff_list: list[DiffLite]) -> None: - """Applies the given set of diffs to the graph in order to match the current file system content""" - if self.session_options: - self.session_options = self.session_options.model_copy(update={"max_seconds": None}) - logger.info(f"Applying {len(diff_list)} diffs to graph") - files_to_sync: dict[Path, SyncType] = {} - # Gather list of deleted files, new files to add, and modified files to reparse - file_cls = self.node_classes.file_cls - extensions = file_cls.get_extensions() - for diff in diff_list: - filepath = Path(diff.path) - if extensions is not None and filepath.suffix not in extensions: - continue - if self.projects[0].subdirectories is not None and not any(filepath.relative_to(subdir) for subdir in self.projects[0].subdirectories): - continue - - if diff.change_type == ChangeType.Added: - # Sync by adding the added file to the graph - files_to_sync[filepath] = SyncType.ADD - elif diff.change_type == ChangeType.Modified: - files_to_sync[filepath] = SyncType.REPARSE - elif diff.change_type == ChangeType.Renamed: - files_to_sync[diff.rename_from] = SyncType.DELETE - files_to_sync[diff.rename_to] = SyncType.ADD - elif diff.change_type == ChangeType.Removed: - files_to_sync[filepath] = SyncType.DELETE - else: - logger.warning(f"Unhandled diff change type: {diff.change_type}") - by_sync_type = defaultdict(lambda: []) - if self.config.disable_file_parse: - logger.warning("WARNING: File parsing is disabled!") - else: - for filepath, sync_type in files_to_sync.items(): - if self.get_file(filepath) is None: - if sync_type is SyncType.DELETE: - # SourceFile is already deleted, nothing to do here - continue - elif sync_type is SyncType.REPARSE: - # SourceFile needs to be parsed for the first time - sync_type = SyncType.ADD - elif sync_type is SyncType.ADD: - # If the file was deleted earlier, we need to reparse so we can remove old edges - sync_type = SyncType.REPARSE - - by_sync_type[sync_type].append(filepath) - self.generation += 1 - self._process_diff_files(by_sync_type) - - def _reset_files(self, syncs: list[DiffLite]) -> None: - files_to_write = [] - files_to_remove = [] - modified_files = set() - for sync in syncs: - if sync.path in modified_files: - continue - if sync.change_type == ChangeType.Removed: - files_to_write.append((sync.path, sync.old_content)) - modified_files.add(sync.path) - logger.info(f"Removing {sync.path} from disk") - elif sync.change_type == ChangeType.Modified: - files_to_write.append((sync.path, sync.old_content)) - modified_files.add(sync.path) - elif sync.change_type == ChangeType.Renamed: - files_to_write.append((sync.rename_from, sync.old_content)) - files_to_remove.append(sync.rename_to) - modified_files.add(sync.rename_from) - modified_files.add(sync.rename_to) - elif sync.change_type == ChangeType.Added: - files_to_remove.append(sync.path) - modified_files.add(sync.path) - logger.info(f"Writing {len(files_to_write)} files to disk and removing {len(files_to_remove)} files") - for file in files_to_remove: - self.io.delete_file(file) - to_save = set() - for file, content in files_to_write: - self.io.write_file(file, content) - to_save.add(file) - self.io.save_files(to_save) - - @stopwatch - def reset_codebase(self) -> None: - self._reset_files(self.all_syncs + self.pending_syncs + self.unapplied_diffs) - self.unapplied_diffs.clear() - - @stopwatch - def undo_applied_diffs(self) -> None: - self.transaction_manager.clear_transactions() - self.reset_codebase() - self.io.check_changes() - self.pending_syncs.clear() # Discard pending changes - if len(self.all_syncs) > 0: - logger.info(f"Unapplying {len(self.all_syncs)} diffs to graph. Current graph commit: {self.synced_commit}") - self._revert_diffs(list(reversed(self.all_syncs))) - self.all_syncs.clear() - - @stopwatch - @commiter(reset=True) - def _revert_diffs(self, diff_list: list[DiffLite]) -> None: - """Resets the graph to its initial solve branch file state""" - reversed_diff_list = list(DiffLite.from_reverse_diff(diff) for diff in diff_list) - self._autocommit.reset() - self.apply_diffs(reversed_diff_list) - # ====== [ Re-resolve lost edges from previous syncs ] ====== - self.prune_graph() - if self.config.verify_graph: - post_reset_validation(self.old_graph.nodes(), self._graph.nodes(), get_edges(self.old_graph), get_edges(self._graph), self.repo_name, self.projects[0].subdirectories) - - def save_commit(self, commit: GitCommit) -> None: - if commit is not None: - logger.info(f"Saving commit {commit.hexsha} to graph") - self.all_syncs.clear() - self.unapplied_diffs.clear() - self.synced_commit = commit - if self.config.verify_graph: - self.old_graph = self._graph.copy() - - @stopwatch - def prune_graph(self) -> None: - # ====== [ Remove orphaned external modules ] ====== - external_modules = self.get_nodes(NodeType.EXTERNAL) - for module in external_modules: - if not any(self.predecessors(module.node_id)): - self.remove_node(module.node_id) - self._ext_module_idx.pop(module._idx_key, None) - - def build_directory_tree(self) -> None: - """Builds the directory tree for the codebase""" - # Reset and rebuild the directory tree - self.directories = dict() - - for file_path, _ in self.projects[0].repo_operator.iter_files( - subdirs=self.projects[0].subdirectories, - ignore_list=GLOBAL_FILE_IGNORE_LIST, - skip_content=True, - ): - file_path = Path(file_path) - directory = self.get_directory(file_path.parent, create_on_missing=True) - directory._add_file(file_path.name) - - def get_directory(self, directory_path: PathLike, create_on_missing: bool = False, ignore_case: bool = False) -> Directory | None: - """Returns the directory object for the given path, or None if the directory does not exist. - - If create_on_missing is set, use a recursive strategy to create the directory object and all subdirectories. - """ - # If not part of repo path, return None - absolute_path = self.to_absolute(directory_path) - if not self.is_subdir(absolute_path) and not self.config.allow_external: - assert False, f"Directory {absolute_path} is not part of repo path {self.repo_path}" - return None - - # Get the directory - if dir := self.directories.get(absolute_path, None): - return dir - if ignore_case: - for path, directory in self.directories.items(): - if str(absolute_path).lower() == str(path).lower(): - return directory - - # If the directory does not exist, create it - if create_on_missing: - # Get the parent directory and create it if it does not exist - parent_path = absolute_path.parent - - # Base Case - if str(absolute_path) == str(self.repo_path) or str(absolute_path) == str(parent_path): - root_directory = Directory(ctx=self, path=absolute_path, dirpath="") - self.directories[absolute_path] = root_directory - return root_directory - - # Recursively create the parent directory - parent = self.get_directory(parent_path, create_on_missing=True) - # Create the directory - directory = Directory(ctx=self, path=absolute_path, dirpath=str(self.to_relative(absolute_path))) - # Add the directory to the parent - parent._add_subdirectory(directory.name) - # Add the directory to the tree - self.directories[absolute_path] = directory - return directory - return None - - def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incremental: bool = True) -> None: - # If all the files are empty, don't uncache - assert self._computing is False - skip_uncache = incremental and ((len(files_to_sync[SyncType.DELETE]) + len(files_to_sync[SyncType.REPARSE])) == 0) - if not skip_uncache: - uncache_all() - # Step 0: Start the dependency manager and language engine if they exist - # Start the dependency manager. This may or may not run asynchronously, depending on the implementation - if self.dependency_manager is not None: - # Check if its inital start or a reparse - if not self.dependency_manager.ready() and not self.dependency_manager.error(): - # TODO: We do not reparse dependencies during syncs as it is expensive. We should probably add a flag for this - logger.info("> Starting dependency manager") - self.dependency_manager.start(async_start=False) - - # Start the language engine. This may or may not run asynchronously, depending on the implementation - if self.language_engine is not None: - # Check if its inital start or a reparse - if not self.language_engine.ready() and not self.language_engine.error(): - logger.info("> Starting language engine") - self.language_engine.start(async_start=False) - else: - logger.info("> Reparsing language engine") - self.language_engine.reparse(async_start=False) - - # Step 1: Wait for dependency manager and language engines to finish before graph construction - if self.dependency_manager is not None: - self.dependency_manager.wait_until_ready(ignore_error=self.config.ignore_process_errors) - if self.language_engine is not None: - self.language_engine.wait_until_ready(ignore_error=self.config.ignore_process_errors) - - # ====== [ Refresh the graph] ======== - # Step 2: For any files that no longer exist, remove them during the sync - add_to_remove = [] - if incremental: - for file_path in files_to_sync[SyncType.ADD]: - if not self.io.file_exists(self.to_absolute(file_path)): - add_to_remove.append(file_path) - logger.warning(f"SYNC: SourceFile {file_path} no longer exists! Removing from graph") - reparse_to_remove = [] - for file_path in files_to_sync[SyncType.REPARSE]: - if not self.io.file_exists(self.to_absolute(file_path)): - reparse_to_remove.append(file_path) - logger.warning(f"SYNC: SourceFile {file_path} no longer exists! Removing from graph") - files_to_sync[SyncType.ADD] = [f for f in files_to_sync[SyncType.ADD] if f not in add_to_remove] - files_to_sync[SyncType.REPARSE] = [f for f in files_to_sync[SyncType.REPARSE] if f not in reparse_to_remove] - for file_path in add_to_remove + reparse_to_remove: - if self.get_file(file_path) is not None: - files_to_sync[SyncType.DELETE].append(file_path) - else: - logger.warning(f"SYNC: SourceFile {file_path} does not exist and also not found on graph!") - - # Step 3: Remove files to delete from graph - to_resolve = [] - for file_path in files_to_sync[SyncType.DELETE]: - file = self.get_file(file_path) - file.remove_internal_edges() - to_resolve.extend(file.unparse()) - to_resolve = list(filter(lambda node: self.has_node(node.node_id) and node is not None, to_resolve)) - for file_path in files_to_sync[SyncType.REPARSE]: - file = self.get_file(file_path) - file.remove_internal_edges() - - task = self.progress.begin("Reparsing updated files", count=len(files_to_sync[SyncType.REPARSE])) - files_to_resolve = [] - # Step 4: Reparse updated files - for idx, file_path in enumerate(files_to_sync[SyncType.REPARSE]): - task.update(f"Reparsing {self.to_relative(file_path)}", count=idx) - file = self.get_file(file_path) - to_resolve.extend(file.unparse(reparse=True)) - to_resolve = list(filter(lambda node: self.has_node(node.node_id) and node is not None, to_resolve)) - file.sync_with_file_content() - files_to_resolve.append(file) - task.end() - # Step 5: Add new files as nodes to graph (does not yet add edges) - task = self.progress.begin("Adding new files", count=len(files_to_sync[SyncType.ADD])) - for idx, filepath in enumerate(files_to_sync[SyncType.ADD]): - task.update(f"Adding {self.to_relative(filepath)}", count=idx) - try: - content = self.io.read_text(filepath) - except UnicodeDecodeError as e: - logger.warning(f"Can't read file at:{filepath} since it contains non-unicode characters. File will be ignored!") - continue - # TODO: this is wrong with context changes - if filepath.suffix in self.extensions: - file_cls = self.node_classes.file_cls - new_file = file_cls.from_content(filepath, content, self, sync=False, verify_syntax=False) - if new_file is not None: - files_to_resolve.append(new_file) - task.end() - for file in files_to_resolve: - to_resolve.append(file) - to_resolve.extend(file.get_nodes()) - - to_resolve = list(filter(lambda node: self.has_node(node.node_id) and node is not None, to_resolve)) - counter = Counter(node.node_type for node in to_resolve) - - # Step 6: Build directory tree - logger.info("> Building directory tree") - self.build_directory_tree() - - # Step 7: Build configs - if self.config_parser is not None: - self.config_parser.parse_configs() - - # Step 8: Add internal import resolution edges for new and updated files - if not skip_uncache: - uncache_all() - - if self.config.disable_graph: - logger.warning("Graph generation is disabled. Skipping import and symbol resolution") - self._computing = False - else: - self._computing = True - try: - logger.info(f"> Computing import resolution edges for {counter[NodeType.IMPORT]} imports") - task = self.progress.begin("Resolving imports", count=counter[NodeType.IMPORT]) - for node in to_resolve: - if node.node_type == NodeType.IMPORT: - task.update(f"Resolving imports in {node.filepath}", count=idx) - node._remove_internal_edges(EdgeType.IMPORT_SYMBOL_RESOLUTION) - node.add_symbol_resolution_edge() - to_resolve.extend(node.symbol_usages) - task.end() - if counter[NodeType.EXPORT] > 0: - logger.info(f"> Computing export dependencies for {counter[NodeType.EXPORT]} exports") - task = self.progress.begin("Computing export dependencies", count=counter[NodeType.EXPORT]) - for node in to_resolve: - if node.node_type == NodeType.EXPORT: - task.update(f"Computing export dependencies for {node.filepath}", count=idx) - node._remove_internal_edges(EdgeType.EXPORT) - node.compute_export_dependencies() - to_resolve.extend(node.symbol_usages) - task.end() - if counter[NodeType.SYMBOL] > 0: - from codegen.sdk.core.interfaces.inherits import Inherits - - logger.info("> Computing superclass dependencies") - task = self.progress.begin("Computing superclass dependencies", count=counter[NodeType.SYMBOL]) - for symbol in to_resolve: - if isinstance(symbol, Inherits): - task.update(f"Computing superclass dependencies for {symbol.filepath}", count=idx) - symbol._remove_internal_edges(EdgeType.SUBCLASS) - symbol.compute_superclass_dependencies() - task.end() - if not skip_uncache: - uncache_all() - self._compute_dependencies(to_resolve, incremental) - finally: - self._computing = False - - def _compute_dependencies(self, to_update: list[Importable], incremental: bool): - seen = set() - while to_update: - task = self.progress.begin("Computing dependencies", count=len(to_update)) - step = to_update.copy() - to_update.clear() - logger.info(f"> Incrementally computing dependencies for {len(step)} nodes") - for idx, current in enumerate(step): - task.update(f"Computing dependencies for {current.filepath}", count=idx) - if current not in seen: - seen.add(current) - to_update.extend(current.recompute(incremental)) - if not incremental: - for node in self._graph.nodes(): - if node not in seen: - to_update.append(node) - task.end() - seen.clear() - - def build_subgraph(self, nodes: list[NodeId]) -> PyDiGraph[Importable, Edge]: - """Builds a subgraph from the given set of nodes""" - subgraph = PyDiGraph() - subgraph.add_nodes_from(self._graph.nodes()) - subgraph.add_edges_from(self._graph.weighted_edge_list()) - return subgraph.subgraph(nodes) - - def get_node(self, node_id: int) -> Any: - return self._graph.get_node_data(node_id) - - def get_nodes(self, node_type: NodeType | None = None, exclude_type: NodeType | None = None) -> list[Importable]: - if node_type is not None and exclude_type is not None: - msg = "node_type and exclude_type cannot both be specified" - raise ValueError(msg) - if node_type is not None: - return [self.get_node(node_id) for node_id in self._graph.filter_nodes(lambda node: node.node_type == node_type)] - if exclude_type is not None: - return [self.get_node(node_id) for node_id in self._graph.filter_nodes(lambda node: node.node_type != node_type)] - return self._graph.nodes() - - def get_edges(self) -> list[tuple[NodeId, NodeId, EdgeType, Usage | None]]: - return [(x[0], x[1], x[2].type, x[2].usage) for x in self._graph.weighted_edge_list()] - - def get_file(self, file_path: os.PathLike, ignore_case: bool = False) -> SourceFile | None: - # If not part of repo path, return None - absolute_path = self.to_absolute(file_path) - if not self.is_subdir(absolute_path) and not self.config.allow_external: - assert False, f"File {file_path} is not part of the repository path" - - # Check if file exists in graph - node_id = self.filepath_idx.get(str(self.to_relative(file_path)), None) - if node_id is not None: - return self.get_node(node_id) - if ignore_case: - # Using `get_directory` so that the case insensitive lookup works - parent = self.get_directory(self.to_absolute(file_path).parent, ignore_case=ignore_case).path - for file in parent.iterdir(): - if str(file_path).lower() == str(self.to_relative(file)).lower(): - return self.get_file(file, ignore_case=False) - - def _get_raw_file_from_path(self, path: Path) -> File | None: - from codegen.sdk.core.file import File - - try: - return File.from_content(path, self.io.read_text(path), self, sync=False) - except UnicodeDecodeError: - # Handle when file is a binary file - return File.from_content(path, self.io.read_bytes(path), self, sync=False, binary=True) - - def get_external_module(self, module: str, import_name: str) -> ExternalModule | None: - node_id = self._ext_module_idx.get(module + "::" + import_name, None) - if node_id is not None: - return self.get_node(node_id) - - def add_node(self, node: Importable) -> int: - if self.config.debug: - if self._graph.find_node_by_weight(node.__eq__): - msg = "Node already exists" - raise Exception(msg) - if self.config.debug and self._computing and node.node_type != NodeType.EXTERNAL: - assert False, f"Adding node during compute dependencies: {node!r}" - return self._graph.add_node(node) - - def add_child(self, parent: NodeId, node: Importable, type: EdgeType, usage: Usage | None = None) -> int: - if self.config.debug: - if self._graph.find_node_by_weight(node.__eq__): - msg = "Node already exists" - raise Exception(msg) - if self.config.debug and self._computing and node.node_type != NodeType.EXTERNAL: - assert False, f"Adding node during compute dependencies: {node!r}" - return self._graph.add_child(parent, node, Edge(type, usage)) - - def has_node(self, node_id: NodeId): - return isinstance(node_id, int) and self._graph.has_node(node_id) - - def has_edge(self, u: NodeId, v: NodeId, edge: Edge): - return self._graph.has_edge(u, v) and edge in self._graph.get_all_edge_data(u, v) - - def add_edge(self, u: NodeId, v: NodeId, type: EdgeType, usage: Usage | None = None) -> None: - edge = Edge(type, usage) - if self.config.debug: - assert self._graph.has_node(u) - assert self._graph.has_node(v), v - assert not self.has_edge(u, v, edge), (u, v, edge) - self._graph.add_edge(u, v, edge) - - def add_edges(self, edges: list[tuple[NodeId, NodeId, Edge]]) -> None: - if self.config.debug: - for u, v, edge in edges: - assert self._graph.has_node(u) - assert self._graph.has_node(v), v - assert not self.has_edge(u, v, edge), (self.get_node(u), self.get_node(v), edge) - self._graph.add_edges_from(edges) - - @property - def nodes(self): - return self._graph.nodes() - - @property - def edges(self) -> WeightedEdgeList[Edge]: - return self._graph.weighted_edge_list() - - def predecessor(self, n: NodeId, *, edge_type: EdgeType | None) -> Importable: - return self._graph.find_predecessor_node_by_edge(n, lambda edge: edge.type == edge_type) - - def predecessors(self, n: NodeId, edge_type: EdgeType | None = None) -> Sequence[Importable]: - if edge_type is not None: - return sort_editables(self._graph.find_predecessors_by_edge(n, lambda edge: edge.type == edge_type), by_id=True) - return self._graph.predecessors(n) - - def successors(self, n: NodeId, *, edge_type: EdgeType | None = None, sort: bool = True) -> Sequence[Importable]: - if edge_type is not None: - res = self._graph.find_successors_by_edge(n, lambda edge: edge.type == edge_type) - else: - res = self._graph.successors(n) - if sort: - return sort_editables(res, by_id=True, dedupe=False) - return res - - def get_edge_data(self, *args, **kwargs) -> set[Edge]: - return set(self._graph.get_all_edge_data(*args, **kwargs)) - - def in_edges(self, n: NodeId) -> WeightedEdgeList[Edge]: - return self._graph.in_edges(n) - - def out_edges(self, n: NodeId) -> WeightedEdgeList[Edge]: - return self._graph.out_edges(n) - - def remove_node(self, n: NodeId): - return self._graph.remove_node(n) - - def remove_edge(self, u: NodeId, v: NodeId, *, edge_type: EdgeType | None = None): - for edge in self._graph.edge_indices_from_endpoints(u, v): - if edge_type is not None: - if self._graph.get_edge_data_by_index(edge).type != edge_type: - continue - self._graph.remove_edge_from_index(edge) - - @lru_cache(maxsize=10000) - def to_absolute(self, filepath: PathLike | str) -> Path: - path = Path(filepath) - if not path.is_absolute(): - path = Path(self.repo_path) / path - return path.resolve() - - @lru_cache(maxsize=10000) - def to_relative(self, filepath: PathLike | str) -> Path: - path = self.to_absolute(filepath) - if path == Path(self.repo_path) or Path(self.repo_path) in path.parents: - return path.relative_to(self.repo_path) - return path - - def is_subdir(self, path: PathLike | str) -> bool: - path = self.to_absolute(path) - return path == Path(self.repo_path) or path.is_relative_to(self.repo_path) or Path(self.repo_path) in path.parents - - @commiter - def commit_transactions(self, sync_graph: bool = True, sync_file: bool = True, files: set[Path] | None = None) -> None: - """Commits all transactions to the codebase, and syncs the graph to match the latest file changes. - Should be called at the end of `execute` for every codemod group run. - - Arguments: - sync_graph (bool): If True, syncs the graph with the latest set of file changes - sync_file (bool): If True, writes any pending file edits to the file system - files (set[str] | None): If provided, only commits transactions for the given set of files - """ - # Commit transactions for all contexts - files_to_lock = self.transaction_manager.to_commit(files) - diffs = self.transaction_manager.commit(files_to_lock) - for diff in diffs: - if self.get_file(diff.path) is None: - self.unapplied_diffs.append(diff) - else: - self.pending_syncs.append(diff) - - # Write files if requested - if sync_file: - self.io.save_files(files) - - # Sync the graph if requested - if sync_graph and len(self.pending_syncs) > 0: - self.apply_diffs(self.pending_syncs) - self.all_syncs.extend(self.pending_syncs) - self.pending_syncs.clear() - - @commiter - def add_single_file(self, filepath: PathLike) -> None: - """Adds a file to the graph and computes it's dependencies""" - sync = DiffLite(ChangeType.Added, self.to_absolute(filepath)) - self.all_syncs.append(sync) - self.apply_diffs([sync]) - self.transaction_manager.check_limits() - - @contextmanager - def session(self, sync_graph: bool = True, commit: bool = True, session_options: SessionOptions = SessionOptions()) -> Generator[None, None, None]: - self.session_options = session_options - self.transaction_manager.set_max_transactions(self.session_options.max_transactions) - self.transaction_manager.reset_stopwatch(self.session_options.max_seconds) - try: - yield None - except StopCodemodException as e: - logger.info(f"{e}, committing transactions and resetting graph") - raise - finally: - if commit: - self.commit_transactions(sync_graph) - - def remove_directory(self, directory_path: PathLike, force: bool = False, cleanup: bool = True) -> None: - """Removes a directory from the graph""" - # Get the directory - directory = self.get_directory(directory_path) - - # Check errors - if directory is None: - msg = f"Directory {directory_path} does not exist" - raise ValueError(msg) - if not force and len(directory.items) > 0: - msg = f"Directory {directory_path} is not empty" - raise ValueError(msg) - - # Remove the directory from the tree - if str(directory_path) in self.directories: - del self.directories[str(directory_path)] - - # Remove the directory from the parent - if directory.parent is not None: - directory.parent.remove_subdirectory(directory) - # Cleanup - if cleanup and len(directory.parent.items) == 0: - self.remove_directory(directory.parent.path, cleanup=cleanup) - - #################################################################################################################### - # EXTERNAL UTILS - #################################################################################################################### - - _ts_declassify: TSDeclassify | None = None - - @property - def ts_declassify(self) -> TSDeclassify: - if self._ts_declassify is None: - self._ts_declassify = TSDeclassify(self.repo_path, self.projects[0].base_path) - self._ts_declassify.start() # Install react-declassify - return self._ts_declassify diff --git a/src/codegen/sdk/codebase/config.py b/src/codegen/sdk/codebase/config.py deleted file mode 100644 index 25f3c2e0e..000000000 --- a/src/codegen/sdk/codebase/config.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -from typing import Self - -from pydantic import BaseModel -from pydantic.config import ConfigDict -from pydantic.fields import Field - -from codegen.configs.models.codebase import DefaultCodebaseConfig -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig -from codegen.git.utils.file_utils import split_git_path -from codegen.git.utils.language import determine_project_language -from codegen.shared.enums.programming_language import ProgrammingLanguage - -HARD_MAX_AI_LIMIT = 500 # Global limit for AI requests - - -class SessionOptions(BaseModel): - """Options for a session. A session is a single codemod run.""" - - model_config = ConfigDict(frozen=True) - max_seconds: int | None = None - max_transactions: int | None = None - max_ai_requests: int = Field(default=150, le=HARD_MAX_AI_LIMIT) - - -TestFlags = DefaultCodebaseConfig.model_copy(update=dict(debug=True, track_graph=True, verify_graph=True, full_range_index=True, sync_enabled=True)) -LintFlags = DefaultCodebaseConfig.model_copy(update=dict(method_usages=False, sync_enabled=True)) -ParseTestFlags = DefaultCodebaseConfig.model_copy(update=dict(debug=False, track_graph=False, sync_enabled=True)) - - -class ProjectConfig(BaseModel): - """Context for a codebase. A codebase is a set of files in a directory.""" - - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - repo_operator: RepoOperator - - # TODO: clean up these fields. Duplicated across RepoConfig and CodebaseContext - base_path: str | None = None - subdirectories: list[str] | None = None - programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON - - @classmethod - def from_path(cls, path: str, programming_language: ProgrammingLanguage | None = None) -> Self: - # Split repo_path into (git_root, base_path) - repo_path = os.path.abspath(path) - git_root, base_path = split_git_path(repo_path) - subdirectories = [base_path] if base_path else None - programming_language = programming_language or determine_project_language(repo_path) - repo_config = RepoConfig.from_repo_path(repo_path=git_root) - repo_config.language = programming_language - repo_config.subdirectories = subdirectories - # Create main project - return cls( - repo_operator=RepoOperator(repo_config=repo_config), - programming_language=programming_language, - base_path=base_path, - subdirectories=subdirectories, - ) - - @classmethod - def from_repo_operator(cls, repo_operator: RepoOperator, programming_language: ProgrammingLanguage | None = None, base_path: str | None = None) -> Self: - return cls( - repo_operator=repo_operator, - programming_language=programming_language or determine_project_language(repo_operator.repo_path), - base_path=base_path, - subdirectories=[base_path] if base_path else None, - ) diff --git a/src/codegen/sdk/codebase/config_parser.py b/src/codegen/sdk/codebase/config_parser.py deleted file mode 100644 index f6ee88238..000000000 --- a/src/codegen/sdk/codebase/config_parser.py +++ /dev/null @@ -1,24 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -class ConfigParser(ABC): - def __init__(self): - pass - - @abstractmethod - def parse_configs(self, codebase_context: "CodebaseContext"): ... - - -def get_config_parser_for_language(language: ProgrammingLanguage, codebase_context: "CodebaseContext") -> ConfigParser | None: - from codegen.sdk.typescript.config_parser import TSConfigParser - - if language == ProgrammingLanguage.TYPESCRIPT: - return TSConfigParser(codebase_context) - - return None diff --git a/src/codegen/sdk/codebase/diff_lite.py b/src/codegen/sdk/codebase/diff_lite.py deleted file mode 100644 index b923488ef..000000000 --- a/src/codegen/sdk/codebase/diff_lite.py +++ /dev/null @@ -1,85 +0,0 @@ -from enum import IntEnum, auto -from os import PathLike -from pathlib import Path -from typing import NamedTuple, Self - -from git import Diff -from watchfiles import Change - - -class ChangeType(IntEnum): - Modified = auto() - Removed = auto() - Renamed = auto() - Added = auto() - - @staticmethod - def from_watch_change_type(change_type: Change): - if change_type is Change.added: - return ChangeType.Added - elif change_type is Change.deleted: - return ChangeType.Removed - elif change_type is Change.modified: - return ChangeType.Modified - - @staticmethod - def from_git_change_type(change_type: str | None): - if change_type == "M": - return ChangeType.Modified - if change_type == "D": - return ChangeType.Removed - if change_type == "R": - return ChangeType.Renamed - if change_type == "A": - return ChangeType.Added - msg = f"Invalid change type: {change_type}" - raise ValueError(msg) - - -class DiffLite(NamedTuple): - """Simple diff for recomputing the graph""" - - change_type: ChangeType - path: Path - rename_from: Path | None = None - rename_to: Path | None = None - old_content: bytes | None = None - - @classmethod - def from_watch_change(cls, change: Change, path: PathLike) -> Self: - return cls( - change_type=ChangeType.from_watch_change_type(change), - path=Path(path), - ) - - @classmethod - def from_git_diff(cls, git_diff: Diff): - old = None - if git_diff.a_blob: - old = git_diff.a_blob.data_stream.read() - return cls( - change_type=ChangeType.from_git_change_type(git_diff.change_type), - path=Path(git_diff.a_path) if git_diff.a_path else None, - rename_from=Path(git_diff.rename_from) if git_diff.rename_from else None, - rename_to=Path(git_diff.rename_to) if git_diff.rename_to else None, - old_content=old, - ) - - @classmethod - def from_reverse_diff(cls, diff_lite: "DiffLite"): - if diff_lite.change_type == ChangeType.Added: - change_type = ChangeType.Removed - elif diff_lite.change_type == ChangeType.Removed: - change_type = ChangeType.Added - else: - change_type = diff_lite.change_type - - if diff_lite.change_type == ChangeType.Renamed: - return cls( - change_type=change_type, - path=diff_lite.path, - rename_from=diff_lite.rename_to, - rename_to=diff_lite.rename_from, - ) - - return cls(change_type=change_type, path=diff_lite.path) diff --git a/src/codegen/sdk/codebase/factory/codebase_factory.py b/src/codegen/sdk/codebase/factory/codebase_factory.py deleted file mode 100644 index 009992311..000000000 --- a/src/codegen/sdk/codebase/factory/codebase_factory.py +++ /dev/null @@ -1,28 +0,0 @@ -from codegen.configs.models.codebase import CodebaseConfig -from codegen.configs.models.secrets import SecretsConfig -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.config import ProjectConfig -from codegen.sdk.core.codebase import ( - Codebase, - CodebaseType, -) -from codegen.shared.enums.programming_language import ProgrammingLanguage - - -class CodebaseFactory: - #################################################################################################################### - # CREATE CODEBASE - #################################################################################################################### - - @staticmethod - def get_codebase_from_files( - repo_path: str = "/tmp/codegen_run_on_str", - files: dict[str, str] = {}, - bot_commit: bool = True, - programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, - config: CodebaseConfig | None = None, - secrets: SecretsConfig | None = None, - ) -> CodebaseType: - op = RepoOperator.create_from_files(repo_path=repo_path, files=files, bot_commit=bot_commit) - projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)] - return Codebase(projects=projects, config=config, secrets=secrets) diff --git a/src/codegen/sdk/codebase/factory/get_session.ipynb b/src/codegen/sdk/codebase/factory/get_session.ipynb deleted file mode 100644 index e703e7b7e..000000000 --- a/src/codegen/sdk/codebase/factory/get_session.ipynb +++ /dev/null @@ -1,41 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# test get_codebase_session\n", - "from codegen.sdk.codebase.factory.get_session import get_codebase_session\n", - "import time\n", - "\n", - "with get_codebase_session(tmpdir=f\"/tmp/{int(time.time())}\", files={\"file.py\": \"a = 1 + 2\"}) as codebase:\n", - " file = codebase.get_file(\"file.py\")\n", - " owners = file.owners\n", - " assert len(owners) == 0" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/src/codegen/sdk/codebase/factory/get_session.py b/src/codegen/sdk/codebase/factory/get_session.py deleted file mode 100644 index 189eec6e6..000000000 --- a/src/codegen/sdk/codebase/factory/get_session.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -import sys -from collections.abc import Generator -from contextlib import AbstractContextManager, contextmanager -from typing import Literal, overload - -from codegen.configs.models.codebase import CodebaseConfig -from codegen.configs.models.secrets import SecretsConfig -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.codebase.config import ProjectConfig, SessionOptions, TestFlags -from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory -from codegen.sdk.core.codebase import Codebase, PyCodebaseType, TSCodebaseType -from codegen.sdk.core.file import SourceFile -from codegen.sdk.tree_sitter_parser import print_errors -from codegen.shared.enums.programming_language import ProgrammingLanguage - - -@overload -def get_codebase_session( - tmpdir: str | os.PathLike[str], - programming_language: None = None, - files: dict[str, str] = {}, - commit: bool = True, - sync_graph: bool = True, - verify_input: bool = True, - verify_output: bool = True, - config: CodebaseConfig = TestFlags, - session_options: SessionOptions = SessionOptions(), - secrets: SecretsConfig | None = None, -) -> AbstractContextManager[PyCodebaseType]: ... - - -@overload -def get_codebase_session( - tmpdir: str | os.PathLike[str], - programming_language: Literal[ProgrammingLanguage.PYTHON], - files: dict[str, str] = {}, - commit: bool = True, - sync_graph: bool = True, - verify_input: bool = True, - verify_output: bool = True, - config: CodebaseConfig = TestFlags, - session_options: SessionOptions = SessionOptions(), - secrets: SecretsConfig | None = None, -) -> AbstractContextManager[PyCodebaseType]: ... - - -@overload -def get_codebase_session( - tmpdir: str | os.PathLike[str], - programming_language: Literal[ProgrammingLanguage.TYPESCRIPT], - files: dict[str, str] = {}, - commit: bool = True, - sync_graph: bool = True, - verify_input: bool = True, - verify_output: bool = True, - config: CodebaseConfig = TestFlags, - session_options: SessionOptions = SessionOptions(), - secrets: SecretsConfig | None = None, -) -> AbstractContextManager[TSCodebaseType]: ... - - -@contextmanager -def get_codebase_session( - tmpdir: str | os.PathLike[str], - programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, - files: dict[str, str] = {}, - commit: bool = True, - sync_graph: bool = True, - verify_input: bool = True, - verify_output: bool = True, - config: CodebaseConfig = TestFlags, - session_options: SessionOptions = SessionOptions(), - secrets: SecretsConfig | None = None, -) -> Generator[Codebase, None, None]: - """Gives you a Codebase operating on the files you provided as a dict""" - codebase = CodebaseFactory.get_codebase_from_files(repo_path=str(tmpdir), files=files, config=config, secrets=secrets, programming_language=programming_language) - with codebase.session( - commit=commit, - sync_graph=sync_graph, - session_options=session_options, - ): - if verify_input: - for file in codebase.files: - # NOTE: We only check SourceFiles for syntax errors - abs_filepath = os.path.join(tmpdir, file.filepath) - if os.path.exists(abs_filepath): - if isinstance(file, SourceFile): - # Check for syntax errors - print_errors(abs_filepath, file.content) - if file.ts_node.has_error: - msg = "Invalid syntax in test case" - raise SyntaxError(msg) - yield codebase - - if verify_output: - for file in codebase.files: - if os.path.exists(file.filepath): - if file.ts_node.has_error and len(file.content.splitlines()) < 10: - print(file.content, file=sys.stderr) - print_errors(file.filepath, file.content) - assert not file.ts_node.has_error, "Invalid syntax in file after commiting" - - -@contextmanager -def get_codebase_graph_session( - tmpdir: str, - programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, - files: dict[str, str] = {}, - sync_graph: bool = True, - session_options: SessionOptions = SessionOptions(), -) -> Generator[CodebaseContext, None, None]: - """Gives you a Codebase2 operating on the files you provided as a dict""" - op = RepoOperator.create_from_files(repo_path=tmpdir, files=files) - projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)] - graph = CodebaseContext(projects=projects, config=TestFlags) - with graph.session(sync_graph=sync_graph, session_options=session_options): - try: - yield graph - finally: - pass diff --git a/src/codegen/sdk/codebase/flagging/code_flag.py b/src/codegen/sdk/codebase/flagging/code_flag.py deleted file mode 100644 index 1b1a92fc5..000000000 --- a/src/codegen/sdk/codebase/flagging/code_flag.py +++ /dev/null @@ -1,35 +0,0 @@ -from dataclasses import dataclass -from typing import Generic, TypeVar - -from codegen.sdk.codebase.flagging.enums import MessageType -from codegen.sdk.core.interfaces.editable import Editable - -Symbol = TypeVar("Symbol", bound=Editable | None) - - -@dataclass -class CodeFlag(Generic[Symbol]): - symbol: Symbol - message: str | None = None # a short desc of the code flag/violation. ex: enums should be ordered alphabetically - message_type: MessageType = MessageType.GITHUB | MessageType.CODEGEN # where to send the message (either Github or Slack) - message_recipient: str | None = None # channel ID or user ID to send the message (if message_type is SLACK) - - @property - def hash(self) -> str: - return self.symbol.span.model_dump_json() - - @property - def filepath(self) -> str: - return self.symbol.file.filepath if self.symbol else "" - - def __eq__(self, other): - if self.symbol != other.symbol: - return False - if self.message != other.message: - return False - if self.message_type != other.message_type: - return False - return True - - def __repr__(self): - return f"" diff --git a/src/codegen/sdk/codebase/flagging/enums.py b/src/codegen/sdk/codebase/flagging/enums.py deleted file mode 100644 index 949f09ce5..000000000 --- a/src/codegen/sdk/codebase/flagging/enums.py +++ /dev/null @@ -1,36 +0,0 @@ -from enum import IntFlag, auto -from typing import TypedDict - -from typing_extensions import ReadOnly - -from codegen.shared.decorators.docs import apidoc - - -@apidoc -class MessageType(IntFlag): - """Destination of the message - - Attributes: - CODEGEN: Rendered in the diff preview - GITHUB: Posted as a comment on the PR - SLACK: Sent over slack - """ - - CODEGEN = auto() - GITHUB = auto() - SLACK = auto() - - -@apidoc -class FlagKwargs(TypedDict, total=False): - """Kwargs for the flag_instance method of the Codebase class. - - Attributes: - message: The message to be displayed in the diff preview or posted as a comment on the PR. - message_type: Where the message will be sent (CODEGEN, GITHUB, SLACK) - message_recipient: The recipient of the message. - """ - - message: ReadOnly[str | None] - message_type: ReadOnly[MessageType] - message_recipient: ReadOnly[str | None] diff --git a/src/codegen/sdk/codebase/flagging/flags.py b/src/codegen/sdk/codebase/flagging/flags.py deleted file mode 100644 index 636d5145a..000000000 --- a/src/codegen/sdk/codebase/flagging/flags.py +++ /dev/null @@ -1,77 +0,0 @@ -from dataclasses import dataclass, field -from typing import TypeVar - -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.enums import MessageType -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.core.interfaces.editable import Editable -from codegen.shared.decorators.docs import noapidoc - -Symbol = TypeVar("Symbol", bound=Editable) - - -@dataclass -class Flags: - _flags: list[CodeFlag] = field(default_factory=list) - _find_mode: bool = False - _active_group: list[CodeFlag] | None = None - - def flag_instance( - self, - symbol: Symbol | None = None, - message: str | None = None, - message_type: MessageType = MessageType.GITHUB | MessageType.CODEGEN, - message_recipient: str | None = None, - ) -> CodeFlag[Symbol]: - """Flags a symbol, file or import to enable enhanced tracking of changes and splitting into - smaller PRs. - - This method should be called once per flaggable entity and should be called before any edits are made to the entity. - Flags enable tracking of changes and can be used for various purposes like generating pull requests or applying changes selectively. - - Args: - symbol (TSymbol | None): The symbol to flag. Can be None if just flagging a message. - message (str | None): Optional message to associate with the flag. - message_type (MessageType): The type of message. Defaults to MessageType.GITHUB and MessageType.CODEGEN. - message_recipient (str | None): Optional recipient for the message. - - Returns: - CodeFlag: A flag object representing the flagged entity. - """ - flag = CodeFlag(symbol=symbol, message=message, message_type=message_type, message_recipient=message_recipient) - if self._find_mode: - self._flags.append(flag) - return flag - - def should_fix(self, flag: CodeFlag) -> bool: - """Returns True if the flag should be fixed based on the current mode and active group. - - Used to filter out flags that are not in the active group and determine if the flag should be processed or ignored. - - Args: - flag (CodeFlag): The code flag to check. - - Returns: - bool: True if the flag should be fixed, False if it should be ignored. - Returns False in find mode. - Returns True if no active group is set. - Returns True if the flag's hash exists in the active group hashes. - """ - if self._find_mode: - return False - elif self._active_group is None: - return True - else: - return flag.hash in self._active_group_hashes - - @noapidoc - def set_find_mode(self, find_mode: bool) -> None: - self._find_mode = find_mode - - @noapidoc - def set_active_group(self, group: Group) -> None: - """Will only fix these flags.""" - # TODO - flesh this out more with Group datatype and GroupBy - self._active_group = group.flags - self._find_mode = False - self._active_group_hashes = set(flag.hash for flag in group.flags) diff --git a/src/codegen/sdk/codebase/flagging/group.py b/src/codegen/sdk/codebase/flagging/group.py deleted file mode 100644 index 58f6f9e95..000000000 --- a/src/codegen/sdk/codebase/flagging/group.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass - -from dataclasses_json import dataclass_json - -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - -DEFAULT_GROUP_ID = 0 - - -@dataclass_json -@dataclass -class Group: - group_by: GroupBy - segment: str - flags: list[CodeFlag] | None = None - id: int = DEFAULT_GROUP_ID diff --git a/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py deleted file mode 100644 index d2fccd474..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py +++ /dev/null @@ -1,22 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - - -class AllGrouper(BaseGrouper): - """Group all flags into one group.""" - - type: GroupBy = GroupBy.ALL - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - return [Group(group_by=GroupBy.ALL, segment="all", flags=flags)] if flags else [] - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - if segment != "all": - msg = f"❌ Invalid segment for AllGrouper: {segment}. Only 'all' is a valid segment." - raise ValueError(msg) - return Group(group_by=GroupBy.ALL, segment=segment, flags=flags) diff --git a/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py deleted file mode 100644 index 171468da6..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py +++ /dev/null @@ -1,33 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class AppGrouper(BaseGrouper): - """Group flags by segment=app. - Ex: apps/profile. - """ - - type: GroupBy = GroupBy.APP - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - unique_apps = list({"/".join(flag.filepath.split("/")[:3]) for flag in flags}) - groups = [] - for idx, app in enumerate(unique_apps): - matches = [f for f in flags if f.filepath.startswith(app)] - if len(matches) > 0: - groups.append(Group(id=idx, group_by=GroupBy.APP, segment=app, flags=matches)) - return groups - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - segment_flags = [f for f in flags if f.filepath.startswith(segment)] - if len(segment_flags) == 0: - logger.warning(f"🤷‍♀️ No flags found for APP segment: {segment}") - return Group(group_by=GroupBy.APP, segment=segment, flags=segment_flags) diff --git a/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py deleted file mode 100644 index 02b072eb2..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py +++ /dev/null @@ -1,29 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - - -class BaseGrouper: - """Base class of all groupers. - Children of this class should include in their doc string: - - a short desc of what the segment format is. ex: for FileGrouper the segment is a filename - """ - - type: GroupBy - - def __init__(self) -> None: - if type is None: - msg = "Must set type in BaseGrouper" - raise ValueError(msg) - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - msg = "Must implement create_all_groups in BaseGrouper" - raise NotImplementedError(msg) - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - """TODO: handle the case when 0 flags are passed in""" - msg = "Must implement create_single_group in BaseGrouper" - raise NotImplementedError(msg) diff --git a/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py deleted file mode 100644 index 752a77285..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py +++ /dev/null @@ -1,41 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - -DEFAULT_CHUNK_SIZE = 5 - - -class CodeownerGrouper(BaseGrouper): - """Group flags by CODEOWNERS. - - Parses .github/CODEOWNERS and groups by each possible codeowners - - Segment should be either a github username or github team name. - """ - - type: GroupBy = GroupBy.CODEOWNER - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - owner_to_group: dict[str, Group] = {} - no_owner_group = Group(group_by=GroupBy.CODEOWNER, segment="@no-owner", flags=[]) - for idx, flag in enumerate(flags): - flag_owners = repo_operator.codeowners_parser.of(flag.filepath) # TODO: handle codeowners_parser could be null - if not flag_owners: - no_owner_group.flags.append(flag) - continue - # NOTE: always use the first owner. ex if the line is /dir @team1 @team2 then use team1 - flag_owner = flag_owners[0][1] - group = owner_to_group.get(flag_owner, Group(id=idx, group_by=GroupBy.CODEOWNER, segment=flag_owner, flags=[])) - group.flags.append(flag) - owner_to_group[flag_owner] = group - - no_owner_group.id = len(owner_to_group) - return [*list(owner_to_group.values()), no_owner_group] - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - msg = "TODO: implement single group creation" - raise NotImplementedError(msg) diff --git a/src/codegen/sdk/codebase/flagging/groupers/constants.py b/src/codegen/sdk/codebase/flagging/groupers/constants.py deleted file mode 100644 index 2fc2a29ab..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper -from codegen.sdk.codebase.flagging.groupers.app_grouper import AppGrouper -from codegen.sdk.codebase.flagging.groupers.codeowner_grouper import CodeownerGrouper -from codegen.sdk.codebase.flagging.groupers.file_chunk_grouper import FileChunkGrouper -from codegen.sdk.codebase.flagging.groupers.file_grouper import FileGrouper -from codegen.sdk.codebase.flagging.groupers.instance_grouper import InstanceGrouper - -ALL_GROUPERS = [ - AllGrouper, - AppGrouper, - CodeownerGrouper, - FileChunkGrouper, - FileGrouper, - InstanceGrouper, -] diff --git a/src/codegen/sdk/codebase/flagging/groupers/enums.py b/src/codegen/sdk/codebase/flagging/groupers/enums.py deleted file mode 100644 index c84b2f413..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/enums.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import StrEnum - - -class GroupBy(StrEnum): - ALL = "all" - APP = "app" - CODEOWNER = "codeowner" - FILE = "file" - FILE_CHUNK = "file_chunk" - HOT_COLD = "hot_cold" - INSTANCE = "instance" diff --git a/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py deleted file mode 100644 index 704028970..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py +++ /dev/null @@ -1,46 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy -from codegen.shared.logging.get_logger import get_logger -from codegen.shared.string.csv_utils import comma_separated_to_list, list_to_comma_separated - -logger = get_logger(__name__) - -DEFAULT_CHUNK_SIZE = 5 - - -class FileChunkGrouper(BaseGrouper): - """Group flags by a chunk of files. - Ex: if chunk size is 10 then a Group only contains flags from max 10 unique files. - TODO: currently only supports a harcoded chunk size of 5. - - Segment is a comma separated list of filenames. - """ - - type: GroupBy = GroupBy.FILE_CHUNK - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - map = {f.filepath: f for f in flags} - filenames = sorted(map.keys()) - chunks = chunk_list(filenames, DEFAULT_CHUNK_SIZE) - groups = [] - for idx, chunk in enumerate(chunks): - chunk_flags = [map[filename] for filename in chunk] - groups.append(Group(id=idx, group_by=GroupBy.FILE_CHUNK, segment=list_to_comma_separated(chunk), flags=chunk_flags)) - return groups - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - segment_filepaths = comma_separated_to_list(segment) - all_segment_flags = [f for f in flags if f.filepath in segment_filepaths] - if len(all_segment_flags) == 0: - logger.warning(f"🤷‍♀️ No flags found for FILE_CHUNK segment: {segment_filepaths}") - return Group(group_by=GroupBy.FILE_CHUNK, segment=segment, flags=all_segment_flags) - - -def chunk_list(lst: list, chk_size: int) -> list[list[str]]: - for index in range(0, len(lst), chk_size): - yield lst[index : index + chk_size] diff --git a/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py deleted file mode 100644 index 9e2cfa0a4..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py +++ /dev/null @@ -1,32 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class FileGrouper(BaseGrouper): - """Group flags by file. - Segment is the filename. - """ - - type: GroupBy = GroupBy.FILE - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - groups = [] - filenames = sorted(list({f.filepath for f in flags})) - for idx, filename in enumerate(filenames): - filename_flags = [flag for flag in flags if flag.filepath == filename] - groups.append(Group(id=idx, group_by=GroupBy.FILE, segment=filename, flags=filename_flags)) - return groups - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - segment_flags = [flag for flag in flags if flag.filepath == segment] - if len(segment_flags) == 0: - logger.warning(f"🤷‍♀️ No flags found for FILE segment: {segment}") - return Group(group_by=GroupBy.FILE, segment=segment, flags=segment_flags) diff --git a/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py deleted file mode 100644 index bbdbdac97..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py +++ /dev/null @@ -1,27 +0,0 @@ -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - - -class InstanceGrouper(BaseGrouper): - """Group flags by flags. haha - One Group per flag. - """ - - type: GroupBy = GroupBy.INSTANCE - - @staticmethod - def create_all_groups(flags: list[CodeFlag], repo_operator: RepoOperator | None = None) -> list[Group]: - return [Group(id=idx, group_by=GroupBy.INSTANCE, segment=f.hash, flags=[f]) for idx, f in enumerate(flags)] - - @staticmethod - def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RepoOperator | None = None) -> Group: - # TODO: not sure if it's possible to regenerate a group for instance grouper b/c it needs to re-generate/re-find the flag. might need to rely on the flag meta 🤦‍♀️ - try: - flag = CodeFlag.from_json(segment) - return Group(group_by=GroupBy.INSTANCE, segment=segment, flags=[flag]) - except Exception as e: - msg = f"Unable to deserialize segment ({segment}) into CodeFlag. Unable to create group." - raise ValueError(msg) diff --git a/src/codegen/sdk/codebase/flagging/groupers/utils.py b/src/codegen/sdk/codebase/flagging/groupers/utils.py deleted file mode 100644 index 38d43cfa2..000000000 --- a/src/codegen/sdk/codebase/flagging/groupers/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper -from codegen.sdk.codebase.flagging.groupers.constants import ALL_GROUPERS -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy - - -def get_grouper_by_group_by(group_by: GroupBy | None) -> type[BaseGrouper]: - if group_by is None: - return AllGrouper - matched_groupers = [x for x in ALL_GROUPERS if x.type == group_by] - if len(matched_groupers) > 0: - return matched_groupers[0] - msg = f"No grouper found for group_by={group_by}. Did you add to ALL_GROUPERS?" - raise ValueError(msg) diff --git a/src/codegen/sdk/codebase/io/file_io.py b/src/codegen/sdk/codebase/io/file_io.py deleted file mode 100644 index f59a28851..000000000 --- a/src/codegen/sdk/codebase/io/file_io.py +++ /dev/null @@ -1,66 +0,0 @@ -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path - -from codegen.sdk.codebase.io.io import IO, BadWriteError -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class FileIO(IO): - """IO implementation that writes files to disk, and tracks pending changes.""" - - files: dict[Path, bytes] - allowed_paths: list[Path] | None - - def __init__(self, allowed_paths: list[Path] | None = None): - self.files = {} - self.allowed_paths = allowed_paths - - def _verify_path(self, path: Path) -> None: - if self.allowed_paths is not None: - if not any(path.resolve().is_relative_to(p.resolve()) for p in self.allowed_paths): - msg = f"Path {path.resolve()} is not within allowed paths {self.allowed_paths}" - raise BadWriteError(msg) - - def write_bytes(self, path: Path, content: bytes) -> None: - self._verify_path(path) - self.files[path] = content - - def read_bytes(self, path: Path) -> bytes: - self._verify_path(path) - if path in self.files: - return self.files[path] - else: - return path.read_bytes() - - def save_files(self, files: set[Path] | None = None) -> None: - to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys() - for path in to_save: - self._verify_path(path) - with ThreadPoolExecutor() as exec: - exec.map(lambda path: path.write_bytes(self.files[path]), to_save) - if files is None: - self.files.clear() - else: - for path in to_save: - del self.files[path] - - def check_changes(self) -> None: - if self.files: - logger.error(BadWriteError("Directly called file write without calling commit_transactions")) - self.files.clear() - - def delete_file(self, path: Path) -> None: - self._verify_path(path) - self.untrack_file(path) - if path.exists(): - path.unlink() - - def untrack_file(self, path: Path) -> None: - self._verify_path(path) - self.files.pop(path, None) - - def file_exists(self, path: Path) -> bool: - self._verify_path(path) - return path.exists() diff --git a/src/codegen/sdk/codebase/io/io.py b/src/codegen/sdk/codebase/io/io.py deleted file mode 100644 index 710474aab..000000000 --- a/src/codegen/sdk/codebase/io/io.py +++ /dev/null @@ -1,46 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path - - -class BadWriteError(Exception): - pass - - -class IO(ABC): - def write_file(self, path: Path, content: str | bytes | None) -> None: - if content is None: - self.untrack_file(path) - elif isinstance(content, str): - self.write_text(path, content) - else: - self.write_bytes(path, content) - - def write_text(self, path: Path, content: str) -> None: - self.write_bytes(path, content.encode("utf-8")) - - @abstractmethod - def write_bytes(self, path: Path, content: bytes) -> None: - pass - - @abstractmethod - def read_bytes(self, path: Path) -> bytes: - pass - - def read_text(self, path: Path) -> str: - return self.read_bytes(path).decode("utf-8") - - @abstractmethod - def save_files(self, files: set[Path] | None = None) -> None: - pass - - @abstractmethod - def check_changes(self) -> None: - pass - - @abstractmethod - def delete_file(self, path: Path) -> None: - pass - - @abstractmethod - def file_exists(self, path: Path) -> bool: - pass diff --git a/src/codegen/sdk/codebase/multigraph.py b/src/codegen/sdk/codebase/multigraph.py deleted file mode 100644 index 2a76fec70..000000000 --- a/src/codegen/sdk/codebase/multigraph.py +++ /dev/null @@ -1,19 +0,0 @@ -from collections import defaultdict -from dataclasses import dataclass, field -from typing import Generic, TypeVar - -from codegen.sdk import TYPE_CHECKING -from codegen.sdk.core.detached_symbols.function_call import FunctionCall - -if TYPE_CHECKING: - from codegen.sdk.core.function import Function - -TFunction = TypeVar("TFunction", bound=Function) - - -@dataclass -class MultiGraph(Generic[TFunction]): - """Mapping of API endpoints to their definitions and usages across languages.""" - - api_definitions: dict[str, TFunction] = field(default_factory=dict) - usages: defaultdict[str, list[FunctionCall]] = field(default_factory=lambda: defaultdict(list)) diff --git a/src/codegen/sdk/codebase/node_classes/__init__.py b/src/codegen/sdk/codebase/node_classes/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/codebase/node_classes/generic_node_classes.py b/src/codegen/sdk/codebase/node_classes/generic_node_classes.py deleted file mode 100644 index a3b67a8ab..000000000 --- a/src/codegen/sdk/codebase/node_classes/generic_node_classes.py +++ /dev/null @@ -1,22 +0,0 @@ -from codegen.sdk.codebase.node_classes.node_classes import NodeClasses -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.detached_symbols.code_block import CodeBlock -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.detached_symbols.parameter import Parameter -from codegen.sdk.core.file import File -from codegen.sdk.core.function import Function -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.statements.comment import Comment - -GenericNodeClasses = NodeClasses( - file_cls=File, - class_cls=Class, - function_cls=Function, - import_cls=Import, - parameter_cls=Parameter, - comment_cls=Comment, - code_block_cls=CodeBlock, - function_call_cls=FunctionCall, - bool_conversion={}, - dynamic_import_parent_types={}, -) diff --git a/src/codegen/sdk/codebase/node_classes/node_classes.py b/src/codegen/sdk/codebase/node_classes/node_classes.py deleted file mode 100644 index f439dc1c3..000000000 --- a/src/codegen/sdk/codebase/node_classes/node_classes.py +++ /dev/null @@ -1,49 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from functools import cached_property -from typing import TYPE_CHECKING - -from codegen.sdk.core.interfaces.resolvable import Resolvable - -if TYPE_CHECKING: - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.function import Function - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.statements.comment import Comment - from codegen.sdk.core.symbol import Symbol - - -@dataclass -class NodeClasses: - file_cls: type[SourceFile] - class_cls: type[Class] - function_cls: type[Function] - import_cls: type[Import] - - # Detached symbols - parameter_cls: type[Parameter] - code_block_cls: type[CodeBlock] - function_call_cls: type[FunctionCall] - comment_cls: type[Comment] - bool_conversion: dict[bool, str] - dynamic_import_parent_types: set[type[Editable]] - symbol_map: dict[str, type[Symbol]] = field(default_factory=dict) - expression_map: dict[str, type[Expression]] = field(default_factory=dict) - type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict) - keywords: list[str] = field(default_factory=list) - type_node_type: str = "" - int_dict_key: bool = False - - @cached_property - def resolvables(self) -> set[str]: - id_types = {k for k, v in self.expression_map.items() if isinstance(v, type) and issubclass(v, Resolvable)} - id_types.update(["identifier"]) - return id_types diff --git a/src/codegen/sdk/codebase/node_classes/py_node_classes.py b/src/codegen/sdk/codebase/node_classes/py_node_classes.py deleted file mode 100644 index 7f2203f75..000000000 --- a/src/codegen/sdk/codebase/node_classes/py_node_classes.py +++ /dev/null @@ -1,130 +0,0 @@ -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.node_classes.node_classes import NodeClasses -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import String, Type -from codegen.sdk.core.expressions.await_expression import AwaitExpression -from codegen.sdk.core.expressions.binary_expression import BinaryExpression -from codegen.sdk.core.expressions.boolean import Boolean -from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.expressions.none_type import NoneType -from codegen.sdk.core.expressions.number import Number -from codegen.sdk.core.expressions.parenthesized_expression import ParenthesizedExpression -from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression -from codegen.sdk.core.expressions.unary_expression import UnaryExpression -from codegen.sdk.core.expressions.unpack import Unpack -from codegen.sdk.core.function import Function -from codegen.sdk.core.statements.comment import Comment -from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement -from codegen.sdk.core.statements.if_block_statement import IfBlockStatement -from codegen.sdk.core.statements.switch_statement import SwitchStatement -from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement -from codegen.sdk.core.statements.while_statement import WhileStatement -from codegen.sdk.core.symbol_groups.dict import Dict -from codegen.sdk.core.symbol_groups.list import List -from codegen.sdk.core.symbol_groups.tuple import Tuple -from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters -from codegen.sdk.python import PyClass, PyFile, PyFunction, PyImport, PySymbol -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.detached_symbols.parameter import PyParameter -from codegen.sdk.python.expressions.chained_attribute import PyChainedAttribute -from codegen.sdk.python.expressions.conditional_expression import PyConditionalExpression -from codegen.sdk.python.expressions.generic_type import PyGenericType -from codegen.sdk.python.expressions.named_type import PyNamedType -from codegen.sdk.python.expressions.string import PyString -from codegen.sdk.python.expressions.union_type import PyUnionType -from codegen.sdk.python.statements.import_statement import PyImportStatement -from codegen.sdk.python.statements.match_case import PyMatchCase -from codegen.sdk.python.statements.with_statement import WithStatement - - -def parse_subscript(node: TSNode, file_node_id, ctx, parent): - if (node.prev_named_sibling and node.prev_named_sibling.text.decode("utf-8") == "TypeAlias") or isinstance(parent, Type): - return PyGenericType(node, file_node_id, ctx, parent) - return SubscriptExpression(node, file_node_id, ctx, parent) - - -PyExpressionMap = { - "string": PyString, - "dictionary": Dict, - "list": List, - "name": Name, - "true": Boolean, - "false": Boolean, - "integer": Number, - "float": Number, - "identifier": Name, - "attribute": PyChainedAttribute, - "call": FunctionCall, - "binary_operator": BinaryExpression, - "boolean_operator": BinaryExpression, - "comparison_operator": ComparisonExpression, - "string_content": String, - "parenthesized_expression": ParenthesizedExpression, - "await": AwaitExpression, - "function_definition": PyFunction, - "list_splat": Unpack, - "dictionary_splat": Unpack, - "tuple": Tuple, - "conditional_expression": PyConditionalExpression, - "not_operator": UnaryExpression, - "subscript": parse_subscript, - "type_parameter": TypeParameters, - "pattern_list": List, - # "assignment": PyAssignment.from_assignment, - # "augmented_assignment": PyAssignment.from_assignment, - # "named_expression": PyAssignment.from_named_expression, -} - -PyStatementMap = { - "import_statement": PyImportStatement, - "import_from_statement": PyImportStatement, - "future_import_statement": PyImportStatement, -} - -PySymbolMap = { - "decorated_definition": PySymbol.from_decorated_definition, - "function_definition": PyFunction, - "class_definition": PyClass, -} - -PyNodeClasses = NodeClasses( - file_cls=PyFile, - class_cls=PyClass, - function_cls=PyFunction, - import_cls=PyImport, - parameter_cls=PyParameter, - comment_cls=Comment, - code_block_cls=PyCodeBlock, - function_call_cls=FunctionCall, - symbol_map=PySymbolMap, - expression_map=PyExpressionMap, - type_map={ - "union_type": PyUnionType, - "binary_operator": PyUnionType, - "generic_type": PyGenericType, - "subscript": PyGenericType, - "none": NoneType, - "identifier": PyNamedType, - "attribute": PyNamedType, - "string": PyNamedType, # TODO: handle string types (IE postponed annotations) - }, - keywords=["async"], - type_node_type="type", - int_dict_key=True, - bool_conversion={ - True: "True", - False: "False", - }, - dynamic_import_parent_types={ - Function, - IfBlockStatement, - TryCatchStatement, - WithStatement, - ForLoopStatement, - WhileStatement, - SwitchStatement, - PyMatchCase, - }, -) diff --git a/src/codegen/sdk/codebase/node_classes/ts_node_classes.py b/src/codegen/sdk/codebase/node_classes/ts_node_classes.py deleted file mode 100644 index e1d4515c2..000000000 --- a/src/codegen/sdk/codebase/node_classes/ts_node_classes.py +++ /dev/null @@ -1,184 +0,0 @@ -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.node_classes.node_classes import NodeClasses -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions.await_expression import AwaitExpression -from codegen.sdk.core.expressions.binary_expression import BinaryExpression -from codegen.sdk.core.expressions.boolean import Boolean -from codegen.sdk.core.expressions.defined_name import DefinedName -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.expressions.none_type import NoneType -from codegen.sdk.core.expressions.number import Number -from codegen.sdk.core.expressions.parenthesized_expression import ParenthesizedExpression -from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression -from codegen.sdk.core.expressions.tuple_type import TupleType -from codegen.sdk.core.expressions.unary_expression import UnaryExpression -from codegen.sdk.core.expressions.unpack import Unpack -from codegen.sdk.core.expressions.value import Value -from codegen.sdk.core.function import Function -from codegen.sdk.core.statements.comment import Comment -from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement -from codegen.sdk.core.statements.if_block_statement import IfBlockStatement -from codegen.sdk.core.statements.switch_case import SwitchCase -from codegen.sdk.core.statements.switch_statement import SwitchStatement -from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement -from codegen.sdk.core.statements.while_statement import WhileStatement -from codegen.sdk.core.symbol_groups.list import List -from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters -from codegen.sdk.typescript.class_definition import TSClass -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement -from codegen.sdk.typescript.detached_symbols.jsx.expression import JSXExpression -from codegen.sdk.typescript.detached_symbols.jsx.prop import JSXProp -from codegen.sdk.typescript.detached_symbols.parameter import TSParameter -from codegen.sdk.typescript.enum_definition import TSEnum -from codegen.sdk.typescript.enums import TSFunctionTypeNames -from codegen.sdk.typescript.expressions.array_type import TSArrayType -from codegen.sdk.typescript.expressions.chained_attribute import TSChainedAttribute -from codegen.sdk.typescript.expressions.conditional_type import TSConditionalType -from codegen.sdk.typescript.expressions.function_type import TSFunctionType -from codegen.sdk.typescript.expressions.generic_type import TSGenericType -from codegen.sdk.typescript.expressions.lookup_type import TSLookupType -from codegen.sdk.typescript.expressions.named_type import TSNamedType -from codegen.sdk.typescript.expressions.object_type import TSObjectType -from codegen.sdk.typescript.expressions.query_type import TSQueryType -from codegen.sdk.typescript.expressions.readonly_type import TSReadonlyType -from codegen.sdk.typescript.expressions.string import TSString -from codegen.sdk.typescript.expressions.ternary_expression import TSTernaryExpression -from codegen.sdk.typescript.expressions.undefined_type import TSUndefinedType -from codegen.sdk.typescript.expressions.union_type import TSUnionType -from codegen.sdk.typescript.file import TSFile -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.import_resolution import TSImport -from codegen.sdk.typescript.interface import TSInterface -from codegen.sdk.typescript.namespace import TSNamespace -from codegen.sdk.typescript.statements.comment import TSComment -from codegen.sdk.typescript.statements.import_statement import TSImportStatement -from codegen.sdk.typescript.symbol_groups.dict import TSDict -from codegen.sdk.typescript.type_alias import TSTypeAlias - - -def parse_dict(node: TSNode, *args): - if node.prev_named_sibling and node.prev_named_sibling.text.decode("utf-8").endswith("propTypes"): - return TSObjectType(node, *args) - return TSDict(node, *args) - - -def parse_new(node: TSNode, *args): - if not node.child_by_field_name("arguments"): - return Value(node, *args) - return FunctionCall(node, *args) - - -TSExpressionMap = { - "string": TSString, - "template_string": TSString, - "object": parse_dict, - "array": List, - "name": Name, - "true": Boolean, - "false": Boolean, - "number": Number, - "property_identifier": DefinedName, - "call_expression": FunctionCall, - "identifier": Name, - "type_identifier": Name, # HACK - "shorthand_property_identifier_pattern": Name, # maybe hack?? - "null": NoneType, - "comment": TSComment, - "binary_expression": BinaryExpression, - "member_expression": TSChainedAttribute, - "method_definition": TSFunction, - "parenthesized_expression": ParenthesizedExpression, - "await_expression": AwaitExpression, - "unary_expression": UnaryExpression, - "shorthand_property_identifier": Name, - "ternary_expression": TSTernaryExpression, - "jsx_expression": JSXExpression, - "jsx_element": JSXElement, - "jsx_closing_element": JSXElement, - "jsx_opening_element": JSXElement, - "jsx_self_closing_element": JSXElement, - "jsx_attribute": JSXProp, - "spread_element": Unpack, - "subscript_expression": SubscriptExpression, - "type_parameters": TypeParameters, - "array_pattern": List, - "new_expression": parse_new, - # "variable_declarator": TSAssignment.from_named_expression, - # "property_signature": TSAssignment.from_named_expression, - # "public_field_definition": TSAssignment.from_named_expression, - # "assignment_expression": TSAssignment.from_assignment, - # "augmented_assignment_expression": TSAssignment.from_assignment, -} - -TSStatementMap = { - "import_statement": TSImportStatement, - "import": TSImportStatement, -} - -TSSymbolMap = { - **{function_type.value: TSFunction.from_function_type for function_type in TSFunctionTypeNames}, - "class_declaration": TSClass, - "abstract_class_declaration": TSClass, - "interface_declaration": TSInterface, - "type_alias_declaration": TSTypeAlias, - "enum_declaration": TSEnum, - "internal_module": TSNamespace, -} - -TSNodeClasses = NodeClasses( - file_cls=TSFile, - class_cls=TSClass, - function_cls=TSFunction, - import_cls=TSImport, - parameter_cls=TSParameter, - code_block_cls=TSCodeBlock, - function_call_cls=FunctionCall, - comment_cls=Comment, - symbol_map=TSSymbolMap, - expression_map=TSExpressionMap, - type_map={ - "union_type": TSUnionType, - "lookup_type": TSLookupType, - "predefined_type": TSNamedType, - "identifier": TSNamedType, - "type_identifier": TSNamedType, - "object_type": TSObjectType, - "generic_type": TSGenericType, - "literal_type": { - "null": NoneType, - "undefined": TSUndefinedType, - "string": TSNamedType, - }, - "parenthesized_type": { - "function_type": TSFunctionType, - "type_query": TSQueryType, - }, - "nested_type_identifier": TSNamedType, - "array_type": TSArrayType, - "member_expression": TSNamedType, # TODO: parse generics in class extends clause - "function_type": TSFunctionType, - "type_query": TSQueryType, - "readonly_type": TSReadonlyType, - "intersection_type": TSUnionType, # TODO: Not accurate, implement this properly - "type_parameter": TSNamedType, - "tuple_type": TupleType, - "conditional_type": TSConditionalType, - }, - keywords=["export", "default", "let", "const", "static", "async"], - type_node_type="type_annotation", - bool_conversion={ - True: "true", - False: "false", - }, - dynamic_import_parent_types={ - Function, - IfBlockStatement, - TryCatchStatement, - ForLoopStatement, - WhileStatement, - SwitchStatement, - SwitchCase, - }, -) diff --git a/src/codegen/sdk/codebase/progress/progress.py b/src/codegen/sdk/codebase/progress/progress.py deleted file mode 100644 index ec1c8b6e1..000000000 --- a/src/codegen/sdk/codebase/progress/progress.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -if TYPE_CHECKING: - from codegen.sdk.codebase.progress.task import Task - -T = TypeVar("T", bound="Task") - - -class Progress(ABC, Generic[T]): - @abstractmethod - def begin(self, message: str, count: int | None = None) -> T: - pass diff --git a/src/codegen/sdk/codebase/progress/stub_progress.py b/src/codegen/sdk/codebase/progress/stub_progress.py deleted file mode 100644 index 6c0aac5aa..000000000 --- a/src/codegen/sdk/codebase/progress/stub_progress.py +++ /dev/null @@ -1,7 +0,0 @@ -from codegen.sdk.codebase.progress.progress import Progress -from codegen.sdk.codebase.progress.stub_task import StubTask - - -class StubProgress(Progress[StubTask]): - def begin(self, message: str, count: int | None = None) -> StubTask: - return StubTask() diff --git a/src/codegen/sdk/codebase/progress/stub_task.py b/src/codegen/sdk/codebase/progress/stub_task.py deleted file mode 100644 index 43d25acf7..000000000 --- a/src/codegen/sdk/codebase/progress/stub_task.py +++ /dev/null @@ -1,9 +0,0 @@ -from codegen.sdk.codebase.progress.task import Task - - -class StubTask(Task): - def update(self, message: str, count: int | None = None) -> None: - pass - - def end(self) -> None: - pass diff --git a/src/codegen/sdk/codebase/progress/task.py b/src/codegen/sdk/codebase/progress/task.py deleted file mode 100644 index c0814513d..000000000 --- a/src/codegen/sdk/codebase/progress/task.py +++ /dev/null @@ -1,11 +0,0 @@ -from abc import ABC, abstractmethod - - -class Task(ABC): - @abstractmethod - def update(self, message: str, count: int | None = None) -> None: - pass - - @abstractmethod - def end(self) -> None: - pass diff --git a/src/codegen/sdk/codebase/range_index.py b/src/codegen/sdk/codebase/range_index.py deleted file mode 100644 index 9c1bdf691..000000000 --- a/src/codegen/sdk/codebase/range_index.py +++ /dev/null @@ -1,53 +0,0 @@ -import itertools -from collections import defaultdict -from functools import cached_property - -from tree_sitter import Range - -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.extensions.sort import sort_editables - - -class RangeIndex: - _ranges: defaultdict[Range, list[Editable]] - _canonical_range: defaultdict[Range, dict[int, Editable]] - - def __init__(self): - self._ranges = defaultdict(list) - self._canonical_range = defaultdict(dict) - - def add_to_range(self, editable: Editable) -> None: - self._ranges[editable.range].append(editable) - - def mark_as_canonical(self, editable: Editable) -> None: - self._canonical_range[editable.range][editable.ts_node.kind_id] = editable - - def get_all_for_range(self, range: Range) -> list[Editable]: - return self._ranges[range] - - def get_canonical_for_range(self, range: Range, kind_id: int) -> Editable | None: - if mapping := self._canonical_range.get(range, None): - return mapping.get(kind_id, None) - - def clear(self): - self._ranges.clear() - self._canonical_range.clear() - self.__dict__.pop("children", None) - self.__dict__.pop("nodes", None) - - @cached_property - def nodes(self) -> list[Editable]: - return list(itertools.chain.from_iterable(self._ranges.values())) - - @cached_property - def children(self) -> dict[Editable, list[Editable]]: - ret = defaultdict(list) - for node in self.nodes: - # if node.ctx.config.debug: - # assert node.parent != node, node.__class__ - if node.parent != node: - ret[node.parent].append(node) - return ret - - def get_children(self, parent: Editable) -> list[Editable]: - return sort_editables(self.children[parent]) diff --git a/src/codegen/sdk/codebase/resolution_stack.py b/src/codegen/sdk/codebase/resolution_stack.py deleted file mode 100644 index ccce6d38a..000000000 --- a/src/codegen/sdk/codebase/resolution_stack.py +++ /dev/null @@ -1,3 +0,0 @@ -from codegen.sdk.extensions.resolution import ResolutionStack - -__all__ = ["ResolutionStack"] diff --git a/src/codegen/sdk/codebase/span.py b/src/codegen/sdk/codebase/span.py deleted file mode 100644 index d7933927e..000000000 --- a/src/codegen/sdk/codebase/span.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Annotated, Any - -from pydantic import BaseModel -from pydantic.config import ConfigDict -from pydantic.functional_validators import BeforeValidator -from pydantic.json_schema import JsonSchemaValue, WithJsonSchema -from pydantic_core.core_schema import ValidationInfo -from tree_sitter import Point, Range - -from codegen.shared.decorators.docs import apidoc - - -def validate_range(value: Any, info: ValidationInfo) -> Range: - if isinstance(value, dict): - value = Range( - start_byte=value["start_byte"], - end_byte=value["end_byte"], - start_point=Point(**value["start_point"]), - end_point=Point(**value["end_point"]), - ) - elif not isinstance(value, Range): - msg = "Invalid type for range field. Expected tree_sitter.Range or dict." - raise ValueError(msg) - return value - - -def range_json_schema() -> JsonSchemaValue: - return { - "type": "object", - "properties": { - "start_byte": {"type": "integer"}, - "end_byte": {"type": "integer"}, - "start_point": { - "type": "object", - "properties": { - "row": {"type": "integer"}, - "column": {"type": "integer"}, - }, - }, - "end_point": { - "type": "object", - "properties": {"row": {"type": "integer"}, "column": {"type": "integer"}}, - }, - }, - } - - -RangeAdapter = Annotated[ - Range, - BeforeValidator(validate_range), - WithJsonSchema(range_json_schema()), -] - - -@apidoc -class Span(BaseModel): - """Range within the codebase - - Attributes: - range: Adapter for the range within the codebase. - filepath: The path to the file associated with the range. - """ - - model_config = ConfigDict( - frozen=True, - arbitrary_types_allowed=True, - json_encoders={ - Range: lambda r: { - "start_byte": r.start_byte, - "end_byte": r.end_byte, - "start_point": { - "row": r.start_point.row, - "column": r.start_point.column, - }, - "end_point": { - "row": r.end_point.row, - "column": r.end_point.column, - }, - } - }, - ) - range: RangeAdapter - filepath: str diff --git a/src/codegen/sdk/codebase/transaction_manager.py b/src/codegen/sdk/codebase/transaction_manager.py deleted file mode 100644 index a59b6eb4e..000000000 --- a/src/codegen/sdk/codebase/transaction_manager.py +++ /dev/null @@ -1,306 +0,0 @@ -import time -from collections.abc import Callable -from pathlib import Path -from typing import TYPE_CHECKING - -from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite -from codegen.sdk.codebase.transactions import ( - EditTransaction, - FileAddTransaction, - FileRemoveTransaction, - FileRenameTransaction, - RemoveTransaction, - Transaction, - TransactionPriority, -) -from codegen.shared.exceptions.control_flow import MaxPreviewTimeExceeded, MaxTransactionsExceeded -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.core.file import File - - -logger = get_logger(__name__) - - -class TransactionError(Exception): - pass - - -class TransactionManager: - """Responsible for handling `Transaction` objects - basically an atomic modification of a codebase. - - This is used by the Codebase class to queue up transactions and then commit them in bulk. - """ - - # Unsorted list of transactions, grouped by file - # TODO: consider using SortedList for better performance - queued_transactions: dict[Path, list[Transaction]] - pending_undos: set[Callable[[], None]] - _commiting: bool = False - max_transactions: int | None = None # None = no limit - stopwatch_start = None - stopwatch_max_seconds: int | None = None # None = no limit - - def __init__(self) -> None: - self.queued_transactions = dict() - self.pending_undos = set() - - def sort_transactions(self) -> None: - for file_path, file_transactions in self.queued_transactions.items(): - file_transactions.sort(key=Transaction._to_sort_key) - - def clear_transactions(self) -> None: - """Should be called between tests to remove any potential extraneous transactions. Makes sure we reset max_transactions as well.""" - if len(self.queued_transactions) > 0: - logger.warning("Not all transactions have been committed") - self.queued_transactions.clear() - for undo in self.pending_undos: - undo() - self.pending_undos.clear() - self.set_max_transactions(None) - self.reset_stopwatch() - - def _format_transactions(self, transactions: list[Transaction]) -> str: - return "\n".join([">" * 100 + f"\n[ID: {t.transaction_id}]: {t.diff_str()}" + "<" * 100 for t in transactions]) - - def get_transactions_str(self) -> str: - """Returns a human-readable string representation of the transactions""" - return "\n\n\n".join([f"{file_path}:\n{self._format_transactions(transactions)}" for file_path, transactions in self.queued_transactions.items()]) - - #################################################################################################################### - # Transation Limits - #################################################################################################################### - - def get_num_transactions(self) -> int: - """Returns total number of transactions created to date""" - return sum([len(transactions) for transactions in self.queued_transactions.values()]) - - def set_max_transactions(self, max_transactions: int | None = None) -> None: - self.max_transactions = max_transactions - - def max_transactions_exceeded(self) -> bool: - """Util method to check if the max transactions limit has been exceeded.""" - if self.max_transactions is None: - return False - return self.get_num_transactions() >= self.max_transactions - - #################################################################################################################### - # Stopwatch - #################################################################################################################### - - def reset_stopwatch(self, max_seconds: int | None = None) -> int: - self.stopwatch_start = time.time() - self.stopwatch_max_seconds = max_seconds - - def is_time_exceeded(self) -> bool: - if self.stopwatch_max_seconds is None: - return False - else: - num_seconds = time.time() - self.stopwatch_start - return num_seconds > self.stopwatch_max_seconds - - #################################################################################################################### - # Transaction Creation - #################################################################################################################### - - def add_file_add_transaction(self, filepath: Path) -> None: - t = FileAddTransaction(filepath) - self.add_transaction(t) - - def add_file_rename_transaction(self, file: "File", new_filepath: str) -> None: - t = FileRenameTransaction(file, new_filepath) - self.add_transaction(t) - - def add_file_remove_transaction(self, file: "File") -> None: - t = FileRemoveTransaction(file) - self.add_transaction(t) - - def add_transaction(self, transaction: Transaction, dedupe: bool = True, solve_conflicts: bool = True) -> bool: - # Get the list of transactions for the file - file_path = transaction.file_path - if file_path not in self.queued_transactions: - self.queued_transactions[file_path] = [] - file_queue = self.queued_transactions[file_path] - - # Dedupe transactions - if dedupe and transaction in file_queue: - logger.debug(f"Transaction already exists in queue: {transaction}") - return False - # Solve conflicts - if new_transaction := self._resolve_conflicts(transaction, file_queue, solve_conflicts=solve_conflicts): - file_queue.append(new_transaction) - - self.check_limits() - return True - - def check_limits(self): - self.check_max_transactions() - self.check_max_preview_time() - - def check_max_transactions(self): - # =====[ Max transactions ]===== - # max_transactions is set so that long-running codemods terminate early so we can quickly surface a subset - # of the results to the user. This may result in errors that do not get covered. - if self.max_transactions_exceeded(): - logger.info(f"Max transactions reached: {self.max_transactions}. Stopping codemod.") - msg = f"Max transactions reached: {self.max_transactions}" - raise MaxTransactionsExceeded(msg, threshold=self.max_transactions) - - def check_max_preview_time(self): - # =====[ Max preview time ]===== - # This is to prevent the preview from taking too long. We want to keep it at like ~5s in the frontend during debugging - if self.is_time_exceeded(): - logger.info(f"Max preview time exceeded: {self.stopwatch_max_seconds}. Stopping codemod.") - msg = f"Max preview time exceeded: {self.is_time_exceeded()}" - raise MaxPreviewTimeExceeded(msg, threshold=self.stopwatch_max_seconds) - - #################################################################################################################### - # Commit - #################################################################################################################### - - def to_commit(self, files: set[Path] | None = None) -> set[Path]: - """Get node ids of files to commit""" - if files is None: - return set(self.queued_transactions.keys()) - return files.intersection(self.queued_transactions) - - def commit(self, files: set[Path]) -> list[DiffLite]: - """Execute transactions in bulk for each file, in reverse order of start_byte. - Returns the list of diffs that were committed. - """ - if self._commiting: - logger.warn("Skipping commit, already committing") - return [] - self._commiting = True - try: - diffs: list[DiffLite] = [] - if not self.queued_transactions or len(self.queued_transactions) == 0: - return diffs - - self.sort_transactions() - - # TODO: raise error if two transactions byte ranges overlap with each other - if len(files) > 3: - num_transactions = sum([len(self.queued_transactions[file_path]) for file_path in files]) - logger.info(f"Committing {num_transactions} transactions for {len(files)} files") - else: - for file in files: - logger.info(f"Committing {len(self.queued_transactions[file])} transactions for {file}") - for file_path in files: - file_transactions = self.queued_transactions.pop(file_path, []) - modified = False - for transaction in file_transactions: - # Add diff IF the file is a source file - diff = transaction.get_diff() - if diff.change_type == ChangeType.Modified: - if not modified: - modified = True - diffs.append(diff) - else: - diffs.append(diff) - transaction.execute() - return diffs - finally: - self._commiting = False - - #################################################################################################################### - # Conflict Resolution - #################################################################################################################### - - def _resolve_conflicts(self, transaction: Transaction, file_queue: list[Transaction], solve_conflicts: bool = True) -> Transaction | None: - def break_down(to_break: EditTransaction) -> bool: - if new_transactions := to_break.break_down(): - try: - insert_idx = file_queue.index(to_break) - file_queue.pop(insert_idx) - except ValueError: - insert_idx = len(file_queue) - for new_transaction in new_transactions: - if broken_down := self._resolve_conflicts(new_transaction, file_queue, solve_conflicts=solve_conflicts): - file_queue.insert(insert_idx, broken_down) - return True - return False - - try: - conflicts = self._get_conflicts(transaction) - if solve_conflicts and conflicts: - # Check if the current transaction completely overlaps with any existing transaction - if (completely_overlapping := self._get_overlapping_conflicts(transaction)) is not None: - # If it does, check the overlapping transaction's type - # If the overlapping transaction is a remove, remove the current transaction - if isinstance(completely_overlapping, RemoveTransaction): - return None - # If the overlapping transaction is an edit, raise an error - elif isinstance(completely_overlapping, EditTransaction): - if break_down(completely_overlapping): - return transaction - - raise TransactionError() - else: - # If current transaction is deleted, remove all conflicting transactions - if isinstance(transaction, RemoveTransaction): - for t in conflicts: - file_queue.remove(t) - # If current transaction is edit, raise an error - elif isinstance(transaction, EditTransaction): - if break_down(transaction): - return None - raise TransactionError() - - # Add to priority queue and rebuild the queue - return transaction - except TransactionError as e: - logger.exception(e) - msg = ( - f"Potential conflict detected in file {transaction.file_path}!\n" - "Attempted to perform code modification:\n" - "\n" - f"{self._format_transactions([transaction])}\n" - "\n" - "That potentially conflicts with the following other modifications:\n" - "\n" - f"{self._format_transactions(conflicts)}\n" - "\n" - "Aborting!\n" - "\n" - f"[Conflict Detected] Potential Modification Conflict in File {transaction.file_path}!" - ) - raise TransactionError(msg) - - def get_transactions_at_range(self, file_path: Path, start_byte: int, end_byte: int, transaction_order: TransactionPriority | None = None, *, combined: bool = False) -> list[Transaction]: - """Returns list of queued transactions that matches the given filtering criteria. - - Args: - combined: Return a list of transactions which collectively apply to the given range - """ - matching_transactions = [] - if file_path not in self.queued_transactions: - return matching_transactions - - for t in self.queued_transactions[file_path]: - if t.start_byte == start_byte: - if t.end_byte == end_byte: - if transaction_order is None or t.transaction_order == transaction_order: - matching_transactions.append(t) - elif combined and t.start_byte != t.end_byte: - if other := self.get_transactions_at_range(t.file_path, t.end_byte, end_byte, transaction_order, combined=combined): - return [t, *other] - - return matching_transactions - - def _get_conflicts(self, transaction: Transaction) -> list[Transaction]: - """Returns all transactions that overlap with the given transaction""" - overlapping_transactions = [] - queued_transactions = list(self.queued_transactions[transaction.file_path]) - for t in queued_transactions: - if transaction.start_byte < t.end_byte and transaction.end_byte > t.start_byte: - overlapping_transactions.append(t) - return overlapping_transactions - - def _get_overlapping_conflicts(self, transaction: Transaction) -> Transaction | None: - """Returns the transaction that completely overlaps with the given transaction""" - for t in self.queued_transactions[transaction.file_path]: - if transaction.start_byte >= t.start_byte and transaction.end_byte <= t.end_byte: - return t - return None diff --git a/src/codegen/sdk/codebase/transactions.py b/src/codegen/sdk/codebase/transactions.py deleted file mode 100644 index 31d48b2e1..000000000 --- a/src/codegen/sdk/codebase/transactions.py +++ /dev/null @@ -1,302 +0,0 @@ -from collections.abc import Callable -from difflib import unified_diff -from enum import IntEnum -from functools import cached_property -from pathlib import Path -from typing import TYPE_CHECKING, Protocol, runtime_checkable - -from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite - -if TYPE_CHECKING: - from codegen.sdk.core.file import File - - -class TransactionPriority(IntEnum): - Remove = 0 # Remove always has highest priority - Edit = 1 # Edit always comes next (remove and edit are incompatible with each other, so it should error out) - Insert = 2 # Insert is always the last of the edit operations - # File operations happen last, since they will mess up all other transactions - FileAdd = 10 - FileRename = 11 - FileRemove = 12 - - -@runtime_checkable -class ContentFunc(Protocol): - """A function executed to generate a content block dynamically.""" - - def __call__(self) -> str: ... - - -class Transaction: - start_byte: int - end_byte: int - file_path: Path - priority: int | tuple - transaction_order: TransactionPriority - transaction_counter: int = 0 - - def __init__( - self, - start_byte: int, - end_byte: int, - file_path: Path, - priority: int | tuple = 0, - new_content: str | None | Callable[[], str] = None, - ) -> None: - self.start_byte = start_byte - assert self.start_byte >= 0 - self.end_byte = end_byte - self.file_path = file_path - self.priority = priority - self._new_content = new_content - self.transaction_id = Transaction.transaction_counter - - Transaction.transaction_counter += 1 - - def __repr__(self) -> str: - return f"" - - def __hash__(self): - return hash((self.start_byte, self.end_byte, self.file_path, self.priority, self.new_content)) - - def __eq__(self, other): - if not isinstance(other, type(self)): - return False - - # Check for everything EXCEPT transaction_time - return ( - self.start_byte == other.start_byte - and self.end_byte == other.end_byte - and self.file_path == other.file_path - and self.priority == other.priority - and self._new_content == other._new_content - ) - - @property - def length(self): - return self.end_byte - self.start_byte - - def execute(self): - msg = "Transaction.execute() must be implemented by subclasses" - raise NotImplementedError(msg) - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - msg = "Transaction.get_diff() must be implemented by subclasses" - raise NotImplementedError(msg) - - def diff_str(self): - """Human-readable string representation of the change""" - msg = "Transaction.diff_str() must be implemented by subclasses" - raise NotImplementedError(msg) - - def _to_sort_key(transaction: "Transaction"): - # Sort by: - # 1. Descending start_byte - # 2. Ascending transaction type - # 3. Ascending priority - # 4. Descending time of transaction= - priority = (transaction.priority,) if isinstance(transaction.priority, int) else transaction.priority - - return -transaction.start_byte, transaction.transaction_order.value, priority, -transaction.transaction_id - - @cached_property - def new_content(self) -> str | None: - return self._new_content() if isinstance(self._new_content, ContentFunc) else self._new_content - - -class RemoveTransaction(Transaction): - transaction_order = TransactionPriority.Remove - - exec_func: Callable[[], None] | None = None - - def __init__(self, start_byte: int, end_byte: int, file: "File", priority: int = 0, exec_func: Callable[[], None] | None = None) -> None: - super().__init__(start_byte, end_byte, file.path, priority=priority) - self.file = file - self.exec_func = exec_func - - def _generate_new_content_bytes(self) -> bytes: - content_bytes = self.file.content_bytes - new_content_bytes = content_bytes[: self.start_byte] + content_bytes[self.end_byte :] - return new_content_bytes - - def execute(self) -> None: - """Removes the content between start_byte and end_byte""" - self.file.write_bytes(self._generate_new_content_bytes()) - if self.exec_func: - self.exec_func() - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) - - def diff_str(self) -> str: - """Human-readable string representation of the change""" - diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True))) - return f"Remove {self.length} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}" - - -class InsertTransaction(Transaction): - transaction_order = TransactionPriority.Insert - - exec_func: Callable[[], None] | None = None - - def __init__( - self, - insert_byte: int, - file: "File", - new_content: str | Callable[[], str], - *, - priority: int | tuple = 0, - exec_func: Callable[[], None] | None = None, - ) -> None: - super().__init__(insert_byte, insert_byte, file.path, priority=priority, new_content=new_content) - self.insert_byte = insert_byte - self.file = file - self.exec_func = exec_func - - def _generate_new_content_bytes(self) -> bytes: - new_bytes = bytes(self.new_content, encoding="utf-8") - content_bytes = self.file.content_bytes - head = content_bytes[: self.insert_byte] - tail = content_bytes[self.insert_byte :] - new_content_bytes = head + new_bytes + tail - return new_content_bytes - - def execute(self) -> None: - """Inserts new_src at the specified byte_index""" - self.file.write_bytes(self._generate_new_content_bytes()) - if self.exec_func: - self.exec_func() - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) - - def diff_str(self) -> str: - """Human-readable string representation of the change""" - diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True))) - return f"Insert {len(self.new_content)} bytes at bytes ({self.start_byte}, {self.end_byte})\n{diff}" - - -class EditTransaction(Transaction): - transaction_order = TransactionPriority.Edit - new_content: str - - def __init__( - self, - start_byte: int, - end_byte: int, - file: "File", - new_content: str, - priority: int = 0, - ) -> None: - super().__init__(start_byte, end_byte, file.path, priority=priority, new_content=new_content) - self.file = file - - def _generate_new_content_bytes(self) -> bytes: - new_bytes = bytes(self.new_content, "utf-8") - content_bytes = self.file.content_bytes - new_content_bytes = content_bytes[: self.start_byte] + new_bytes + content_bytes[self.end_byte :] - return new_content_bytes - - def execute(self) -> None: - """Edits the entirety of this node's source to new_src""" - self.file.write_bytes(self._generate_new_content_bytes()) - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Modified, self.file_path, old_content=self.file.content_bytes) - - def diff_str(self) -> str: - """Human-readable string representation of the change""" - diff = "".join(unified_diff(self.file.content.splitlines(True), self._generate_new_content_bytes().decode("utf-8").splitlines(True))) - return f"Edit {self.length} bytes at bytes ({self.start_byte}, {self.end_byte}), src: ({self.new_content[:50]})\n{diff}" - - def break_down(self) -> list[InsertTransaction] | None: - old = self.file.content_bytes[self.start_byte : self.end_byte] - new = bytes(self.new_content, "utf-8") - if old and old in new: - prefix, suffix = new.split(old, maxsplit=1) - ret = [] - if suffix: - ret.append(InsertTransaction(self.end_byte, self.file, suffix.decode("utf-8"), priority=self.priority)) - if prefix: - ret.append(InsertTransaction(self.start_byte, self.file, prefix.decode("utf-8"), priority=self.priority)) - return ret - return None - - -class FileAddTransaction(Transaction): - transaction_order = TransactionPriority.FileAdd - - def __init__( - self, - file_path: Path, - priority: int = 0, - ) -> None: - super().__init__(0, 0, file_path, priority=priority) - - def execute(self) -> None: - """Adds a new file""" - pass # execute is a no-op as the file is immediately added - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Added, self.file_path) - - def diff_str(self) -> str: - """Human-readable string representation of the change""" - return f"Add file at {self.file_path}" - - -class FileRenameTransaction(Transaction): - transaction_order = TransactionPriority.FileRename - - def __init__( - self, - file: "File", - new_file_path: str, - priority: int = 0, - ) -> None: - super().__init__(0, 0, file.path, priority=priority, new_content=new_file_path) - self.new_file_path = file.ctx.to_absolute(new_file_path) - self.file = file - - def execute(self) -> None: - """Renames the file""" - self.file.ctx.io.save_files({self.file.path}) - self.file_path.rename(self.new_file_path) - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Renamed, self.file_path, self.file_path, self.new_file_path, old_content=self.file.content_bytes) - - def diff_str(self) -> str: - """Human-readable string representation of the change""" - return f"Rename file from {self.file_path} to {self.new_file_path}" - - -class FileRemoveTransaction(Transaction): - transaction_order = TransactionPriority.FileRemove - - def __init__( - self, - file: "File", - priority: int = 0, - ) -> None: - super().__init__(0, 0, file.path, priority=priority) - self.file = file - - def execute(self) -> None: - """Removes the file""" - self.file.ctx.io.delete_file(self.file.path) - - def get_diff(self) -> DiffLite: - """Gets the diff produced by this transaction""" - return DiffLite(ChangeType.Removed, self.file_path, old_content=self.file.content_bytes) - - def diff_str(self) -> str: - """Human-readable string representation of the change""" - return f"Remove file at {self.file_path}" diff --git a/src/codegen/sdk/codebase/validation.py b/src/codegen/sdk/codebase/validation.py deleted file mode 100644 index 54d04228a..000000000 --- a/src/codegen/sdk/codebase/validation.py +++ /dev/null @@ -1,151 +0,0 @@ -from __future__ import annotations - -import functools -import socket -from collections import Counter, defaultdict -from enum import StrEnum -from typing import TYPE_CHECKING - -from tabulate import tabulate - -from codegen.sdk.enums import NodeType -from codegen.sdk.utils import truncate_line -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from rustworkx import PyDiGraph - - from codegen.sdk.core.codebase import CodebaseType - - -class PostInitValidationStatus(StrEnum): - NO_NODES = "NO_NODES" - NO_EDGES = "NO_EDGES" - MISSING_FILES = "MISSING_FILES" - LOW_IMPORT_RESOLUTION_RATE = "LOW_IMPORT_RESOLUTION_RATE" - SUCCESS = "SUCCESS" - - -def post_init_validation(codebase: CodebaseType) -> PostInitValidationStatus: - """Post codebase._init_graph verifies that the built graph is valid.""" - from codegen.sdk.codebase.codebase_context import GLOBAL_FILE_IGNORE_LIST - - # Verify the graph has nodes - if len(codebase.ctx.nodes) == 0: - return PostInitValidationStatus.NO_NODES - - # Verify the graph has the same number of files as there are in the repo - if len(codebase.files) != len(list(codebase.op.iter_files(codebase.ctx.projects[0].subdirectories, extensions=codebase.ctx.extensions, ignore_list=GLOBAL_FILE_IGNORE_LIST))): - return PostInitValidationStatus.MISSING_FILES - - # Verify import resolution - num_resolved_imports = len([imp for imp in codebase.imports if imp.imported_symbol and imp.imported_symbol.node_type != NodeType.EXTERNAL]) - if len(codebase.imports) > 0 and num_resolved_imports / len(codebase.imports) < 0.2: - logger.info(f"Codebase {codebase.repo_path} has {num_resolved_imports / len(codebase.imports)} < 0.2 resolved imports") - return PostInitValidationStatus.LOW_IMPORT_RESOLUTION_RATE - return PostInitValidationStatus.SUCCESS - - -def post_reset_validation(init_nodes, nodes, init_edges, edges, repo_name: str, subdirectories: list[str] | None) -> None: - logger.info("Verifying graph state and alerting if necessary") - hostname = socket.gethostname() - - if len(dict.fromkeys(nodes)) != len(dict.fromkeys(init_nodes)): - post_message = f"Reset graph: Nodes do not match for {repo_name} for subdirectories {subdirectories}. Hostname: {hostname}" - message = get_nodes_error(init_nodes, nodes) - log_or_throw(post_message, message) - if len(dict.fromkeys(edges)) != len(dict.fromkeys(init_edges)): - post_message = f"Reset graph: Edges do not match for {repo_name} for subdirectories {subdirectories}. Hostname: {hostname}" - message = get_edges_error(edges, init_edges) - log_or_throw(post_message, message) - - -def post_sync_validation(codebase: CodebaseType) -> bool: - """Post codebase.sync, checks that the codebase graph is in a valid state (i.e. not corrupted by codebase.sync)""" - if len(codebase.ctx.all_syncs) > 0 or len(codebase.ctx.pending_syncs) > 0 or len(codebase.ctx.transaction_manager.to_commit()) > 0: - msg = "Can only be called on a reset codebase" - raise NotImplementedError(msg) - if not codebase.ctx.config.codebase.track_graph: - msg = "Can only be called with track_graph=true" - raise NotImplementedError(msg) - return len(dict.fromkeys(codebase.ctx.old_graph.nodes())) == len(dict.fromkeys(codebase.ctx.nodes)) and len(dict.fromkeys(codebase.ctx.old_graph.weighted_edge_list())) == len( - dict.fromkeys(codebase.ctx.edges) - ) - - -def log_or_throw(message, thread_message: str): - hostname = socket.gethostname() - logger.error(message) - # logger.error(thread_message) - if hostname != "modal": - msg = f"{message}\n{thread_message}" - raise Exception(msg) - return - - -def get_edges_error(edges, init_edges): - set_edges = set(edges) - set_init_edges = set(init_edges) - missing_edges = set_init_edges - set_edges - extra_edges = set_edges - set_init_edges - message = "" - if extra_edges: - extras = tabulate((map(functools.partial(truncate_line, max_chars=50), edge) for edge in extra_edges), ["Start", "End", "Edge"], maxcolwidths=50) - message += f""" -Extra edges -``` -{extras} -``` -""" - - if missing_edges: - missing = tabulate((map(functools.partial(truncate_line, max_chars=50), edge) for edge in missing_edges), ["Start", "End", "Edge"], maxcolwidths=50) - message += f""" -Missing edges -``` -{missing} -``` -""" - missing_by_key = defaultdict(lambda: defaultdict(list)) - for u, v, data in missing_edges: - missing_by_key[u][v].append(data) - for u, v, data in extra_edges: - if u in missing_by_key and v in missing_by_key[u]: - for match in missing_by_key[u][v]: - message += f"Possible match from {u} to {v}: {match} -> {data}\n" - if len(edges) != len(set_init_edges): - message += f"{len(edges) - len(set_edges)} edges duplicated from {len(init_edges) - len(set_init_edges)}. Printing out up to 5 edges\n" - extras = tabulate(((*map(functools.partial(truncate_line, max_chars=50), edge), count) for edge, count in Counter(edges).most_common(5)), ["Start", "End", "Edge", "Count"], maxcolwidths=50) - message += extras - return message - - -def get_nodes_error(init_nodes, nodes): - set_nodes = set(nodes) - set_init_nodes = set(init_nodes) - message = f""" -Extra nodes -``` -{set_nodes - set_init_nodes} -``` - -Missing nodes -``` -{set_init_nodes - set_nodes} -``` -""" - for node in set_nodes - set_init_nodes: - from codegen.sdk.core.external_module import ExternalModule - - if isinstance(node, ExternalModule): - message += "External Module persisted with following dependencies: " + str(list((node.ctx.get_node(source), edge) for source, _, edge in node.ctx.in_edges(node.node_id))) - return message - - -def get_edges(graph: PyDiGraph): - ret = [] - for start, end, edge in graph.weighted_edge_list(): - ret.append((graph.get_node_data(start), graph.get_node_data(end), edge)) - return ret diff --git a/src/codegen/sdk/core/__init__.py b/src/codegen/sdk/core/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/core/assignment.py b/src/codegen/sdk/core/assignment.py deleted file mode 100644 index 116fca79d..000000000 --- a/src/codegen/sdk/core/assignment.py +++ /dev/null @@ -1,287 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression, Name -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.core.expressions.multi_expression import MultiExpression -from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.typeable import Typeable -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.core.symbol_groups.dict import Dict -from codegen.sdk.enums import SymbolType -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.typescript.expressions.object_type import TSObjectType -from codegen.sdk.utils import find_index -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.codebase.resolution_stack import ResolutionStack - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.assignment_statement import AssignmentStatement - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.core.statements.statement import Statement - -Parent = TypeVar("Parent", bound="AssignmentStatement | ExportStatement") - - -@apidoc -class Assignment(Symbol[Parent, ...], Typeable[Parent, ...], HasValue, Generic[Parent]): - """Represents an assignment for a single variable within an assignment statement. - - Example: - ```typescript - var z - var z = 5 - ``` - - Attributes: - symbol_type: The type of symbol, set to SymbolType.GlobalVar. - """ - - _left: Expression[Self] - symbol_type = SymbolType.GlobalVar - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, left: TSNode, value: TSNode, name_node: TSNode, type: Type | None = None) -> None: - self._unique_node = name_node # HACK: This prevents deduplication of Assignments - super().__init__(ts_node, file_node_id, ctx, parent=parent, name_node=name_node, name_node_type=Name) - self._left = self._parse_expression(left, default=Name) - self._value_node = self._parse_expression(value) - self.type = type - if self.type is None: - self._init_type() - - @classmethod - def _from_left_and_right_nodes(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, left_node: TSNode, right_node: TSNode) -> list[Assignment]: - left = ctx.parser.parse_expression(left_node, file_node_id, ctx, parent) - value = ctx.parser.parse_expression(right_node, file_node_id, ctx, parent) - - if isinstance(left, Collection | Dict): - assignments = [] - for var in left.symbols: - # Make a deep copy of the value expression for each child - value = ctx.parser.parse_expression(right_node, file_node_id, ctx, parent) - assignments.extend(cls._from_value_expression(ts_node, file_node_id, ctx, parent, left, value, var.ts_node)) - return sort_editables(assignments) - return cls._from_value_expression(ts_node, file_node_id, ctx, parent, left, value, left_node) - - @classmethod - def _from_value_expression( - cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, left: Expression[Self], value: Expression[Self] | list[Expression], name_node: TSNode - ) -> list[Assignment]: - assignments = [cls(ts_node, file_node_id, ctx, parent, left, value, name_node)] - if value and isinstance(value, MultiExpression) and isinstance(value.expressions[0], Assignment): - for expr in value.expressions: - assignments.extend(cls._from_value_expression(expr.ts_node, file_node_id, ctx, parent, expr.left, expr.value, expr.get_name().ts_node)) - return sort_editables(assignments) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - dest = self.self_dest - if value := self.value: - value._compute_dependencies(UsageKind.BODY, dest) - if self.type: - self.type._compute_dependencies(UsageKind.TYPE_ANNOTATION, dest) - - # Check for usages in left hand side of assignment if it is an object access - if name := self.get_name(): - if isinstance(name, ChainedAttribute): - name._compute_dependencies(UsageKind.BODY, dest) - elif isinstance(name, SubscriptExpression): - name._compute_dependencies(UsageKind.BODY, dest) - - @property - @noapidoc - @reader - def left(self) -> Expression[Self]: - """The left side of the assignment. - - Only should be used for internal purposes. - """ - return self._left - - @property - @reader - def index(self) -> int: - """Returns the index of the assignment statement in its parent's code block. - - Returns: - int: The 0-based index position of the assignment statement within its parent's code block statements. - """ - return self.parent.index - - @property - @reader - def is_local_variable(self) -> bool: - """Determines if an assignment represents a local variable in the current scope. - - A local variable is an assignment that: - 1. Is not a chained attribute (e.g., not self.x or obj.x) - 2. Is not in the global (file) scope - - Returns: - bool: True if the assignment is a local variable, False otherwise. - """ - from codegen.sdk.core.file import File - - if isinstance(self._left, ChainedAttribute): - return False - - if isinstance(self.parent, File): - return False - return True - - @proxy_property - @reader - def local_usages(self) -> list[Editable[Statement]]: - """Retrieves all usages of the assigned variable within its code block scope. - - Returns all instances where the variable defined in this Assignment is used within its code block. Only returns usages that occur after the assignment, excluding the usage in the assignment - itself. - - Returns: - list[Editable[Statement]]: A sorted list of statement nodes where the variable is used. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - """ - usages = [] - for statement in self.parent.parent.statements[self.index :]: - var_references = statement.get_variable_usages(self.name) - for var_reference in var_references: - # Exclude the variable usage in the assignment itself - if self.ts_node.byte_range[0] <= var_reference.ts_node.start_byte and self.ts_node.byte_range[1] >= var_reference.ts_node.end_byte: - continue - usages.append(var_reference) - return sort_editables(usages) - - @writer - def set_value(self, src: str) -> None: - """Sets the value of an assignment expression. - - Updates the value of an assignment expression. If the assignment doesn't have an existing value, - it adds one after the type annotation (if present) or after the variable name. If the assignment - already has a value, it replaces the existing value. - - Args: - src (str): The source code string representing the new value to be assigned. - - Returns: - None - """ - if self.value is None: - if self.type: - self.type.insert_after(f" = {src}", newline=False) - else: - self.insert_after(f" = {src}", newline=False) - else: - self.value.edit(src) - - @writer - def set_type_annotation(self, type_str: str) -> None: - """Adds or updates a type annotation for the current assignment. - - This method modifies an assignment statement to include a type annotation. If the assignment already - has a type annotation, it will be overwritten with the new type. If no type annotation exists, - one will be added between the assignment name and the equals sign. - - Args: - type_str (str): The type annotation to be added or updated. - - Returns: - None - """ - type_annotation_node = self.type - if type_annotation_node: - type_annotation_node.edit(type_str) - else: - self._left.insert_after(f": {type_str}", newline=False) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if isinstance(self.type, Chainable) and not self.type.source == "TypeAlias": - yield from self.with_resolution_frame(self.type, direct=False) - elif self.value: - resolved = False - from codegen.sdk.core.statements.assignment_statement import AssignmentStatement - - if self.parent_of_type(AssignmentStatement) and len(self.parent_of_type(AssignmentStatement).assignments) > 0: - name_node = self._name_node.ts_node - if name_node and (val := self.value) and isinstance(val, Chainable): - for resolution in val.resolved_type_frames: - type = resolution.top.node - current = self.ts_node - while current and current.id != name_node.id: - idx = find_index(name_node, current.named_children) - current = current.named_children[idx] if idx != -1 else None - if current is None: - break - if current.type == "object_pattern": - if name_node in current.named_children: - if isinstance(type, TSObjectType): - type = type.get(self.name) - current = name_node - if current.type == "pair_pattern": - key = current.child_by_field_name("key") - if isinstance(type, TSObjectType) and (elem := type.get(key.text.decode("utf-8"))): - type = elem - - if type and type != resolution.top.node: - yield from self.with_resolution_frame(type, direct=False, chained=True) - resolved = True - if not resolved: - yield from self.with_resolution_frame(self.value, direct=False) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - if self.type: - symbols.extend(self.type.descendant_symbols) - if self.value is not None: - symbols.extend(self.value.descendant_symbols) - return symbols - - def __hash__(self): - if self._hash is None: - self._hash = hash((self.filepath, self.range, self.ts_node.kind_id, self._unique_node.range)) - return self._hash - - @reader - def __eq__(self, other: object): - if isinstance(other, Assignment): - return super().__eq__(other) and self._unique_node.range == other._unique_node.range - return super().__eq__(other) - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Simplifies an assignment expression by reducing it based on a boolean condition and updating all the usages. - - - Args: - bool_condition (bool): The boolean value to reduce the condition to. - - """ - self.remove() - for usage in self.usages: - if usage.match == self.name: - usage.match.reduce_condition(bool_condition) diff --git a/src/codegen/sdk/core/autocommit/__init__.py b/src/codegen/sdk/core/autocommit/__init__.py deleted file mode 100644 index 188109bf1..000000000 --- a/src/codegen/sdk/core/autocommit/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Autocommit implementation. - -Theory of Operation: -------------------- -Context: We operate on 3 kinds of nodes: Files, Symbols, and DetachedSymbols. -(Technically DetachedSymbols includes Editables and some are orphans) - -Every time we perform an operation, if the node is a File or Symbol, we will commit and reaquire if it has been updated. -Then if the operation was a write, we mark it (and the file containing it) as pending. -If the symbol is detached, we mark its parent as pending. This is recursive until we reach a symbol on the graph or an orphan -If it was a move (or rename), we also mark where the new symbol will be. This removes all need for commiting in most circumstances. -Edge Cases: ----------- -- We cannot reaquire detached symbols, so we don't autoupdate those. -- We cannot handle situations where you change the type of a symbol then operate on it -- We cannot handle removing then operating on a symbol -- We skip commits when you do raw edits and inserts, but will fall back to autocommit if needed -""" - -from codegen.sdk.core.autocommit.constants import enabled -from codegen.sdk.core.autocommit.decorators import mover, remover, repr_func, writer -from codegen.sdk.core.autocommit.manager import AutoCommit -from codegen.sdk.extensions.autocommit import commiter, reader - -__all__ = [ - "AutoCommit", - "commiter", - "enabled", - "mover", - "reader", - "remover", - "repr_func", - "writer", -] diff --git a/src/codegen/sdk/core/autocommit/constants.py b/src/codegen/sdk/core/autocommit/constants.py deleted file mode 100644 index 9bcb051d2..000000000 --- a/src/codegen/sdk/core/autocommit/constants.py +++ /dev/null @@ -1,62 +0,0 @@ -from enum import IntEnum, unique -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -REMOVED = "REMOVED" - -AutoCommitSymbol = "Editable" - - -@unique -class AutoCommitState(IntEnum): - """Current operation.""" - - Write = 0 # Can only be done inside another write or as the first state - Read = 1 # Can be done anytime - Committing = 2 # During a commit/reset, Prevents any updates - Special = 4 # During Hash or Repr, prevents further changes to state - - -class IllegalWriteError(Exception): - """Indicates there is a write, move, or commit called inside a read, commit, or repr - function. - """ - - pass - - -class NodeNotFoundError(Exception): - """Indicates a node was not found during the update process, such as when editing the type.""" - - pass - - -class OutdatedNodeError(Exception): - """Indicates a node is out of date.""" - - def __init__(self, node: "Editable") -> None: - parent = node - from codegen.sdk.core.symbol import Symbol - - while parent is not None and not isinstance(parent, Symbol): - parent = parent.parent - super().__init__( - f"Using an outdated node {node}.\n" - + "This can happen if you cache a detached symbol, then update a related symbol or file.\n" - + ( - f"Try acquiring the node from it's parent symbol: {parent}.\n" - + "For example if the node was the first parameter of a function, " - + f"call {node.name} = {parent.name}.parameters[0]" - ) - if parent - else "" - ) - - -# SAFETY TOGGLE -enabled = False -# def enabled(): -# # SAFETY TOGGLE -# return True diff --git a/src/codegen/sdk/core/autocommit/decorators.py b/src/codegen/sdk/core/autocommit/decorators.py deleted file mode 100644 index 996738ea8..000000000 --- a/src/codegen/sdk/core/autocommit/decorators.py +++ /dev/null @@ -1,115 +0,0 @@ -from codegen.shared.logging.get_logger import get_logger -import functools -from collections.abc import Callable -from typing import TYPE_CHECKING, ParamSpec, TypeVar, Union, overload - -import wrapt - -from codegen.sdk.core.autocommit.constants import AutoCommitState, enabled -from codegen.sdk.core.node_id_factory import NodeId - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.symbol import Symbol - - -logger = get_logger(__name__) -P = ParamSpec("P") -T = TypeVar("T") - - -@overload -def writer(wrapped: Callable[P, T]) -> Callable[P, T]: ... - - -@overload -def writer( - wrapped: None = None, *, commit: bool = ... -) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - - -def writer( - wrapped: Callable[P, T] | None = None, *, commit: bool = True -) -> Callable[P, T] | Callable[[Callable[P, T]], Callable[P, T]]: - """Indicates the method is a writer. This will automatically update if the original is out of - date. - - Args: - ---- - commit: Whether to commit if there is an update. Do not set this to False unless you are absolutely sure the method can be retried with commit as True safely. - """ - if wrapped is None: - return functools.partial(writer, commit=commit) - - @wrapt.decorator(enabled=enabled) - def wrapper(wrapped: Callable[P, T], instance: "Editable", args, kwargs) -> T: - if instance is None: - instance = args[0] - if instance.removed: - logger.warning("Editing a removed node") - autocommit = instance.ctx._autocommit - logger.debug("Writing node %r,%r", instance, wrapped) - with autocommit.write_state(instance, commit=commit): - return wrapped(*args, **kwargs) - - return wrapper(wrapped) - - -@wrapt.decorator(enabled=enabled) -def remover( - wrapped: Callable[P, T], - instance: Union["Symbol", None] = None, - args: P.args = None, - kwargs: P.kwargs = None, -) -> Callable[P, T]: - """Indicates the node will be removed at the end of this method. - - Further usage of the node will result in undefined behaviour and a warning. - """ - if instance is None: - instance = args[0] - logger.debug("Removing node %r, %r", instance, wrapped) - with instance.ctx._autocommit.write_state(instance): - ret = wrapped(*args, **kwargs) - # instance.ctx._autocommit.set_pending(instance, REMOVED) - instance.removed = True - return ret - - -@wrapt.decorator(enabled=enabled) -def repr_func( - wrapped: Callable[P, T], - instance: Union["Editable", None] = None, - args: P.args = None, - kwargs: P.kwargs = None, -) -> Callable[P, T]: - """Indicates the method is use in debugging/logs.""" - if instance is None: - instance = args[0] - autocommit = instance.ctx._autocommit - old_state = autocommit.enter_state(AutoCommitState.Special) - try: - ret = wrapped(*args, **kwargs) - finally: - autocommit.state = old_state - return ret - - -@wrapt.decorator(enabled=enabled) -def mover( - wrapped: Callable[P, tuple[NodeId, NodeId]], - instance: Union["Symbol", None] = None, - args: P.args = None, - kwargs: P.kwargs = None, -) -> Callable[P, None]: - """Indicates the Node will be moved by the end of this method. - - It should also return the node_id of itself and the new file - """ - if instance is None: - instance = args[0] - with instance.ctx._autocommit.write_state(instance, move=True): - file_node_id, node_id = wrapped(*args, **kwargs) - instance.ctx._autocommit.set_pending(instance, node_id, file_node_id) - instance.removed = False - return None diff --git a/src/codegen/sdk/core/autocommit/manager.py b/src/codegen/sdk/core/autocommit/manager.py deleted file mode 100644 index 9d9932838..000000000 --- a/src/codegen/sdk/core/autocommit/manager.py +++ /dev/null @@ -1,294 +0,0 @@ -from codegen.shared.logging.get_logger import get_logger -from collections.abc import Iterator -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union - -from codegen.sdk.core.autocommit.constants import ( - REMOVED, - AutoCommitState, - AutoCommitSymbol, - IllegalWriteError, - NodeNotFoundError, - OutdatedNodeError, -) -from codegen.sdk.core.autocommit.utils import is_file, is_on_graph, is_symbol -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.extensions.autocommit import update_dict - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.file import File - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.symbol import Symbol - - -logger = get_logger(__name__) - - -@dataclass -class AutoCommitNode: - """The pending update for a node. - - Attributes: - symbol: The symbol being updated. Kept to ensure correctness - generation: Version of the symbol - new_id: New id to fetch (if applicable) - new_file: File symbol was moved to (if applicable) - """ - - symbol: AutoCommitSymbol - generation: int - new_id: NodeId | None = None - new_file: Optional["File"] = None - - -@dataclass -class PendingFiles: - """Current files autocommit is operating on. - - For example, if we read a symbol and find another symbol out of date in the same file, we would - not want to update it. - """ - - files: set[Path] | None - all: bool = False - - def __bool__(self) -> bool: - return bool(self.files) or self.all - - -class AutoCommit: - """Global autocommit state. - - Attributes: - state: Current operation being performed - _files: Mapping of files to their new filepaths, or None if they were just modified - _nodes: Mapping of nodes to their new Node IDs - _locked_files: All files that are currently being operated on - _lock_all: All files are currently being operated on - """ - - state: AutoCommitState | None = None - _files: dict[Path, NodeId | None] - _nodes: dict[NodeId, AutoCommitNode] - ctx: "CodebaseContext" - _locked_files: set[str] - _lock_all: bool = False - - def __init__(self, ctx: "CodebaseContext") -> None: - self.ctx = ctx - self._files = {} - self._nodes = {} - self._locked_files = set() - - def __repr__(self) -> str: - return str(self.__dict__) - - def _commit(self, lock: PendingFiles, additional: str | None = None) -> None: - if lock: - logger.debug( - "Running autocommit on %s", "all files" if lock.all else lock.files - ) - files = lock.files if not lock.all else None - if additional and files: - files.add(additional) - self.ctx.commit_transactions(files=files) - - def _update_file(self, symbol: "File", lock: PendingFiles) -> None: - """Check for an update to a file, and if there is one, copy its dict.""" - if symbol.file_node_id in self._files: - new_id = self._files.pop(symbol.file_node_id, None) - if new_id == REMOVED: - logger.warning("Editing a removed node") - return - self._commit(lock, new_id) - old_node = self.ctx.get_node(symbol.file_node_id) - new_node = self.ctx.get_node( - new_id if new_id is not None else symbol.file_node_id - ) - old_node.__dict__ = new_node.__dict__ - if not lock: - self._files[symbol.file_node_id] = new_id - - def _reaquire_node( - self, - symbol: Union["Symbol", "Import"], - new_node_id: NodeId, - missing_ok: bool = False, - ): - """Re-aquire a symbol.""" - # Prevent double re-aquire - new_node = self.ctx.get_node(new_node_id) - if new_node is None: - if missing_ok: - return - raise NodeNotFoundError( - f"Could not find node with {new_node_id=} {symbol.node_id=}. This may happen if you change the type of a symbol using edit (such as editing a variable into a function)" - ) - update_dict(set(), symbol, new_node) - - def _update_symbol( - self, symbol: Union["Symbol", "Import"], lock: PendingFiles - ) -> None: - """Check for an update to a symbol, and if there is one, copy its dict.""" - node_id = symbol.node_id - if symbol_update := self._nodes.pop(node_id, None): - assert self.state is not None - logger.debug("Running autocommit on %r due to %r", symbol, self.state.name) - self._commit(lock, symbol_update.new_file) - if symbol_update.new_id == REMOVED: - logger.warning("Editing a removed node") - # Incredibly cursed, but keep the update around to make re-acquire succeed - self._nodes[node_id] = symbol_update - return - if symbol.file._generation == symbol_update.generation: - self._reaquire_node(symbol, node_id) - self._nodes[node_id] = symbol_update - else: - new_id = ( - node_id if (symbol_update.new_id is None) else symbol_update.new_id - ) - self._reaquire_node(symbol, new_id) - if symbol_update.new_id == REMOVED: - # Incredibly cursed, but keep the update around to make re-acquire succeed - self._nodes[node_id] = symbol_update - elif symbol.is_outdated: - # We can't re-acquire a node twice - self._reaquire_node(symbol, node_id, missing_ok=True) - - def check_update( - self, node: AutoCommitSymbol, lock: PendingFiles, must_be_updated: bool = True - ) -> None: - """Check for an update to a node if possible.""" - assert self.state is not None - if is_on_graph(node): - self._update_symbol(node, lock=lock) - elif is_file(node): - self._update_file(node, lock=lock) - else: - if node.is_outdated: - if node.parent is not None: - self.check_update( - node.parent, lock=lock, must_be_updated=must_be_updated - ) - if not node.is_outdated: - return - if must_be_updated: - raise OutdatedNodeError(node) - - def set_pending_file( - self, - file: AutoCommitSymbol, - *, - update_id: NodeId | None = None, - new_id: NodeId | None = None, - ) -> None: - """Mark a file as pending.""" - if update_id is None: - update_id = file.filepath - if new_id is not None or update_id not in self._files: - self._files[update_id] = new_id - - def set_pending( - self, - node: AutoCommitSymbol, - new_id: NodeId | None = None, - new_file: NodeId | None = None, - ) -> None: - """Mark a node as pending. - - This also mark the file it's in, the file it's moved to, and it's parent if the node is - detached - """ - if is_file(node): - self.set_pending_file(node, new_id=new_file) - return - self.set_pending_file(node, update_id=node.file_node_id) - if new_file is not None: - self.set_pending_file(node, update_id=new_file) - if is_symbol(node): - new_file_node = self.ctx.get_node(new_file) if new_file else None - if symbol_update := self._nodes.get(node.node_id, None): - assert symbol_update.symbol == node - if new_id is not None: - logger.debug("Setting new id: %s", new_id) - symbol_update.new_id = new_id - symbol_update.new_file = new_file_node - symbol_update.generation = node.file._generation - else: - self._nodes[node.node_id] = AutoCommitNode( - node, node.file._generation, new_id, new_file_node - ) - elif node.parent: - self.set_pending(node.parent, None, None) - else: - logger.warning("Could not find parent node of %r", node) - - @contextmanager - def write_state( - self, node: AutoCommitSymbol, *, commit: bool = True, move: bool = False - ) -> Iterator[None]: - """Enter a write state.""" - if self.state not in (AutoCommitState.Write, None): - # Can't write in a read or commit - logger.error(IllegalWriteError()) - try: - with self.lock_files({node.filepath}, all=move, commit=commit) as lock: - old_state = self.enter_state(AutoCommitState.Write) - self.check_update(node, lock=lock) - yield None - logger.debug("%r: Marking pending autoupdate", node) - # self.set_pending(node, None) - finally: - self.state = old_state - - def enter_state(self, state: AutoCommitState) -> AutoCommitState | None: - """Begin a new state.""" - old_state = self.state - logger.debug( - "Starting %s, previous: %s", - state.name, - old_state.name if old_state else None, - ) - self.state = state - return old_state - - @contextmanager - def lock_files( - self, files: set[Path], all: bool = False, commit: bool = True - ) -> Iterator[PendingFiles]: - to_unlock = self.try_lock_files(files, all, commit) - try: - yield to_unlock - finally: - self.unlock_files(to_unlock) - - def try_lock_files( - self, files: set[Path], all: bool = False, commit: bool = True - ) -> PendingFiles: - if self._lock_all or not commit: - return PendingFiles(set()) - if all: - self._lock_all = True - return PendingFiles(None, True) - to_unlock = files - self._locked_files - self._locked_files |= to_unlock - return PendingFiles(to_unlock) - - def unlock_files(self, files: PendingFiles) -> None: - if files.all: - self._lock_all = False - else: - self._locked_files -= files.files - - def reset(self) -> None: - """Reset Autocommit state. - - Probably not necessary - """ - self._files.clear() - self._nodes.clear() - self._locked_files.clear() - self.state = None diff --git a/src/codegen/sdk/core/autocommit/ruff.toml b/src/codegen/sdk/core/autocommit/ruff.toml deleted file mode 100644 index 0291e8e20..000000000 --- a/src/codegen/sdk/core/autocommit/ruff.toml +++ /dev/null @@ -1,2 +0,0 @@ -extend = "../../../../../pyproject.toml" -lint.extend-select = ["G"] diff --git a/src/codegen/sdk/core/autocommit/utils.py b/src/codegen/sdk/core/autocommit/utils.py deleted file mode 100644 index a7b9e5c84..000000000 --- a/src/codegen/sdk/core/autocommit/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Utilities to prevent circular imports.""" - -from typing import TYPE_CHECKING, Any, TypeGuard, Union - -if TYPE_CHECKING: - from codegen.sdk.core.file import File - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.symbol import Symbol - - -def is_file(node: Any) -> TypeGuard["File"]: - from codegen.sdk.core.file import File - - return isinstance(node, File) - - -def is_symbol(node: Any) -> TypeGuard["Symbol"]: - from codegen.sdk.core.symbol import Symbol - - return isinstance(node, Symbol) - - -def is_on_graph(node: Any) -> TypeGuard[Union["Import", "Symbol"]]: - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.symbol import Symbol - - return isinstance(node, Import | Symbol) diff --git a/src/codegen/sdk/core/class_definition.py b/src/codegen/sdk/core/class_definition.py deleted file mode 100644 index bbf2682ab..000000000 --- a/src/codegen/sdk/core/class_definition.py +++ /dev/null @@ -1,419 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Literal, Self, overload, override - -from typing_extensions import TypeVar - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.interfaces.callable import Callable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.interfaces.inherits import Inherits -from codegen.sdk.core.statements.attribute import Attribute -from codegen.sdk.core.statements.statement import StatementType -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.enums import SymbolType -from codegen.sdk.extensions.utils import cached_property -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.logging.get_logger import get_logger -from codegen.visualizations.enums import VizNode - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.decorator import Decorator - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.expressions import Name - from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.external_module import ExternalModule - from codegen.sdk.core.function import Function - from codegen.sdk.core.interface import Interface - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.symbol_statement import SymbolStatement - from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection - from codegen.sdk.core.symbol_groups.parents import Parents - - -logger = get_logger(__name__) - - -TFunction = TypeVar("TFunction", bound="Function", default="Function") -TDecorator = TypeVar("TDecorator", bound="Decorator", default="Decorator") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock", default="CodeBlock") -TParameter = TypeVar("TParameter", bound="Parameter", default="Parameter") -TType = TypeVar("TType", bound="Type", default="Type") - - -@apidoc -class Class(Inherits[TType], HasBlock[TCodeBlock, TDecorator], Callable[TParameter, TType], HasAttribute[TFunction | Attribute], Generic[TFunction, TDecorator, TCodeBlock, TParameter, TType]): - """Abstract representation of a Class definition. - - Attributes: - symbol_type: The type of symbol, set to SymbolType.Class. - constructor_keyword: The keyword used to identify the constructor method. - parent_classes: The parent classes of this class, if any. - """ - - symbol_type = SymbolType.Class - constructor_keyword = None - parent_classes: Parents[TType, Self] | None = None - _methods: MultiLineCollection[TFunction, Self] | None = None - - def __init__(self, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: SymbolStatement) -> None: - super().__init__(ts_node, file_id, ctx, parent) - self._methods = self._parse_methods() - self._parameters = [] - - #################################################################################################################### - # PROPERTIES - #################################################################################################################### - @proxy_property - @reader - def superclasses(self, max_depth: int | None = None) -> list[Class | ExternalModule | Interface]: - """Returns a list of all classes that this class extends, up to max_depth. - - Gets all classes that this class extends, traversing up the inheritance tree up to max_depth. - The traversal follows Python's Method Resolution Order (MRO), meaning superclasses are searched breadth-first. - - Args: - max_depth (int | None): The maximum depth to traverse up the inheritance tree. If None, traverses the entire tree. - - Returns: - list[Class | ExternalModule | Interface]: A list of all superclass symbols in MRO order, up to max_depth. - Returns an empty list if the class has no parent classes. - """ - # Implements the python MRO, IE: by level - if self.parent_classes is not None: - return self._get_superclasses(max_depth=max_depth) - return [] - - @property - @reader - def parent_class_names(self) -> list[Name | ChainedAttribute]: - """Returns a list of the parent class names that this class inherits from. - - Gets the list of parent class names from Parents object. Returns empty list if class has no parents. - - Returns: - list[Name | ChainedAttribute]: A list of parent class identifiers. Each identifier can be either a simple - name (Name) or a chained attribute (e.g., 'module.Class'). - """ - if self.parent_classes: - return self.parent_classes.parent_class_names - return [] - - @reader - def get_parent_class(self, parent_class_name: str) -> Editable | None: - """Returns the parent class node with the specified name. - - Retrieves a parent class Name or ChainedAttribute node from this class's list of parent class names that matches - the specified name. - - Args: - parent_class_name (str): The name of the parent class to find. - - Returns: - Editable | None: The matching parent class node, or None if no match is found. - """ - return next((p for p in self.parent_class_names if p.source == parent_class_name), None) - - @property - @reader - def is_subclass(self) -> bool: - """Indicates whether the current class is a subclass of another class. - - A class is considered a subclass if it inherits from at least one parent class. - - Returns: - bool: True if the class has one or more parent classes, False otherwise. - """ - return len(self.parent_class_names) > 0 - - @reader - def is_subclass_of(self, parent_class: str | Class, max_depth: int | None = None) -> bool: - """Checks if the class inherits from a specified parent class. - - Determines whether this class is a subclass (direct or indirect) of the specified parent class. The search can be limited to a certain depth in the inheritance tree. - - Args: - parent_class (str | Class): The parent class to check for. Can be specified either as a class name string or Class object. - max_depth (int | None): Maximum inheritance depth to search. None means no limit. - - Returns: - bool: True if this class inherits from the parent class, False otherwise. - """ - if self.parent_classes is None: - return False - return self.parent_classes.is_subclass_of(parent_class, max_depth=max_depth) - - @proxy_property - @reader - def subclasses(self, max_depth: int | None = None) -> list[Class]: - """Returns all classes which subclass this class. - - Retrieves a list of all classes in the codebase that inherit from this class, up to a specified depth. - - Args: - max_depth (int | None, optional): Maximum inheritance depth to search. If None, searches all depths. Defaults to None. - - Returns: - list[Class]: A list of Class objects that inherit from this class. - """ - return self._get_subclasses(max_depth) - - @noapidoc - @commiter - def compute_superclass_dependencies(self) -> None: - if self.parent_classes: - self.parent_classes.compute_superclass_dependencies() - - @cached_property - @reader - def constructor(self) -> TFunction | None: - """Returns the constructor method for this class. - - Gets the constructor of the class (e.g., __init__ in Python) by checking for a method matching the class's constructor_keyword. This includes searching through superclasses. - - Returns: - TFunction | None: The constructor method if found, None otherwise. - """ - # This now does the superclass traversal - return self.get_method(self.constructor_keyword) - - @abstractmethod - @reader - def _parse_methods(self) -> MultiLineCollection[TFunction, Self]: - """Parses the methods of the class into a multi line collection.""" - - @overload - def methods(self, *, max_depth: Literal[0] = ..., private: Literal[True] = ..., magic: Literal[True] = ...) -> MultiLineCollection[TFunction, Self]: ... - @overload - def methods(self, *, max_depth: int | None = ..., private: bool = ..., magic: Literal[False]) -> list[TFunction]: ... - @overload - def methods(self, *, max_depth: int | None = ..., private: Literal[False], magic: bool = ...) -> list[TFunction]: ... - @overload - def methods(self, *, max_depth: int | None, private: bool = ..., magic: bool = ...) -> list[TFunction]: ... - @proxy_property - @reader - def methods(self, *, max_depth: int | None = 0, private: bool = True, magic: bool = True) -> list[TFunction] | MultiLineCollection[TFunction, Self]: - """Retrieves all methods that exist on this Class, including methods from superclasses, with - filtering options. - - Args: - max_depth (int | None, optional): Include parent classes up to max_depth. None means no limit, 0 means only current class. Defaults to 0. - private (bool, optional): Whether to include private methods. Defaults to True. - magic (bool, optional): Whether to include magic methods. Defaults to False. - - Returns: - A list of methods that match the filtering criteria. Methods are ordered by class hierarchy - (methods from the current class appear before methods from parent classes). For methods with the same name, - only the first occurrence is included. Methods are returned as a MultiLineCollection for efficient access and manipulation if max depth is 0 and private and magic methods - are included. - """ - if max_depth == 0 and private and magic: - return self._methods - parents = [self, *self.superclasses(max_depth=max_depth)] - result = {} - for c in parents: - if isinstance(c, Class): - for m in c._methods: - if m.is_private and not private: - continue - if m.is_magic and not magic: - continue - if m.name not in result: - result[m.name] = m - return list(result.values()) - - @reader - def get_nested_class(self, name: str) -> Self | None: - """Returns a nested class by name from the current class. - - Searches through the nested classes defined in the class and returns the first one that matches the given name. - - Args: - name (str): The name of the nested class to find. - - Returns: - Self | None: The nested class if found, None otherwise. - """ - for m in self.nested_classes: - if m.name == name: - return m - return None - - @reader - def get_method(self, name: str) -> TFunction | None: - """Returns a specific method by name from the class or any of its superclasses. - - Searches through the class's methods and its superclasses' methods to find a method with the specified name. - - Args: - name (str): The name of the method to find. - - Returns: - TFunction | None: The method if found, None otherwise. - """ - parents = [self, *self.superclasses] - for c in parents: - if isinstance(c, Class): - for m in c.methods: - if m.name == name: - return m - return None - - @proxy_property - @reader - def attributes(self, *, max_depth: int | None = 0, private: bool = True) -> list[Attribute]: - """Retrieves all attributes from this Class including those from its superclasses up to a - specified depth. - - Args: - max_depth (int | None): The maximum depth of superclass traversal. None means no limit, 0 means only this class. - private (bool): Whether to include private attributes. Defaults to True. - - Returns: - list[Attribute]: A list of unique attributes from this class and its superclasses. If an attribute is defined in - multiple classes, the first definition encountered is used. - """ - parents = [self, *self.superclasses(max_depth=max_depth)] - result = {} - for c in parents: - if isinstance(c, Class): - for m in c.code_block.get_attributes(private): - if m.name not in result: - result[m.name] = m - return list(result.values()) - - @reader - def get_attribute(self, name: str) -> Attribute | None: - """Returns a specific attribute by name. - - Searches for an attribute with the given name in the current class and its superclasses. - - Args: - name (str): The name of the attribute to search for. - - Returns: - Attribute | None: The matching attribute if found, None otherwise. If multiple attributes with the same name exist in the inheritance hierarchy, returns the first one found. - """ - parents = [self, *self.superclasses] - for c in parents: - if isinstance(c, Class): - for m in c.code_block.get_attributes(name): - if m.name == name: - return m - return None - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @abstractmethod - def add_source(self, source: str) -> None: - """Add a block of source code to the bottom of a class definition. - - Adds the provided source code to the end of the class definition, after all existing methods and attributes. - - Args: - source (str): The source code to be added to the class definition. The code should be properly formatted - for class-level insertion. - - Returns: - None - """ - - @writer - def add_attribute_from_source(self, source: str) -> None: - """Adds an attribute to a class from raw source code, placing it in a specific location - based on the class structure. - - This method intelligently places the new attribute after existing attributes and docstrings but before methods to maintain a clean class structure. - - Args: - source (str): The source code of the attribute to be added. - - Returns: - None - """ - attributes = self.attributes - if len(attributes) > 0: - last_attribute = attributes[-1] - last_attribute.insert_after(source, fix_indentation=True) - elif (methods := self.methods) and len(methods) > 0: - first_method = methods[0] - first_method.insert_before(f"{source}\n", fix_indentation=True) - elif len(self.code_block.statements) > 0: - first_statement = self.code_block.statements[0] - first_statement.insert_before(source, fix_indentation=True) - else: - self.code_block.insert_after(source, fix_indentation=True) - - @writer - def add_attribute(self, attribute: Attribute, include_dependencies: bool = False) -> None: - """Adds an attribute to a class from another class. - - This method adds an attribute to a class, optionally including its dependencies. If dependencies are included, it will add any necessary imports to the class's file. - - Args: - attribute (Attribute): The attribute to add to the class. - include_dependencies (bool, optional): Whether to include the attribute's dependencies. If True, adds any necessary imports to the class's file. Defaults to False. - - Returns: - None - """ - # TODO: maybe this should be on Attribute API and renamed to "move_to_class" - # - my preference is to drop it altogether, or combine with add_attribute_from_source - self.add_attribute_from_source(attribute.source) - - if include_dependencies: - deps = attribute.dependencies - file = self.file - for d in deps: - if isinstance(d, Import): - file.add_import(d.imported_symbol) - elif isinstance(d, Symbol): - file.add_import(d) - - @property - @noapidoc - def viz(self) -> VizNode: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - - @noapidoc - @reader - @override - def resolve_attribute(self, name: str) -> Attribute | TFunction | None: - if method := self.get_method(name): - return method - if attr := self.get_attribute(name): - return attr - for c in [self, *self.superclasses]: - if isinstance(c, Class): - for child_class in c.nested_classes: - if child_class.name == name: - return child_class - - @property - def nested_classes(self) -> list[Self]: - """Retrieves the nested classes defined within this class. - - Args: - None - - Returns: - list[Self]: A list of Class objects representing nested class definitions within this class. - """ - symbols = [] - for s in self.code_block.statements: - if s.statement_type == StatementType.SYMBOL_STATEMENT: - if (c := s.symbol) and isinstance(c, Class): - symbols.append(c) - return symbols diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py deleted file mode 100644 index fc3e0557e..000000000 --- a/src/codegen/sdk/core/codebase.py +++ /dev/null @@ -1,1613 +0,0 @@ -"""Codebase - main interface for Codemods to interact with the codebase""" - -import codecs -import json -import os -import re -import tempfile -from collections.abc import Generator -from contextlib import contextmanager -from functools import cached_property -from pathlib import Path -from typing import Generic, Literal, Unpack, overload - -import plotly.graph_objects as go -import rich.repr -from git import Commit as GitCommit -from git import Diff -from git.remote import PushInfoList -from github.PullRequest import PullRequest -from networkx import Graph -from openai import OpenAI -from rich.console import Console -from typing_extensions import TypeVar, deprecated - -from codegen.configs.models.codebase import CodebaseConfig, PinkMode -from codegen.configs.models.secrets import SecretsConfig -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.enums import CheckoutResult -from codegen.git.schemas.repo_config import RepoConfig -from codegen.git.utils.pr_review import CodegenPR -from codegen.sdk._proxy import proxy_property -from codegen.sdk.ai.client import get_openai_client -from codegen.sdk.codebase.codebase_ai import generate_system_prompt, generate_tools -from codegen.sdk.codebase.codebase_context import ( - GLOBAL_FILE_IGNORE_LIST, - CodebaseContext, -) -from codegen.sdk.codebase.config import ProjectConfig, SessionOptions -from codegen.sdk.codebase.diff_lite import DiffLite -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.enums import FlagKwargs -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.io.io import IO -from codegen.sdk.codebase.progress.progress import Progress -from codegen.sdk.codebase.span import Span -from codegen.sdk.core.assignment import Assignment -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.codeowner import CodeOwner -from codegen.sdk.core.detached_symbols.code_block import CodeBlock -from codegen.sdk.core.detached_symbols.parameter import Parameter -from codegen.sdk.core.directory import Directory -from codegen.sdk.core.export import Export -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.file import File, SourceFile -from codegen.sdk.core.function import Function -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.interface import Interface -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.core.type_alias import TypeAlias -from codegen.sdk.enums import NodeType, SymbolType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import uncache_all -from codegen.sdk.output.constants import ANGULAR_STYLE -from codegen.sdk.python.assignment import PyAssignment -from codegen.sdk.python.class_definition import PyClass -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.detached_symbols.parameter import PyParameter -from codegen.sdk.python.file import PyFile -from codegen.sdk.python.function import PyFunction -from codegen.sdk.python.import_resolution import PyImport -from codegen.sdk.python.statements.import_statement import PyImportStatement -from codegen.sdk.python.symbol import PySymbol -from codegen.sdk.typescript.assignment import TSAssignment -from codegen.sdk.typescript.class_definition import TSClass -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.detached_symbols.parameter import TSParameter -from codegen.sdk.typescript.export import TSExport -from codegen.sdk.typescript.file import TSFile -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.import_resolution import TSImport -from codegen.sdk.typescript.interface import TSInterface -from codegen.sdk.typescript.statements.import_statement import TSImportStatement -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.sdk.typescript.type_alias import TSTypeAlias -from codegen.shared.decorators.docs import apidoc, noapidoc, py_noapidoc -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.exceptions.control_flow import MaxAIRequestsError -from codegen.shared.logging.get_logger import get_logger -from codegen.shared.performance.stopwatch_utils import stopwatch -from codegen.visualizations.visualization_manager import VisualizationManager - -logger = get_logger(__name__) -MAX_LINES = 10000 # Maximum number of lines of text allowed to be logged - - -TSourceFile = TypeVar("TSourceFile", bound="SourceFile", default=SourceFile) -TDirectory = TypeVar("TDirectory", bound="Directory", default=Directory) -TSymbol = TypeVar("TSymbol", bound="Symbol", default=Symbol) -TClass = TypeVar("TClass", bound="Class", default=Class) -TFunction = TypeVar("TFunction", bound="Function", default=Function) -TImport = TypeVar("TImport", bound="Import", default=Import) -TGlobalVar = TypeVar("TGlobalVar", bound="Assignment", default=Assignment) -TInterface = TypeVar("TInterface", bound="Interface", default=Interface) -TTypeAlias = TypeVar("TTypeAlias", bound="TypeAlias", default=TypeAlias) -TParameter = TypeVar("TParameter", bound="Parameter", default=Parameter) -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock", default=CodeBlock) -TExport = TypeVar("TExport", bound="Export", default=Export) -TSGlobalVar = TypeVar("TSGlobalVar", bound="Assignment", default=Assignment) -PyGlobalVar = TypeVar("PyGlobalVar", bound="Assignment", default=Assignment) -TSDirectory = Directory[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport] -PyDirectory = Directory[PyFile, PySymbol, PyImportStatement, PyGlobalVar, PyClass, PyFunction, PyImport] - - -@apidoc -class Codebase( - Generic[ - TSourceFile, - TDirectory, - TSymbol, - TClass, - TFunction, - TImport, - TGlobalVar, - TInterface, - TTypeAlias, - TParameter, - TCodeBlock, - ] -): - """This class provides the main entrypoint for most programs to analyzing and manipulating codebases. - - Attributes: - viz: Manages visualization of the codebase graph. - repo_path: The path to the repository. - console: Manages console output for the codebase. - """ - - _op: RepoOperator - viz: VisualizationManager - repo_path: Path - console: Console - - @overload - def __init__( - self, - repo_path: None = None, - *, - language: None = None, - projects: list[ProjectConfig] | ProjectConfig, - config: CodebaseConfig | None = None, - secrets: SecretsConfig | None = None, - io: IO | None = None, - progress: Progress | None = None, - ) -> None: ... - - @overload - def __init__( - self, - repo_path: str, - *, - language: Literal["python", "typescript"] | ProgrammingLanguage | None = None, - projects: None = None, - config: CodebaseConfig | None = None, - secrets: SecretsConfig | None = None, - io: IO | None = None, - progress: Progress | None = None, - ) -> None: ... - - def __init__( - self, - repo_path: str | None = None, - *, - language: Literal["python", "typescript"] | ProgrammingLanguage | None = None, - projects: list[ProjectConfig] | ProjectConfig | None = None, - config: CodebaseConfig | None = None, - secrets: SecretsConfig | None = None, - io: IO | None = None, - progress: Progress | None = None, - ) -> None: - # Sanity check inputs - if repo_path is not None and projects is not None: - msg = "Cannot specify both repo_path and projects" - raise ValueError(msg) - - if repo_path is None and projects is None: - msg = "Must specify either repo_path or projects" - raise ValueError(msg) - - if projects is not None and language is not None: - msg = "Cannot specify both projects and language. Use ProjectConfig.from_path() to create projects with a custom language." - raise ValueError(msg) - - # If projects is a single ProjectConfig, convert it to a list - if isinstance(projects, ProjectConfig): - projects = [projects] - - # Initialize project with repo_path if projects is None - if repo_path is not None: - main_project = ProjectConfig.from_path( - repo_path, - programming_language=ProgrammingLanguage(language.upper()) if language else None, - ) - projects = [main_project] - else: - main_project = projects[0] - - # Initialize codebase - self._op = main_project.repo_operator - self.viz = VisualizationManager(op=self._op) - self.repo_path = Path(self._op.repo_path) - self.ctx = CodebaseContext(projects, config=config, secrets=secrets, io=io, progress=progress) - self.console = Console(record=True, soft_wrap=True) - if self.ctx.config.use_pink != PinkMode.OFF: - import codegen_sdk_pink - - self._pink_codebase = codegen_sdk_pink.Codebase(self.repo_path) - - @noapidoc - def __str__(self) -> str: - return f"" - - @noapidoc - def __repr__(self): - return str(self) - - def __rich_repr__(self) -> rich.repr.Result: - yield "repo", self.ctx.repo_name - yield "nodes", len(self.ctx.nodes) - yield "edges", len(self.ctx.edges) - - __rich_repr__.angular = ANGULAR_STYLE - - @property - @deprecated("Please do not use the local repo operator directly") - @noapidoc - def op(self) -> RepoOperator: - return self._op - - @property - def github(self) -> RepoOperator: - """Access GitHub operations through the repo operator. - - This property provides access to GitHub operations like creating PRs, - working with branches, commenting on PRs, etc. The implementation is built - on top of PyGitHub (python-github library) and provides a simplified interface - for common GitHub operations. - - Returns: - RepoOperator: The repo operator instance that handles GitHub operations. - """ - return self._op - - #################################################################################################################### - # SIMPLE META - #################################################################################################################### - - @property - def name(self) -> str: - """The name of the repository.""" - return self.ctx.repo_name - - @property - def language(self) -> ProgrammingLanguage: - """The programming language of the repository.""" - return self.ctx.programming_language - - #################################################################################################################### - # NODES - #################################################################################################################### - - @noapidoc - def _symbols(self, symbol_type: SymbolType | None = None) -> list[TSymbol | TClass | TFunction | TGlobalVar]: - matches: list[Symbol] = self.ctx.get_nodes(NodeType.SYMBOL) - return [x for x in matches if x.is_top_level and (symbol_type is None or x.symbol_type == symbol_type)] - - # =====[ Node Types ]===== - @overload - def files(self, *, extensions: list[str]) -> list[File]: ... - @overload - def files(self, *, extensions: Literal["*"]) -> list[File]: ... - @overload - def files(self, *, extensions: None = ...) -> list[TSourceFile]: ... - @proxy_property - def files(self, *, extensions: list[str] | Literal["*"] | None = None) -> list[TSourceFile] | list[File]: - """A list property that returns all files in the codebase. - - By default, this only returns source files. Setting `extensions='*'` will return all files in the codebase, and - `extensions=[...]` will return all files with the specified extensions. - - For Python and Typescript repos WITH file parsing enabled, - `extensions='*'` is REQUIRED for listing all non source code files. - Or else, codebase.files will ONLY return source files (e.g. .py, .ts). - - For repos with file parsing disabled or repos with other languages, this will return all files in the codebase. - - Returns all Files in the codebase, sorted alphabetically. For Python codebases, returns PyFiles (python files). - For Typescript codebases, returns TSFiles (typescript files). - - Returns: - list[TSourceFile]: A sorted list of source files in the codebase. - """ - if self.ctx.config.use_pink == PinkMode.ALL_FILES: - return self._pink_codebase.files - if extensions is None and len(self.ctx.get_nodes(NodeType.FILE)) > 0: - # If extensions is None AND there is at least one file in the codebase (This checks for unsupported languages or parse-off repos), - # Return all source files - files = self.ctx.get_nodes(NodeType.FILE) - elif isinstance(extensions, str) and extensions != "*": - msg = "extensions must be a list of extensions or '*'" - raise ValueError(msg) - else: - files = [] - # Get all files with the specified extensions - for filepath, _ in self._op.iter_files( - extensions=None if extensions == "*" else extensions, - ignore_list=GLOBAL_FILE_IGNORE_LIST, - ): - files.append(self.get_file(filepath, optional=False)) - # Sort files alphabetically - return sort_editables(files, alphabetical=True, dedupe=False) - - @cached_property - def codeowners(self) -> list["CodeOwner[TSourceFile]"]: - """List all CodeOnwers in the codebase. - - Returns: - list[CodeOwners]: A list of CodeOwners objects in the codebase. - """ - if self.ctx.codeowners_parser is None: - return [] - return CodeOwner.from_parser( - self.ctx.codeowners_parser, - lambda *args, **kwargs: self.files(*args, **kwargs), - ) - - @property - def directories(self) -> list[TDirectory]: - """List all directories in the codebase. - - Returns a list of all Directory objects present in the codebase. Each Directory object represents a directory in the codebase. - This property is used to access and navigate the directory structure of the codebase. - - Returns: - list[TDirectory]: A list of Directory objects in the codebase. - """ - return list(self.ctx.directories.values()) - - @property - def imports(self) -> list[TImport]: - """Returns a list of all Import nodes in the codebase. - - Retrieves all Import nodes from the codebase graph. These imports represent all import statements across all files in the codebase, - including imports from both internal modules and external packages. - - Args: - None - - Returns: - list[TImport]: A list of Import nodes representing all imports in the codebase. - TImport can be PyImport for Python codebases or TSImport for TypeScript codebases. - """ - return self.ctx.get_nodes(NodeType.IMPORT) - - @property - @py_noapidoc - def exports(self: "TSCodebaseType") -> list[TSExport]: - """Returns a list of all Export nodes in the codebase. - - Retrieves all Export nodes from the codebase graph. These exports represent all export statements across all files in the codebase, - including exports from both internal modules and external packages. This is a TypeScript-only codebase property. - - Args: - None - - Returns: - list[TSExport]: A list of Export nodes representing all exports in the codebase. - TExport can only be a TSExport for TypeScript codebases. - - """ - if self.language == ProgrammingLanguage.PYTHON: - msg = "Exports are not supported for Python codebases since Python does not have an export mechanism." - raise NotImplementedError(msg) - - return self.ctx.get_nodes(NodeType.EXPORT) - - @property - def external_modules(self) -> list[ExternalModule]: - """Returns a list of all external modules in the codebase. - - An external module represents a dependency that is imported but not defined within the codebase itself (e.g. third-party packages like 'requests' or 'numpy'). - - Returns: - list[ExternalModule]: List of external module nodes from the codebase graph. - """ - return self.ctx.get_nodes(NodeType.EXTERNAL) - - @property - def symbols(self) -> list[TSymbol]: - """List of all top-level Symbols (Classes, Functions, etc.) in the codebase. Excludes Class - methods. - - Returns: - list[TSymbol]: A list of Symbol objects of all top-level symbols in the codebase. Includes classes, functions, and global variables defined at the module level, excludes methods. - """ - return self._symbols() - - @property - def classes(self) -> list[TClass]: - """List of all Classes in the codebase. - - Returns a sorted list of all Class nodes in the codebase. Class nodes represent class definitions in source files. - Only includes top-level classes, not inner/nested classes. - - Returns: - list[TClass]: A sorted list of all Class nodes in the codebase. - """ - return sort_editables(self._symbols(symbol_type=SymbolType.Class), dedupe=False) - - @property - def functions(self) -> list[TFunction]: - """List of all Functions in the codebase. - - Returns a sorted list of all top-level Function objects in the codebase, excluding class methods. - - Returns: - list[TFunction]: A list of Function objects representing all functions in the codebase, sorted alphabetically. - """ - return sort_editables(self._symbols(symbol_type=SymbolType.Function), dedupe=False) - - @property - def global_vars(self) -> list[TGlobalVar]: - """List of all GlobalVars in the codebase. - - A GlobalVar represents a global variable assignment in the source code. These are variables defined at the module level. - - Returns: - list[TGlobalVar]: A list of all global variable assignments in the codebase. - """ - return self._symbols(symbol_type=SymbolType.GlobalVar) - - @property - def interfaces(self) -> list[TInterface]: - """Retrieves all interfaces in the codebase. - - Returns a list of all Interface symbols defined at the top-level of source files in the codebase. - This property is only applicable for TypeScript codebases and will return an empty list for Python codebases. - - Returns: - list[TInterface]: A list of Interface objects defined in the codebase's source files. - """ - return self._symbols(symbol_type=SymbolType.Interface) - - @property - def types(self) -> list[TTypeAlias]: - """List of all Types in the codebase (Typescript only). - - Returns a list of all type aliases defined at the top level in the codebase. This method is only applicable - for TypeScript codebases. - - Returns: - list[TTypeAlias]: A list of all type aliases defined in the codebase. - """ - return self._symbols(symbol_type=SymbolType.Type) - - #################################################################################################################### - # EDGES - #################################################################################################################### - # TODO - no utilities needed here at the moment, but revisit - - #################################################################################################################### - # EXTERNAL API - #################################################################################################################### - - def create_file(self, filepath: str, content: str = "", sync: bool = True) -> TSourceFile: - """Creates a new file in the codebase with specified content. - - Args: - filepath (str): The path where the file should be created. - content (str): The content of the file to be created. Defaults to empty string. - sync (bool): Whether to sync the graph after creating the file. Defaults to True. - - Returns: - File: The newly created file object. - - Raises: - ValueError: If the provided content cannot be parsed according to the file extension. - """ - # Check if file already exists - # NOTE: This check is also important to ensure the filepath is valid within the repo! - if self.has_file(filepath): - logger.warning(f"File {filepath} already exists in codebase. Overwriting...") - - file_exts = self.ctx.extensions - # Create file as source file if it has a registered extension - if any(filepath.endswith(ext) for ext in file_exts) and not self.ctx.config.disable_file_parse: - file_cls = self.ctx.node_classes.file_cls - file = file_cls.from_content(filepath, content, self.ctx, sync=sync) - if file is None: - msg = f"Failed to parse file with content {content}. Please make sure the content syntax is valid with respect to the filepath extension." - raise ValueError(msg) - else: - # Create file as non-source file - file = File.from_content(filepath, content, self.ctx, sync=False) - - # This is to make sure we keep track of this file for diff purposes - uncache_all() - return file - - def create_directory(self, dir_path: str, exist_ok: bool = False, parents: bool = False) -> None: - """Creates a directory at the specified path. - - Args: - dir_path (str): The path where the directory should be created. - exist_ok (bool): If True, don't raise an error if the directory already exists. Defaults to False. - parents (bool): If True, create any necessary parent directories. Defaults to False. - - Raises: - FileExistsError: If the directory already exists and exist_ok is False. - """ - # Check if directory already exists - # NOTE: This check is also important to ensure the filepath is valid within the repo! - if self.has_directory(dir_path): - logger.warning(f"Directory {dir_path} already exists in codebase. Overwriting...") - - self.ctx.to_absolute(dir_path).mkdir(parents=parents, exist_ok=exist_ok) - - def has_file(self, filepath: str, ignore_case: bool = False) -> bool: - """Determines if a file exists in the codebase. - - Args: - filepath (str): The filepath to check for existence. - ignore_case (bool): If True, ignore case when checking for file existence. Defaults to False. - - Returns: - bool: True if the file exists in the codebase, False otherwise. - """ - if self.ctx.config.use_pink == PinkMode.ALL_FILES: - absolute_path = self.ctx.to_absolute(filepath) - return self._pink_codebase.has_file(absolute_path) - if self.ctx.config.use_pink == PinkMode.NON_SOURCE_FILES: - if self._pink_codebase.has_file(filepath): - return True - return self.get_file(filepath, optional=True, ignore_case=ignore_case) is not None - - @overload - def get_file(self, filepath: str, *, optional: Literal[False] = ..., ignore_case: bool = ...) -> TSourceFile: ... - @overload - def get_file(self, filepath: str, *, optional: Literal[True], ignore_case: bool = ...) -> TSourceFile | None: ... - def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = False) -> TSourceFile | None: - """Retrieves a file from the codebase by its filepath. - - This method first attempts to find the file in the graph, then checks the filesystem if not found. Files can be either source files (e.g. .py, .ts) or binary files. - - Args: - filepath (str): The path to the file, relative to the codebase root. - optional (bool): If True, return None if file not found. If False, raise ValueError. - ignore_case (bool): If True, ignore case when checking for file existence. Defaults to False. - - Returns: - TSourceFile | None: The source file if found, None if optional=True and file not found. - - Raises: - ValueError: If file not found and optional=False. - """ - if self.ctx.config.use_pink == PinkMode.ALL_FILES: - absolute_path = self.ctx.to_absolute(filepath) - return self._pink_codebase.get_file(absolute_path) - # Try to get the file from the graph first - file = self.ctx.get_file(filepath, ignore_case=ignore_case) - if file is not None: - return file - - # If the file is not in the graph, check the filesystem - absolute_path = self.ctx.to_absolute(filepath) - if self.ctx.io.file_exists(absolute_path): - if self.ctx.config.use_pink != PinkMode.OFF: - if file := self._pink_codebase.get_file(absolute_path): - return file - return self.ctx._get_raw_file_from_path(absolute_path) - # If the file is not in the graph, check the filesystem - if absolute_path.parent.exists(): - for file in absolute_path.parent.iterdir(): - if ignore_case and str(absolute_path).lower() == str(file).lower(): - return self.ctx._get_raw_file_from_path(file) - elif not ignore_case and str(absolute_path) == str(file): - return self.ctx._get_raw_file_from_path(file) - - # If we get here, the file is not found - if not optional: - msg = f"File {filepath} not found in codebase. Use optional=True to return None instead." - raise ValueError(msg) - return None - - def has_directory(self, dir_path: str, ignore_case: bool = False) -> bool: - """Returns a boolean indicating if a directory exists in the codebase. - - Checks if a directory exists at the specified path within the codebase. - - Args: - dir_path (str): The path to the directory to check for, relative to the codebase root. - - Returns: - bool: True if the directory exists in the codebase, False otherwise. - """ - return self.get_directory(dir_path, optional=True, ignore_case=ignore_case) is not None - - def get_directory(self, dir_path: str, optional: bool = False, ignore_case: bool = False) -> TDirectory | None: - """Returns Directory by `dir_path`, or full path to the directory from codebase root. - - Args: - dir_path (str): The path to the directory to retrieve. - optional (bool): If True, return None when directory is not found. If False, raise ValueError. - - Returns: - TDirectory | None: The Directory object if found, None if optional=True and directory not found. - - Raises: - ValueError: If directory not found and optional=False. - """ - # Sanitize the path - dir_path = os.path.normpath(dir_path) - dir_path = "" if dir_path == "." else dir_path - directory = self.ctx.get_directory(self.ctx.to_absolute(dir_path), ignore_case=ignore_case) - if directory is None and not optional: - msg = f"Directory {dir_path} not found in codebase. Use optional=True to return None instead." - raise ValueError(msg) - return directory - - def has_symbol(self, symbol_name: str) -> bool: - """Returns whether a symbol exists in the codebase. - - This method checks if a symbol with the given name exists in the codebase. - - Args: - symbol_name (str): The name of the symbol to look for. - - Returns: - bool: True if a symbol with the given name exists in the codebase, False otherwise. - """ - return any([x.name == symbol_name for x in self.symbols]) - - def get_symbol(self, symbol_name: str, optional: bool = False) -> TSymbol | None: - """Returns a Symbol by name from the codebase. - - Returns the first Symbol that matches the given name. If multiple symbols are found with the same name, raises a ValueError. - If no symbol is found, returns None if optional is True, otherwise raises a ValueError. - - Args: - symbol_name (str): The name of the symbol to find. - optional (bool): If True, returns None when symbol is not found. If False, raises ValueError. Defaults to False. - - Returns: - TSymbol | None: The matched Symbol if found, None if not found and optional=True. - - Raises: - ValueError: If multiple symbols are found with the same name, or if no symbol is found and optional=False. - """ - symbols = self.get_symbols(symbol_name) - if len(symbols) == 0: - if not optional: - msg = f"Symbol {symbol_name} not found in codebase. Use optional=True to return None instead." - raise ValueError(msg) - return None - if len(symbols) > 1: - msg = f"Symbol {symbol_name} is ambiguous in codebase - more than one instance" - raise ValueError(msg) - return symbols[0] - - def get_symbols(self, symbol_name: str) -> list[TSymbol]: - """Retrieves all symbols in the codebase that match the given symbol name. - - This method is used when there may be multiple symbols with the same name, in which case get_symbol() would raise a ValueError. - - Args: - symbol_name (str): The name of the symbols to retrieve. - - Returns: - list[TSymbol]: A list of Symbol objects that match the given name, sorted alphabetically. - - Note: - When a unique symbol is required, use get_symbol() instead. It will raise ValueError if multiple symbols are found. - """ - return sort_editables(x for x in self.symbols if x.name == symbol_name) - - def get_class(self, class_name: str, optional: bool = False) -> TClass | None: - """Returns a class that matches the given name. - - Args: - class_name (str): The name of the class to find. - optional (bool): If True, return None when class is not found instead of raising ValueError. Defaults to False. - - Returns: - TClass | None: The class with the given name, or None if optional=True and class not found. - - Raises: - ValueError: If the class is not found and optional=False, or if multiple classes with the same name exist. - """ - matches = [c for c in self.classes if c.name == class_name] - if len(matches) == 0: - if not optional: - msg = f"Class {class_name} not found in codebase. Use optional=True to return None instead." - raise ValueError(msg) - return None - if len(matches) > 1: - msg = f"Class {class_name} is ambiguous in codebase - more than one instance" - raise ValueError(msg) - return matches[0] - - def get_function(self, function_name: str, optional: bool = False) -> TFunction | None: - """Retrieves a function from the codebase by its name. - - This method searches through all functions in the codebase to find one matching the given name. - If multiple functions with the same name exist, a ValueError is raised. - - Args: - function_name (str): The name of the function to retrieve. - optional (bool): If True, returns None when function is not found instead of raising ValueError. - Defaults to False. - - Returns: - TFunction | None: The matching function if found. If optional=True and no match is found, - returns None. - - Raises: - ValueError: If function is not found and optional=False, or if multiple matching functions exist. - """ - matches = [f for f in self.functions if f.name == function_name] - if len(matches) == 0: - if not optional: - msg = f"Function {function_name} not found in codebase. Use optional=True to return None instead." - raise ValueError(msg) - return None - if len(matches) > 1: - msg = f"Function {function_name} is ambiguous in codebase - more than one instance" - raise ValueError(msg) - return matches[0] - - @noapidoc - @staticmethod - def _remove_extension(filename: str) -> str: - """Removes the trailing extension from the filename if it appears at the end, - e.g. filename.ext -> filename - """ - return re.sub(r"\.[^.]+$", "", filename) - - def get_relative_path(self, from_file: str, to_file: str) -> str: - """Calculates a relative path from one file to another, removing the extension from the target file. - - This method splits both `from_file` and `to_file` by forward slashes, finds their common path prefix, - and determines how many directories to traverse upward from `from_file` before moving into the - remaining directories of `to_file` (with its extension removed). - - Args: - from_file (str): The file path from which the relative path will be computed. - to_file (str): The file path (whose extension will be removed) to which the relative path will be computed. - - Returns: - str: The relative path from `from_file` to `to_file` (with the extension removed from `to_file`). - """ - # Remove extension from the target file - to_file = self._remove_extension(to_file) - - from_parts = from_file.split("/") - to_parts = to_file.split("/") - - # Find common prefix - i = 0 - while i < len(from_parts) - 1 and i < len(to_parts) and from_parts[i] == to_parts[i]: - i += 1 - - # Number of '../' we need - up_levels = len(from_parts) - i - 1 - - # Construct relative path - relative_path = ("../" * up_levels) + "/".join(to_parts[i:]) - - return relative_path - - #################################################################################################################### - # State/Git - #################################################################################################################### - - def git_commit(self, message: str, *, verify: bool = False, exclude_paths: list[str] | None = None) -> GitCommit | None: - """Stages + commits all changes to the codebase and git. - - Args: - message (str): The commit message - verify (bool): Whether to verify the commit before committing. Defaults to False. - - Returns: - GitCommit | None: The commit object if changes were committed, None otherwise. - """ - self.ctx.commit_transactions(sync_graph=False) - if self._op.stage_and_commit_all_changes(message, verify, exclude_paths): - logger.info(f"Commited repository to {self._op.head_commit} on {self._op.get_active_branch_or_commit()}") - return self._op.head_commit - else: - logger.info("No changes to commit") - return None - - def commit(self, sync_graph: bool = True) -> None: - """Commits all staged changes to the codebase graph and synchronizes the graph with the filesystem if specified. - - This method must be called when multiple overlapping edits are made on a single entity to ensure proper tracking of changes. - For example, when renaming a symbol and then moving it to a different file, commit must be called between these operations. - - Args: - sync_graph (bool): Whether to synchronize the graph after committing changes. Defaults to True. - - Returns: - None - """ - self.ctx.commit_transactions(sync_graph=sync_graph and self.ctx.config.sync_enabled) - - @noapidoc - def git_push(self, *args, **kwargs) -> PushInfoList: - """Git push.""" - return self._op.push_changes(*args, **kwargs) - - @property - def default_branch(self) -> str: - """The default branch of this repository. - - Returns the name of the default branch (e.g. 'main' or 'master') for the current repository. - - Returns: - str: The name of the default branch. - """ - return self._op.default_branch - - @property - def current_commit(self) -> GitCommit | None: - """Returns the current Git commit that is checked out in the repository. - - Args: - None - - Returns: - GitCommit | None: The currently checked out Git commit object, or None if no commit is checked out. - """ - return self._op.git_cli.head.commit - - @stopwatch - def reset(self, git_reset: bool = False) -> None: - """Resets the codebase by: - - Discarding any staged/unstaged changes - - Resetting stop codemod limits: (max seconds, max transactions, max AI requests) - - Clearing logs - - Clearing pending transactions + pending files - - Syncing graph to synced_commit - - This will ignore changes to: - - .codegen directory (for codemod development) - - .ipynb files (Jupyter notebooks, where you are likely developing) - """ - logger.info("Resetting codebase ...") - if git_reset: - self._op.discard_changes() # Discard any changes made to the raw file state - self._num_ai_requests = 0 - self.reset_logs() - self.ctx.undo_applied_diffs() - - def checkout( - self, - *, - commit: str | GitCommit | None = None, - branch: str | None = None, - create_if_missing: bool = False, - remote: bool = False, - ) -> CheckoutResult: - """Checks out a git branch or commit and syncs the codebase graph to the new state. - - This method discards any pending changes, performs a git checkout of the specified branch or commit, - and then syncs the codebase graph to reflect the new state. - - Args: - commit (str | GitCommit | None): Hash or GitCommit object to checkout. Cannot be used with branch. - branch (str | None): Name of branch to checkout. Cannot be used with commit. - create_if_missing (bool): If True, creates the branch if it doesn't exist. Defaults to False. - remote (bool): If True, attempts to pull from remote when checking out branch. Defaults to False. - - Returns: - CheckoutResult: The result of the checkout operation. - - Raises: - AssertionError: If neither commit nor branch is specified, or if both are specified. - """ - self.reset() - if commit is None: - assert branch is not None, "Commit or branch must be specified" - logger.info(f"Checking out branch {branch}") - result = self._op.checkout_branch(branch, create_if_missing=create_if_missing, remote=remote) - else: - assert branch is None, "Cannot specify branch and commit" - logger.info(f"Checking out commit {commit}") - result = self._op.checkout_commit(commit_hash=commit) - if result == CheckoutResult.SUCCESS: - logger.info(f"Checked out {branch or commit}") - if self._op.head_commit is None: - logger.info(f"Ref: {self._op.git_cli.head.ref.name} has no commits") - return CheckoutResult.SUCCESS - - self.sync_to_commit(self._op.head_commit) - elif result == CheckoutResult.NOT_FOUND: - logger.info(f"Could not find branch {branch or commit}") - - return result - - @noapidoc - def sync_to_commit(self, target_commit: GitCommit) -> None: - """Updates the current base to a new commit.""" - origin_commit = self.ctx.synced_commit - if origin_commit.hexsha == target_commit.hexsha: - logger.info(f"Codebase is already synced to {target_commit.hexsha}. Skipping sync_to_commit.") - return - if not self.ctx.config.sync_enabled: - logger.info(f"Syncing codebase is disabled for repo {self._op.repo_name}. Skipping sync_to_commit.") - return - - logger.info(f"Syncing {self._op.repo_name} to {target_commit.hexsha}") - diff_index = origin_commit.diff(target_commit) - diff_lites = [] - for diff in diff_index: - diff_lites.append(DiffLite.from_git_diff(diff)) - self.ctx.apply_diffs(diff_lites) - self.ctx.save_commit(target_commit) - - @noapidoc - def get_diffs(self, base: str | None = None) -> list[Diff]: - """Get all changed files.""" - if base is None: - return self._op.get_diffs(self._op.head_commit) - return self._op.get_diffs(base) - - @noapidoc - def get_diff(self, base: str | None = None, stage_files: bool = False) -> str: - """Produce a single git diff for all files.""" - if stage_files: - self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff - if base is None: - diff = self._op.git_cli.git.diff("HEAD", patch=True, full_index=True) - return diff - return self._op.git_cli.git.diff(base, patch=True, full_index=True) - - @noapidoc - def clean_repo(self): - """Cleaning a codebase repo by: - 1. Deleting all branches except the checked out one - 2. Deleting all remotes except origin - - NOTE: doesn't discard changes b/c this should be handled by self.reset - NOTE: doesn't checkout onto the default branch b/c this should be handled by self.checkout - """ - logger.info(f"Cleaning codebase repo at {self.repo_path} ...") - self._op.clean_remotes() - self._op.clean_branches() - - @noapidoc - def stash_changes(self): - """Stash all changes in the codebase.""" - self._op.stash_push() - - @noapidoc - def restore_stashed_changes(self): - """Restore the most recent stash in the codebase.""" - self._op.stash_pop() - - #################################################################################################################### - # GITHUB - #################################################################################################################### - - def create_pr(self, title: str, body: str) -> PullRequest: - """Creates a pull request from the current branch to the repository's default branch. - - This method will: - 1. Stage and commit any pending changes with the PR title as the commit message - 2. Push the current branch to the remote repository - 3. Create a pull request targeting the default branch - - Args: - title (str): The title for the pull request - body (str): The description/body text for the pull request - - Returns: - PullRequest: The created GitHub pull request object - - Raises: - ValueError: If attempting to create a PR while in a detached HEAD state - ValueError: If the current branch is the default branch - """ - if self._op.git_cli.head.is_detached: - msg = "Cannot make a PR from a detached HEAD" - raise ValueError(msg) - if self._op.git_cli.active_branch.name == self._op.default_branch: - msg = "Cannot make a PR from the default branch" - raise ValueError(msg) - self._op.stage_and_commit_all_changes(message=title) - self._op.push_changes() - return self._op.remote_git_repo.create_pull( - head_branch_name=self._op.git_cli.active_branch.name, - base_branch_name=self._op.default_branch, - title=title, - body=body, - ) - - #################################################################################################################### - # GRAPH VISUALIZATION - #################################################################################################################### - - def visualize(self, G: Graph | go.Figure, root: Editable | str | int | None = None) -> None: - """Visualizes a NetworkX graph or Plotly figure. - - Creates a visualization of the provided graph using GraphViz. This is useful for visualizing dependency graphs, call graphs, - directory structures, or other graph-based representations of code relationships. - - Args: - G (Graph | go.Figure): A NetworkX graph or Plotly figure to visualize - root (Editable | str | int | None): The root node to visualize around. When specified, the visualization will be centered on this node. Defaults to None. - - Returns: - None - """ - self.viz.write_graphviz_data(G=G, root=root) - - #################################################################################################################### - # FLAGGING - #################################################################################################################### - - @noapidoc - def flags(self) -> list[CodeFlag]: - """Returns all collected code flags in find mode. - - Returns: - list[CodeFlag]: A list of all flags in the codebase. - """ - return self.ctx.flags._flags - - @noapidoc - def flag_instance( - self, - symbol: TSymbol | None = None, - **kwargs: Unpack[FlagKwargs], - ) -> CodeFlag: - """Flags a symbol, file or import to enable enhanced tracking of changes and splitting into - smaller PRs. - - This method should be called once per flaggable entity and should be called before any edits are made to the entity. - Flags enable tracking of changes and can be used for various purposes like generating pull requests or applying changes selectively. - - Args: - symbol (TSymbol | None): The symbol to flag. Can be None if just flagging a message. - **kwargs: Arguments used to construct the flag - Returns: - CodeFlag: A flag object representing the flagged entity. - """ - return self.ctx.flags.flag_instance(symbol, **kwargs) - - def should_fix(self, flag: CodeFlag) -> bool: - """Returns True if the flag should be fixed based on the current mode and active group. - - Used to filter out flags that are not in the active group and determine if the flag should be processed or ignored. - - Args: - flag (CodeFlag): The code flag to check. - - Returns: - bool: True if the flag should be fixed, False if it should be ignored. - Returns False in find mode. - Returns True if no active group is set. - Returns True if the flag's hash exists in the active group hashes. - """ - return self.ctx.flags.should_fix(flag) - - @noapidoc - def set_find_mode(self, find_mode: bool) -> None: - self.ctx.flags.set_find_mode(find_mode) - - @noapidoc - def set_active_group(self, group: Group) -> None: - """Will only fix these flags.""" - # TODO - flesh this out more with Group datatype and GroupBy - self.ctx.flags.set_active_group(group) - - #################################################################################################################### - # LOGGING - #################################################################################################################### - - _logs = [] - - def __is_markup_loggable__(self, arg) -> bool: - return isinstance(arg, Editable) - - @noapidoc - def log(self, *args) -> None: - """Logs a message as a string. - - At the end, we will save a tail of these logs on the CodemodRun - """ - self.ctx.transaction_manager.check_max_preview_time() - if self.console.export_text(clear=False).count("\n") >= MAX_LINES: - return # if max lines has been reached, skip logging - for arg in args: - if self.__is_markup_loggable__(arg): - fullName = arg.get_name() if isinstance(arg, HasName) and arg.get_name() else "" - doc_lang = arg._api_doc_lang if hasattr(arg, "_api_doc_lang") else None - class_name = arg.__class__.__name__ - link = f"::docs/codebase-sdk/{doc_lang}/{class_name}" if doc_lang is not None else "" - self.console.print(f"{class_name}::{fullName}{link}", markup=True, soft_wrap=True) - args = [arg for arg in args if not self.__is_markup_loggable__(arg)] - if args: - self.console.print(*args, markup=True, soft_wrap=True) - - @noapidoc - def reset_logs(self) -> None: - """Resets the logs.""" - self.console.clear() - - @noapidoc - def get_finalized_logs(self) -> str: - """Returns the logs as a string, truncating if necessary.""" - return self.console.export_text() - - #################################################################################################################### - # INTERNAL UTILS - #################################################################################################################### - - @contextmanager - @noapidoc - def session( - self, - sync_graph: bool = True, - commit: bool = True, - session_options: SessionOptions = SessionOptions(), - ) -> Generator[None, None, None]: - with self.ctx.session(sync_graph=sync_graph, commit=commit, session_options=session_options): - yield None - - @noapidoc - def _enable_experimental_language_engine( - self, - async_start: bool = False, - install_deps: bool = False, - use_v8: bool = False, - ) -> None: - """Debug option to enable experimental language engine for the current codebase.""" - if install_deps and not self.ctx.language_engine: - from codegen.sdk.core.external.dependency_manager import ( - get_dependency_manager, - ) - - logger.info("Cold installing dependencies...") - logger.info("This may take a while for large repos...") - self.ctx.dependency_manager = get_dependency_manager(self.ctx.projects[0].programming_language, self.ctx, enabled=True) - self.ctx.dependency_manager.start(async_start=False) - # Wait for the dependency manager to be ready - self.ctx.dependency_manager.wait_until_ready(ignore_error=False) - logger.info("Dependencies ready") - if not self.ctx.language_engine: - from codegen.sdk.core.external.language_engine import get_language_engine - - logger.info("Cold starting language engine...") - logger.info("This may take a while for large repos...") - self.ctx.language_engine = get_language_engine( - self.ctx.projects[0].programming_language, - self.ctx, - use_ts=True, - use_v8=use_v8, - ) - self.ctx.language_engine.start(async_start=async_start) - # Wait for the language engine to be ready - self.ctx.language_engine.wait_until_ready(ignore_error=False) - logger.info("Language engine ready") - - #################################################################################################################### - # AI - #################################################################################################################### - - _ai_helper: OpenAI = None - _num_ai_requests: int = 0 - - @property - @noapidoc - def ai_client(self) -> OpenAI: - """Enables calling AI/LLM APIs - re-export of the initialized `openai` module""" - # Create a singleton AIHelper instance - if self._ai_helper is None: - if self.ctx.secrets.openai_api_key is None: - msg = "OpenAI key is not set" - raise ValueError(msg) - - self._ai_helper = get_openai_client(key=self.ctx.secrets.openai_api_key) - return self._ai_helper - - def ai( - self, - prompt: str, - target: Editable | None = None, - context: Editable | list[Editable] | dict[str, Editable | list[Editable]] | None = None, - model: str = "gpt-4o", - ) -> str: - """Generates a response from the AI based on the provided prompt, target, and context. - - A method that sends a prompt to the AI client along with optional target and context information to generate a response. - Used for tasks like code generation, refactoring suggestions, and documentation improvements. - - Args: - prompt (str): The text prompt to send to the AI. - target (Editable | None): An optional editable object (like a function, class, etc.) that provides the main focus for the AI's response. - context (Editable | list[Editable] | dict[str, Editable | list[Editable]] | None): Additional context to help inform the AI's response. - model (str): The AI model to use for generating the response. Defaults to "gpt-4o". - - Returns: - str: The generated response from the AI. - - Raises: - MaxAIRequestsError: If the maximum number of allowed AI requests (default 150) has been exceeded. - """ - # Check max transactions - logger.info("Creating call to OpenAI...") - self._num_ai_requests += 1 - if self.ctx.session_options.max_ai_requests is not None and self._num_ai_requests > self.ctx.session_options.max_ai_requests: - logger.info(f"Max AI requests reached: {self.ctx.session_options.max_ai_requests}. Stopping codemod.") - msg = f"Maximum number of AI requests reached: {self.ctx.session_options.max_ai_requests}" - raise MaxAIRequestsError(msg, threshold=self.ctx.session_options.max_ai_requests) - - params = { - "messages": [ - {"role": "system", "content": generate_system_prompt(target, context)}, - {"role": "user", "content": prompt}, - ], - "model": model, - "functions": generate_tools(), - "temperature": 0, - } - if model.startswith("gpt"): - params["tool_choice"] = "required" - - # Make the AI request - response = self.ai_client.chat.completions.create( - model=model, - messages=params["messages"], - tools=params["functions"], # type: ignore - temperature=params["temperature"], - tool_choice=params["tool_choice"], - ) - - # Handle finish reasons - # First check if there is a response - if response.choices: - # Check response reason - choice = response.choices[0] - if choice.finish_reason == "tool_calls" or choice.finish_reason == "function_call" or choice.finish_reason == "stop": - # Check if there is a tool call - if choice.message.tool_calls: - tool_call = choice.message.tool_calls[0] - response_answer = json.loads(tool_call.function.arguments) - if "answer" in response_answer: - response_answer = response_answer["answer"] - else: - msg = "No answer found in tool call. (tool_call.function.arguments does not contain answer)" - raise ValueError(msg) - else: - msg = "No tool call found in AI response. (choice.message.tool_calls is empty)" - raise ValueError(msg) - elif choice.finish_reason == "length": - msg = "AI response too long / ran out of tokens. (choice.finish_reason == length)" - raise ValueError(msg) - elif choice.finish_reason == "content_filter": - msg = "AI response was blocked by OpenAI's content filter. (choice.finish_reason == content_filter)" - raise ValueError(msg) - else: - msg = f"Unknown finish reason from AI: {choice.finish_reason}" - raise ValueError(msg) - else: - msg = "No response from AI Provider. (response.choices is empty)" - raise ValueError(msg) - - # Agent sometimes fucks up and does \\\\n for some reason. - response_answer = codecs.decode(response_answer, "unicode_escape") - logger.info(f"OpenAI response: {response_answer}") - return response_answer - - def set_ai_key(self, key: str) -> None: - """Sets the OpenAI key for the current Codebase instance.""" - # Reset the AI client - self._ai_helper = None - - # Set the AI key - self.ctx.secrets.openai_api_key = key - - def find_by_span(self, span: Span) -> list[Editable]: - """Finds editable objects that overlap with the given source code span. - - Searches for editable objects (like functions, classes, variables) within a file - that overlap with the specified byte range span. Returns an empty list if no - matching file is found. - - Args: - span (Span): The span object containing the filepath and byte range to search within. - - Returns: - list[Editable]: A list of Editable objects that overlap with the given span. - """ - if file := self.get_file(span.filepath): - return file.find_by_byte_range(span.range) - return [] - - def set_session_options(self, **kwargs: Unpack[SessionOptions]) -> None: - """Sets the session options for the current codebase. - - This method updates the session options with the provided keyword arguments and - configures the transaction manager accordingly. It sets the maximum number of - transactions and resets the stopwatch based on the updated session options. - - Args: - **kwargs: Keyword arguments representing the session options to update. - - max_transactions (int, optional): The maximum number of transactions - allowed in a session. - - max_seconds (int, optional): The maximum duration in seconds for a session - before it times out. - - max_ai_requests (int, optional): The maximum number of AI requests - allowed in a session. - """ - self.ctx.session_options = self.ctx.session_options.model_copy(update=kwargs) - self.ctx.transaction_manager.set_max_transactions(self.ctx.session_options.max_transactions) - self.ctx.transaction_manager.reset_stopwatch(self.ctx.session_options.max_seconds) - - @classmethod - def from_repo( - cls, - repo_full_name: str, - *, - tmp_dir: str | None = "/tmp/codegen", - commit: str | None = None, - language: Literal["python", "typescript"] | ProgrammingLanguage | None = None, - config: CodebaseConfig | None = None, - secrets: SecretsConfig | None = None, - full_history: bool = False, - ) -> "Codebase": - """Fetches a codebase from GitHub and returns a Codebase instance. - - Args: - repo_name (str): The name of the repository in format "owner/repo" - tmp_dir (Optional[str]): The directory to clone the repo into. Defaults to /tmp/codegen - commit (Optional[str]): The specific commit hash to clone. Defaults to HEAD - shallow (bool): Whether to do a shallow clone. Defaults to True - language (Literal["python", "typescript"] | ProgrammingLanguage | None): The programming language of the repo. Defaults to None. - config (CodebaseConfig): Configuration for the codebase. Defaults to pre-defined defaults if None. - secrets (SecretsConfig): Configuration for the secrets. Defaults to empty values if None. - - Returns: - Codebase: A Codebase instance initialized with the cloned repository - """ - logger.info(f"Fetching codebase for {repo_full_name}") - - # Parse repo name - if "/" not in repo_full_name: - msg = "repo_name must be in format 'owner/repo'" - raise ValueError(msg) - owner, repo = repo_full_name.split("/") - - # Setup temp directory - os.makedirs(tmp_dir, exist_ok=True) - logger.info(f"Using directory: {tmp_dir}") - - # Setup repo path and URL - repo_path = os.path.join(tmp_dir, repo) - repo_url = f"https://github.com/{repo_full_name}.git" - logger.info(f"Will clone {repo_url} to {repo_path}") - access_token = secrets.github_token if secrets else None - - try: - # Use RepoOperator to fetch the repository - logger.info("Cloning repository...") - if commit is None: - repo_config = RepoConfig.from_repo_path(repo_path) - repo_config.full_name = repo_full_name - repo_operator = RepoOperator.create_from_repo(repo_path=repo_path, url=repo_url, access_token=access_token, full_history=full_history) - else: - # Ensure the operator can handle remote operations - repo_operator = RepoOperator.create_from_commit(repo_path=repo_path, commit=commit, url=repo_url, full_name=repo_full_name, access_token=access_token) - - if repo_operator is None: - logger.error("Failed to clone repository") - return None - - logger.info("Clone completed successfully") - - # Initialize and return codebase with proper context - logger.info("Initializing Codebase...") - project = ProjectConfig.from_repo_operator( - repo_operator=repo_operator, - programming_language=ProgrammingLanguage(language.upper()) if language else None, - ) - codebase = Codebase(projects=[project], config=config, secrets=secrets) - logger.info("Codebase initialization complete") - return codebase - except Exception as e: - logger.exception(f"Failed to initialize codebase: {e}") - raise - - @classmethod - def from_string( - cls, - code: str, - *, - language: Literal["python", "typescript"] | ProgrammingLanguage, - ) -> "Codebase": - """Creates a Codebase instance from a string of code. - - Args: - code: String containing code - language: Language of the code. Defaults to Python. - - Returns: - Codebase: A Codebase instance initialized with the provided code - - Example: - >>> # Python code - >>> code = "def add(a, b): return a + b" - >>> codebase = Codebase.from_string(code, language="python") - - >>> # TypeScript code - >>> code = "function add(a: number, b: number): number { return a + b; }" - >>> codebase = Codebase.from_string(code, language="typescript") - """ - if not language: - msg = "missing required argument language" - raise TypeError(msg) - - logger.info("Creating codebase from string") - - # Determine language and filename - prog_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language - filename = "test.ts" if prog_lang == ProgrammingLanguage.TYPESCRIPT else "test.py" - - # Create codebase using factory - from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory - - files = {filename: code} - - with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir: - logger.info(f"Using directory: {tmp_dir}") - - codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang) - logger.info("Codebase initialization complete") - return codebase - - @classmethod - def from_files( - cls, - files: dict[str, str], - *, - language: Literal["python", "typescript"] | ProgrammingLanguage | None = None, - ) -> "Codebase": - """Creates a Codebase instance from multiple files. - - Args: - files: Dictionary mapping filenames to their content, e.g. {"main.py": "print('hello')"} - language: Optional language override. If not provided, will be inferred from file extensions. - All files must have extensions matching the same language. - - Returns: - Codebase: A Codebase instance initialized with the provided files - - Raises: - ValueError: If file extensions don't match a single language or if explicitly provided - language doesn't match the extensions - - Example: - >>> # Language inferred as Python - >>> files = {"main.py": "print('hello')", "utils.py": "def add(a, b): return a + b"} - >>> codebase = Codebase.from_files(files) - - >>> # Language inferred as TypeScript - >>> files = {"index.ts": "console.log('hello')", "utils.tsx": "export const App = () =>
Hello
"} - >>> codebase = Codebase.from_files(files) - """ - # Create codebase using factory - from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory - - if not files: - msg = "No files provided" - raise ValueError(msg) - - logger.info("Creating codebase from files") - - prog_lang = ProgrammingLanguage.PYTHON # Default language - - if files: - py_extensions = {".py"} - ts_extensions = {".ts", ".tsx", ".js", ".jsx"} - - extensions = {os.path.splitext(f)[1].lower() for f in files} - inferred_lang = None - - # all check to ensure that the from_files method is being used for small testing purposes only. - # If parsing an actual repo, it should not be used. Instead do Codebase("path/to/repo") - if all(ext in py_extensions for ext in extensions): - inferred_lang = ProgrammingLanguage.PYTHON - elif all(ext in ts_extensions for ext in extensions): - inferred_lang = ProgrammingLanguage.TYPESCRIPT - else: - msg = f"Cannot determine single language from extensions: {extensions}. Files must all be Python (.py) or TypeScript (.ts, .tsx, .js, .jsx)" - raise ValueError(msg) - - if language is not None: - explicit_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language - if explicit_lang != inferred_lang: - msg = f"Provided language {explicit_lang} doesn't match inferred language {inferred_lang} from file extensions" - raise ValueError(msg) - - prog_lang = inferred_lang - else: - # Default to Python if no files provided - prog_lang = ProgrammingLanguage.PYTHON if language is None else (ProgrammingLanguage(language.upper()) if isinstance(language, str) else language) - - logger.info(f"Using language: {prog_lang}") - - with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir: - logger.info(f"Using directory: {tmp_dir}") - - # Initialize git repo to avoid "not in a git repository" error - import subprocess - - subprocess.run(["git", "init"], cwd=tmp_dir, check=True, capture_output=True) - - codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang) - logger.info("Codebase initialization complete") - return codebase - - def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str], str]: - """Get all modified symbols in a pull request""" - pr = self._op.get_pull_request(pr_id) - cg_pr = CodegenPR(self._op, self, pr) - patch = cg_pr.get_pr_diff() - commit_sha = cg_pr.get_file_commit_shas() - return patch, commit_sha, cg_pr.modified_symbols, pr.head.ref - - def create_pr_comment(self, pr_number: int, body: str) -> None: - """Create a comment on a pull request""" - return self._op.create_pr_comment(pr_number, body) - - def create_pr_review_comment( - self, - pr_number: int, - body: str, - commit_sha: str, - path: str, - line: int | None = None, - side: str = "RIGHT", - start_line: int | None = None, - ) -> None: - """Create a review comment on a pull request. - - Args: - pr_number: The number of the pull request - body: The body of the comment - commit_sha: The SHA of the commit to comment on - path: The path of the file to comment on - line: The line number to comment on - side: The side of the comment to create - start_line: The start line number to comment on - - Returns: - None - """ - return self._op.create_pr_review_comment(pr_number, body, commit_sha, path, line, side, start_line) - - -# The last 2 lines of code are added to the runner. See codegen-backend/cli/generate/utils.py -# Type Aliases -CodebaseType = Codebase[ - SourceFile, - Directory, - Symbol, - Class, - Function, - Import, - Assignment, - Interface, - TypeAlias, - Parameter, - CodeBlock, -] -PyCodebaseType = Codebase[ - PyFile, - PyDirectory, - PySymbol, - PyClass, - PyFunction, - PyImport, - PyAssignment, - Interface, - TypeAlias, - PyParameter, - PyCodeBlock, -] -TSCodebaseType = Codebase[ - TSFile, - TSDirectory, - TSSymbol, - TSClass, - TSFunction, - TSImport, - TSAssignment, - TSInterface, - TSTypeAlias, - TSParameter, - TSCodeBlock, -] diff --git a/src/codegen/sdk/core/codeowner.py b/src/codegen/sdk/core/codeowner.py deleted file mode 100644 index 8db24cc67..000000000 --- a/src/codegen/sdk/core/codeowner.py +++ /dev/null @@ -1,102 +0,0 @@ -from collections.abc import Iterable, Iterator -from typing import Callable, Generic, Literal - -from codeowners import CodeOwners as CodeOwnersParser - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.interfaces.has_symbols import ( - FilesParam, - HasSymbols, - TClass, - TFile, - TFunction, - TGlobalVar, - TImport, - TImportStatement, - TSymbol, -) -from codegen.sdk.core.utils.cache_utils import cached_generator -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -@apidoc -class CodeOwner( - HasSymbols[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], - Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], -): - """CodeOwner is a class that represents a code owner in a codebase. - - It is used to iterate over all files that are owned by a specific owner. - - Attributes: - owner_type: The type of the owner (USERNAME, TEAM, EMAIL). - owner_value: The value of the owner. - files_source: A callable that returns an iterable of all files in the codebase. - """ - - _instance_iterator: Iterator[TFile] - owner_type: Literal["USERNAME", "TEAM", "EMAIL"] - owner_value: str - files_source: Callable[FilesParam, Iterable[TFile]] - - def __init__( - self, - files_source: Callable[FilesParam, Iterable[TFile]], - owner_type: Literal["USERNAME", "TEAM", "EMAIL"], - owner_value: str, - ): - self.owner_type = owner_type - self.owner_value = owner_value - self.files_source = files_source - - @classmethod - def from_parser( - cls, - parser: CodeOwnersParser, - file_source: Callable[FilesParam, Iterable[TFile]], - ) -> list["CodeOwner"]: - """Create a list of CodeOwner objects from a CodeOwnersParser. - - Args: - parser (CodeOwnersParser): The CodeOwnersParser to use. - file_source (Callable[FilesParam, Iterable[TFile]]): A callable that returns an iterable of all files in the codebase. - - Returns: - list[CodeOwner]: A list of CodeOwner objects. - """ - codeowners = [] - for _, _, owners, _, _ in parser.paths: - for owner_label, owner_value in owners: - codeowners.append(CodeOwner(file_source, owner_label, owner_value)) - return codeowners - - @cached_generator(maxsize=16) - @noapidoc - def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterable[TFile]: - for source_file in self.files_source(*args, **kwargs): - # Filter files by owner value - if self.owner_value in source_file.owners: - yield source_file - - @proxy_property - def files(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterable[TFile]: - """Recursively iterate over all files in the codebase that are owned by the current code owner.""" - return self.files_generator(*args, **kwargs) - - @property - def name(self) -> str: - """The name of the code owner.""" - return self.owner_value - - def __iter__(self) -> Iterator[TFile]: - self._instance_iterator = iter(self.files_generator()) - return self - - def __next__(self) -> str: - return next(self._instance_iterator) - - def __repr__(self) -> str: - return f"CodeOwner(owner_type={self.owner_type}, owner_value={self.owner_value})" diff --git a/src/codegen/sdk/core/dataclasses/usage.py b/src/codegen/sdk/core/dataclasses/usage.py deleted file mode 100644 index 60c44a196..000000000 --- a/src/codegen/sdk/core/dataclasses/usage.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from enum import IntEnum, IntFlag, auto, unique -from typing import TYPE_CHECKING - -from dataclasses_json import dataclass_json - -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.export import Export - from codegen.sdk.core.expressions import Name - from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.symbol import Symbol - - -@apidoc -@dataclass_json -@dataclass(frozen=True) -class Usage: - """A reference to an exportable object in a file. - - Attributes: - match: The exact match of the usage - usage_symbol: The symbol this object is used in - imported_by: The import statement that brought this symbol into scope, or None if not imported - usage_type: How this symbol was used - kind: Where this symbol was used (IE: in a type parameter or in the body of the class, etc) - """ - - match: Name | ChainedAttribute | FunctionCall - usage_symbol: Import | Symbol | Export | SourceFile - imported_by: Import | None - usage_type: UsageType - kind: UsageKind - - -@unique -@apidoc -class UsageType(IntFlag): - """Describes how a symbol is used elsewhere. Used in conjunction with get_usages - - Attributes: - DIRECT: Direct imports and usages within the same file - CHAINED: Chained references (ie: module.foo) - INDIRECT: Indirect usages with the same name - ALIASED: Aliased indirect usages - """ - - DIRECT = auto() - CHAINED = auto() - INDIRECT = auto() - ALIASED = auto() - - -@apidoc -class UsageKind(IntEnum): - """SymbolUsageType is an enumeration class that defines different types of symbol usage within Python code. - - Attributes: - SUBCLASS: Used in symbol inheritance. - TYPED_PARAMETER: Used as a typed parameter in a function/method. - TYPE_ANNOTATION: Used as a type annotation on a parameter or assignment statement. - BODY: Usage within the body of a function/method. - DECORATOR: Usage within a decorator. - RETURN_TYPE: Used as a return type annotation. - TYPE_DEFINITION: Used in a type alias. - EXPORTED_SYMBOL: Used in an export statement. - EXPORTED_WILDCARD: Re-exported by a wildcard export. - GENERIC: Used as a type parameter to another type. - IMPORTED: Imported with an import statement. - IMPORTED_WILDCARD: Imported with a wildcard import statement. - DEFAULT_VALUE: Represents a default value in a function/method parameter. - """ - - SUBCLASS = auto() - TYPED_PARAMETER = auto() - TYPE_ANNOTATION = auto() - BODY = auto() - DECORATOR = auto() - RETURN_TYPE = auto() - TYPE_DEFINITION = auto() - EXPORTED_SYMBOL = auto() - EXPORTED_WILDCARD = auto() - GENERIC = auto() - IMPORTED = auto() - IMPORTED_WILDCARD = auto() - DEFAULT_VALUE = auto() diff --git a/src/codegen/sdk/core/detached_symbols/__init__.py b/src/codegen/sdk/core/detached_symbols/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/core/detached_symbols/argument.py b/src/codegen/sdk/core/detached_symbols/argument.py deleted file mode 100644 index 948619f15..000000000 --- a/src/codegen/sdk/core/detached_symbols/argument.py +++ /dev/null @@ -1,163 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar, override - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.multi_expression import MultiExpression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - - -Parent = TypeVar("Parent", bound="FunctionCall") -TParameter = TypeVar("TParameter", bound="Parameter") - - -@apidoc -class Argument(Expression[Parent], HasName, HasValue, Generic[Parent, TParameter]): - """Represents an argument passed into a FunctionCall.""" - - _pos: int - - def __init__(self, node: TSNode, positional_idx: int, parent: FunctionCall) -> None: - super().__init__(node, parent.file_node_id, parent.ctx, parent) - self._pos = positional_idx - - # TODO: Make the python and typescript implementations into different classes - # Python - if node.type == "keyword_argument": - name_node = node.child_by_field_name("name") - _value_node = node.child_by_field_name("value") - # TypeScript - elif node.type == "assignment_expression": - name_node = node.child_by_field_name("left") - _value_node = node.child_by_field_name("right") - else: - name_node = None - _value_node = node - - self._name_node = self._parse_expression(name_node, default=Name) - self._value_node = self._parse_expression(_value_node) - - def __repr__(self) -> str: - keyword = f"keyword={self.name}, " if self.name else "" - value = f"value='{self.value}', " if self.value else "" - type = f"type={self.type}" if self.type else "" - - return f"Argument({keyword}{value}{type})" - - @noapidoc - @classmethod - def from_argument_list(cls, node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: FunctionCall) -> MultiExpression[Parent, Argument]: - args = [Argument(x, file_node_id, ctx, parent, i) for i, x in enumerate(node.named_children) if x.type != "comment"] - return MultiExpression(node, file_node_id, ctx, parent, expressions=args) - - @property - @reader - def index(self) -> int: - """Returns the zero-based index of this argument within its parent function call. - - Args: - None - - Returns: - int: The zero-based position of this argument in the function call's argument list. - """ - return self._pos - - @property - @reader - def type(self) -> str: - """Gets the `Tree-sitter` type of the argument's value node. - - Returns the type string of the underlying TreeSitter node that represents this argument's value. - This can be useful for checking if the argument is a specific type of expression or literal. - - Returns: - str: The TreeSitter node type of the argument's value. - """ - return self._value_node.ts_node.type - - @property - @reader - def is_named(self) -> bool: - """Determines if an argument is being passed as a named keyword argument. - - Args: - None - - Returns: - bool: True if the argument is being passed with a name (e.g., param=value), False if it's a positional argument. - """ - return self.name is not None - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def add_keyword(self, keyword: str) -> None: - """Converts an unnamed argument to a named argument by adding a keyword. - - Adds the specified keyword to an unnamed argument in a function call, making it a named argument. - For example, turning a positional argument 'value' into a named argument 'param=value'. - - Args: - keyword (str): The keyword name to be added to the argument. - - Raises: - ValueError: If the argument is already a named argument. - """ - if self.is_named: - msg = f"Argument {self.source} already has a keyword argument at file {self.file_node_id}" - raise ValueError(msg) - - self.insert_before(f"{keyword}=", newline=False) - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - if value := self.value: - value._compute_dependencies(usage_type, dest) - - @property - @reader - @noapidoc - def parameter(self) -> TParameter | None: - """Provides access to the corresponding Parameter (defined on the Function being called) for this argument.""" - if self.is_named: - return self.parent.find_parameter_by_name(self.name) - return self.parent.find_parameter_by_index(self.index) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns a list of function calls present in the value of this argument. - - Retrieves all function call nodes that are present within the value of this argument. This is useful for call graph analysis and tracking function usage within arguments. - - Returns: - list[FunctionCall]: A list containing all function calls within the argument's value. - """ - return self.value.function_calls - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - if self.value: - return self.value.descendant_symbols - return [] diff --git a/src/codegen/sdk/core/detached_symbols/code_block.py b/src/codegen/sdk/core/detached_symbols/code_block.py deleted file mode 100644 index a5fde4d62..000000000 --- a/src/codegen/sdk/core/detached_symbols/code_block.py +++ /dev/null @@ -1,555 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from collections import deque -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from typing_extensions import deprecated - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind, UsageType -from codegen.sdk.core.expressions import Expression, Value -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import find_line_start_and_end_nodes -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.has_block import HasBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.statements.assignment_statement import AssignmentStatement - from codegen.sdk.core.statements.attribute import Attribute - from codegen.sdk.core.statements.comment import Comment - from codegen.sdk.core.statements.if_block_statement import IfBlockStatement - from codegen.sdk.core.statements.return_statement import ReturnStatement - from codegen.sdk.core.statements.symbol_statement import SymbolStatement - from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection - from codegen.sdk.output.ast import AST - - -Parent = TypeVar("Parent", bound="HasBlock") -TAssignment = TypeVar("TAssignment", bound="Assignment") - - -@apidoc -class CodeBlock(Expression[Parent], Generic[Parent, TAssignment]): - """Container class for a list of code Statements that share an indentation level, e.g. a - function body or class body. - - Enables various types of queries and operations on the code block. - - Attributes: - level: The indentation level of the code block. - parent_block: The parent code block containing this block, or None if it is a top-level block. - """ - - level: int - parent_block: CodeBlock | None - _statements: MultiLineCollection[Statement, Self] - - def __init__(self, ts_node: TSNode, level: int, parent_block: CodeBlock | None, parent: Parent) -> None: - super().__init__(ts_node, parent.file_node_id, parent.ctx, parent) - self.parent_block = parent_block - self.level = level - # self.parse() - - @noapidoc - def parse(self) -> None: - self._statements = self._parse_statements() - - @abstractmethod - @noapidoc - def _parse_statements(self) -> MultiLineCollection[Statement, Self]: - """Parses top level statements in the code block.""" - - @property - @reader - def statements(self) -> MultiLineCollection[Statement, Self]: - """Gets a view of the top-level statements in the code block. - - Returns a collection of statements that appear directly in this code block, ordered by their appearance. - This does not include statements nested within other blocks (e.g., if statements, functions). - - Returns: - MultiLineCollection[Statement, Self]: An ordered collection of top-level statements in the code block. - """ - return self._statements - - @reader - def _get_statements(self, statement_type: StatementType | None = None, max_level: int | None = None) -> Generator[Statement[Self]]: - """Private implementation of get_statements that returns a generator of statements.""" - queue = deque([(self._statements, self.level)]) - while queue: - current_statements, level = queue.popleft() - - for statement in current_statements: - if statement_type is None or statement.statement_type == statement_type: - yield statement - if statement.statement_type == StatementType.SYMBOL_STATEMENT: - continue - if max_level is None or level < max_level: - for nested_statements in statement.nested_statements: - queue.append((nested_statements.symbols, level + 1)) - - @reader - def get_statements(self, statement_type: StatementType | None = None, max_level: int | None = None) -> list[Statement[Self]]: - """Returns all statements of a given type up to the specified block level. - - This method retrieves statements from the code block and its nested blocks. Statements can be filtered by type and depth. - - Args: - statement_type (StatementType | None): The type of statements to return. If None, returns all statement types. - max_level (int | None): The maximum block depth level to search. If None, searches all levels. - - Returns: - A sorted list of matching statements. - """ - return sort_editables(self._get_statements(statement_type, max_level)) - - @property - @reader - def symbol_statements(self) -> list[SymbolStatement]: - """Returns list of top level symbol statements in the code block. - - Retrieves all statements in the block that have a statement type of SYMBOL_STATEMENT. - Symbol statements are statements that declare or manipulate symbols like functions or classes. - - Returns: - list[SymbolStatement]: A list of all the symbol statements at the top level of this code block. - """ - return [x for x in self.statements if x.statement_type == StatementType.SYMBOL_STATEMENT] - - @property - @reader - def comments(self) -> list[Comment[Parent, Self]]: - """Gets list of top level comments in the code block. - - Returns a list of comment statements that occur at the top level of this code block. Does not include nested comments. - - Returns: - list[Comment[Parent, Self]]: A list of Comment objects that are immediate children of this code block. - """ - return [x for x in self.statements if x.statement_type == StatementType.COMMENT] - - @reader - def get_comment(self, comment_src: str) -> Comment[Parent, Self] | None: - """Gets the first comment statement containing a specific text string. - - Searches through all nested statement levels in the code block to find a comment that contains - the specified text. - - Args: - comment_src (str): The text string to search for within comment statements. - - Returns: - Comment[Parent, Self] | None: The first comment statement containing the search text, - or None if no matching comment is found. - """ - return next((x for x in self._get_statements(StatementType.COMMENT) if comment_src in x.source), None) - - @property - @reader - def if_blocks(self) -> list[IfBlockStatement[Self]]: - """Returns a list of top level if statements in the code block. - - A property that retrieves all the immediate if statements within this code block. - These are if statements that exist at the same indentation level as other statements in the block, not nested ones. - - Returns: - list[IfBlockStatement[Parent, Self]]: A list of top-level if statement objects in the code block. - """ - return [x for x in self.statements if x.statement_type == StatementType.IF_BLOCK_STATEMENT] - - @property - @reader - def attributes(self) -> list[Attribute[Parent, Self]]: - """Returns a list of top level class attribute statements in the code block. - - Get all attribute statements (Attribute objects) that are direct children of the current code block. - These represent class-level attribute declarations. - - Returns: - list[Attribute[Parent, Self]]: A list of Attribute objects representing the class-level attributes, - ordered by their appearance in the code block. - """ - return [x for x in self.statements if x.statement_type == StatementType.CLASS_ATTRIBUTE] - - @reader - def get_attributes(self, private: bool) -> list[Attribute[Parent, Self]]: - """Returns attributes from the code block, with the option to include or exclude private - attributes. - - Retrieves a list of top level attribute statements from the code block, filtering based on the private parameter. - When private is True, both private and public attributes are returned. When private is False, only public - attributes are returned. - - Args: - private (bool): Whether to include private attributes in the returned list. If True, returns both private and - public attributes. If False, returns only public attributes. - - Returns: - list[Attribute[Parent, Self]]: A list of attribute statements matching the privacy criteria. - """ - return [x for x in self.attributes if not x.is_private or private] - - @property - @reader - def assignment_statements(self) -> list[AssignmentStatement[Self, TAssignment]]: - """Returns list of top level assignment statements in the code block. - - Retrieves all statements in the code block whose type is AssignmentStatement. These statements represent direct assignments - at the current code block level (not nested within other blocks). - - Returns: - A list of assignment statements found at the top level of the code block. - """ - return [x for x in self.statements if x.statement_type == StatementType.ASSIGNMENT] - - @property - @reader - def return_statements(self) -> list[ReturnStatement[Self]]: - """Gets all return statements at the top level of the code block. - - Args: - None - - Returns: - list[ReturnStatement[Parent, Self]]: A list of return statements that appear at the top level of the code block. Does not include return statements in nested blocks. - """ - return [x for x in self.statements if x.statement_type == StatementType.RETURN_STATEMENT] - - @property - @reader - def assignments(self) -> list[Assignment[Parent, Self]]: - """Returns all assignments in the code block across all nesting levels. - - Gets every assignment from every assignment statement in the code block, including assignments in nested blocks. - - Returns: - list[Assignment[Parent, Self]]: A list of Assignment objects from all nested levels of the code block. - """ - variables = [] - for statement in self._get_statements(StatementType.ASSIGNMENT): - variables.extend([x for x in statement.assignments]) - return variables - - @reader - def get_assignments(self, var_name: str, *, fuzzy: bool = False, parameters: bool = False) -> list[Assignment[Parent, Self]]: - """Returns a list of assignments with the specified variable name. - - Returns all assignments in the code block that match the given variable name. - - Args: - var_name (str): The name of the variable to find assignments for. - - Returns: - list[Assignment[Parent, Self]]: A list of Assignment objects that match the variable name. - """ - assignments = list(self.parent.parameters) + self.assignments if parameters else self.assignments - - return [x for x in assignments if (var_name in x.name if fuzzy else x.name == var_name)] - - @property - @reader - def local_var_assignments(self) -> list[Assignment[Parent, Self]]: - """Returns all local variable assignment in the code block, for all nest levels. - - A property that returns all variable assignments that are marked as local variables within the code block, - including assignments in nested code blocks. - - Returns: - list[Assignment[Parent, Self]]: A list of Assignment objects representing local variable assignments. - """ - return [x for x in self.assignments if x.is_local_variable] - - @reader - def get_local_var_assignment(self, var_name: str) -> Assignment[Parent, Self] | None: - """Returns the first code statement that assigns a local variable with the specified name. - - Searches through all local variable assignments in the code block and returns the first one that matches - the given variable name. - - Args: - var_name (str): The name of the local variable to search for. - - Returns: - Assignment[Parent, Self] | None: The first matching local variable assignment, or None if no match is found. - """ - return next((x for x in self.local_var_assignments if x.name == var_name), None) - - @reader - def get_local_var_assignments(self, var_name: str, fuzzy_match: bool = False) -> list[Assignment[Parent, Self]]: - """Returns all instances of local variable assignments that match the specified variable - name. - - Finds local variable assignments within the code block that match the provided variable name, with optional fuzzy matching. - - Args: - var_name (str): The name of the local variable to search for. - fuzzy_match (bool, optional): If True, matches variables whose names contain var_name. - If False, only matches exact variable names. Defaults to False. - - - Returns: - list[Assignment[Parent, Self]]: List of Assignment objects representing local variable assignments - that match the specified name criteria. - """ - return [x for x in self.local_var_assignments if (var_name in x.name if fuzzy_match else var_name == x.name)] - - @reader - def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> list[Editable[Self]]: - """Returns all instances of variable usages in a code block. - - This method searches through all statements in the code block to find variable usages that match the specified variable name. - Variable usages are instances where the variable is referenced or used in expressions, function calls, or other code constructs. - - Args: - var_name (str): The name of the variable to search for. - fuzzy_match (bool): When True, matches on variable names that contain var_name. When False (default), only matches exact variable names. - - Returns: - list[Editable[Self]]: A sorted list of variable usage instances as Editable objects. - """ - usages = list() - for assignment in self.get_assignments(var_name, fuzzy=fuzzy_match, parameters=True): - usages.extend(usage.match for usage in assignment.usages(UsageType.DIRECT | UsageType.CHAINED)) - return sort_editables(usages) - - @writer - def rename_variable_usages(self, old_var_name: str, new_var_name: str, fuzzy_match: bool = False) -> None: - """Renames all instances of variable usages in the code block. - - This method modifies variable usages in the code block by replacing occurrences of the old variable name with a new one. - It uses get_assignments() and rename() internally to find all instances of the variable. - - Args: - old_var_name (str): The current name of the variable to rename. - new_var_name (str): The new name to give the variable. - fuzzy_match (bool): When True, matches variables containing old_var_name. When False, only exact matches. Defaults to False. - - Returns: - None: This method mutates the code block in place. - """ - for assignment in self.get_assignments(old_var_name, fuzzy=fuzzy_match, parameters=True): - assignment.rename(assignment.name.replace(old_var_name, new_var_name)) - - @deprecated("Use `self.statements.insert(0, ...)` instead.") - @writer - def insert_before(self, new_src: str) -> None: - """Inserts new source code at the top of the code block. - - This method has been deprecated in favor of using `self.statements.insert(0, ...)`. - - Args: - new_src (str): The source code to insert at the top of the code block. - - Returns: - None - """ - start_lines = self._get_line_starts() - start_line = start_lines[0] - start_line.insert_before(new_src, fix_indentation=True, newline=True) - - @deprecated("Use `self.statements.append(...)` instead.") - @writer - def insert_after(self, new_src: str, fix_indentation=True, newline=True) -> None: - """Inserts source code at the bottom of the code block. - - This method is deprecated. Use `self.statements.append(...)` instead. - - Args: - new_src (str): The source code to insert. - fix_indentation (bool): Whether to fix the indentation of the inserted code. Defaults to True. - newline (bool): Whether to add a newline before the inserted code. Defaults to True. - - Returns: - None - """ - if fix_indentation is False: - super().insert_after(new_src, fix_indentation=fix_indentation, newline=newline) - end_lines = self._get_line_ends() - end_line = end_lines[-1] - end_line.insert_after(new_src, fix_indentation=fix_indentation, newline=newline) - - @writer - def indent(self, level: int) -> None: - """Adjusts the indentation level of the entire code block. - - Modifies the indentation of all lines in the code block by adding or removing spaces at the start of each line. - The amount of indentation per level is determined by either the existing indentation level or defaults to 4 spaces. - - Args: - level (int): The number of indentation levels to adjust. Positive values indent right, negative values indent left. - - Returns: - None - """ - if level == 0: - return - - start_lines = self._get_line_starts() - indent_size = int(start_lines[0].start_point[1] / self.level) if self.level > 0 else 4 - total_indent_size = indent_size * abs(level) - for start_node in start_lines: - if level < 0: - (_, column) = start_node.start_point - new_column = max(0, column - total_indent_size) - offset = column - new_column - start_node.remove_byte_range(start_node.start_byte - offset, start_node.start_byte) - else: - start_node.insert_before(" " * total_indent_size, newline=False) - - @writer - def wrap(self, before_src: str, after_src: str = "") -> None: - """Wraps a code block with a statement and indents it. - - This method wraps an existing code block with a preceding statement (and optionally a following statement), - and indents the block appropriately. Common use cases include wrapping code blocks with if statements, - try/except blocks, with statements, or other control flow structures. - - Args: - before_src (str): The source code to insert before the block. - after_src (str): The source code to insert after the block. Defaults to an empty string. - - Returns: - None - """ - # Step 1: Add before_src before the block - self.insert_before(before_src) - - # Step 2: Add after_src before the block - if after_src: - self.insert_after(after_src) - - # Step 3: Indent the block - self.indent(1) - - @writer - def unwrap(self) -> None: - """Extracts a code block from its parent wrapper container by removing the wrapping - statement and adjusting indentation. - - This method unwraps a code block from its parent container (like if statements, with statements, function definitions, etc.) - by removing the parent wrapper code and unindenting the block content. - - This method handles two cases: - 1. When the wrapper is the only statement on its line - 2. When the wrapper shares a line with other statements - - For example, transforming: - if a: - return b - into: - return b - - Args: - None - - Returns: - None - """ - self.indent(-1) - - # If the wrapper doesn't start at the beginning of the line, only remove up to the end of the wrapper - wrapper_row = self.ts_node.parent.start_point[0] - if (prev_sibling := self.ts_node.parent.prev_sibling) is not None and prev_sibling.start_point[0] == wrapper_row: - while prev_sibling.prev_sibling and prev_sibling.prev_sibling.start_point[0] == wrapper_row: - prev_sibling = prev_sibling.prev_sibling - - remove_start_byte = prev_sibling.start_byte - 1 - wrapper_line_nodes = find_line_start_and_end_nodes(self.ts_node.parent) - wrapper_end_row = self.statements[0].start_point[0] - 1 - wrapper_end_node = next(x[1] for x in wrapper_line_nodes if x[1].start_point[0] == wrapper_end_row) - self.remove_byte_range(remove_start_byte, wrapper_end_node.end_byte) - - # Else, remove the entire top wrapper up to the start of the block - else: - self.remove_byte_range(self.ts_node.parent.start_byte, self.statements[0].start_byte) - - @reader - @noapidoc - def _get_line_starts(self) -> list[Editable]: - """Returns an ordered list of nodes located at the left-most of each line in the code block - eg. - - Given the code: - ``` - def foo(): - x = 1 - y = 2 - ``` - returns [Node(def foo():), Node(x), Node(y)] - """ - starts = [] - for comment in self.get_statements(statement_type=StatementType.COMMENT, max_level=self.level): - if comment.start_byte < self.start_byte: - starts.append(comment) - starts.extend([Value(x[0], self.file_node_id, self.ctx, self) for x in find_line_start_and_end_nodes(self.ts_node)]) - return starts - - @reader - @noapidoc - def _get_line_ends(self) -> list[Editable]: - """Returns an ordered list of nodes located at the right-most of each line in the code block - eg. - - Given the code: - ``` - def foo(): - x = 1 - y = 2 - ``` - returns [Node(def foo():), Node(1), Node(2)] - """ - ends = [] - for comment in self.get_statements(statement_type=StatementType.COMMENT, max_level=self.level): - if comment.start_byte < self.start_byte: - ends.append(comment) - ends.extend([Value(x[1], self.file_node_id, self.ctx, self) for x in find_line_start_and_end_nodes(self.ts_node)]) - return ends - - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - dest = dest or self.parent.self_dest - for statement in self.statements: - statement._compute_dependencies(UsageKind.BODY, dest) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns a list of all function calls in the code block. - - Gets a list of all function calls in the code block, including those within nested statements. The function calls are ordered by their appearance in the code block. - - Returns: - list[FunctionCall]: A list of FunctionCall objects representing all function calls in the code block. - """ - fcalls = [] - for s in self.statements: - fcalls.extend(s.function_calls) - return fcalls - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = [] - for s in self.get_statements(): - symbols.extend(s.descendant_symbols) - return symbols - - def _smart_remove(self, child, *args, **kwargs) -> bool: - if len(self.statements) <= 1 and self.level > 0: - self.parent.remove(*args, **kwargs) - return True - return False - - @override - def _get_ast_children(self) -> list[tuple[str | None, AST]]: - return [("statements", self._statements.ast())] diff --git a/src/codegen/sdk/core/detached_symbols/decorator.py b/src/codegen/sdk/core/detached_symbols/decorator.py deleted file mode 100644 index 5a756b392..000000000 --- a/src/codegen/sdk/core/detached_symbols/decorator.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic - -from typing_extensions import TypeVar - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.function import Function - - -TClass = TypeVar("TClass", bound="Class", default="Class") -TFunction = TypeVar("TFunction", bound="Function", default="Function") -TParameter = TypeVar("TParameter", bound="Parameter", default="Parameter") - - -@apidoc -class Decorator(Expression[TClass | TFunction], HasName, Generic[TClass, TFunction, TParameter]): - """Abstract representation of a Decorator.""" - - def __init__(self, ts_node: TSNode, parent: TClass | TFunction) -> None: - super().__init__(ts_node, parent.file_node_id, parent.ctx, parent) - self._name_node = self._parse_expression(self._get_name_node(), default=Name) - - @abstractmethod - @reader - @noapidoc - def _get_name_node(self) -> TSNode: - """Returns the TSNode of the name of the decorator.""" - - @property - @reader - @abstractmethod - def call(self) -> FunctionCall | None: - """Returns any function call made by this decorator. - - This property identifies whether a decorator makes a function call and provides access to the call details. - - Returns: - FunctionCall | None: The FunctionCall object representing the function call made by the decorator if one exists, - None if the decorator does not make a function call. - """ - - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self._add_all_identifier_usages(UsageKind.DECORATOR, dest or self.parent.self_dest) diff --git a/src/codegen/sdk/core/detached_symbols/function_call.py b/src/codegen/sdk/core/detached_symbols/function_call.py deleted file mode 100644 index 4507e65ff..000000000 --- a/src/codegen/sdk/core/detached_symbols/function_call.py +++ /dev/null @@ -1,721 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader, remover, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.detached_symbols.argument import Argument -from codegen.sdk.core.expressions import Expression, Name, Value -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.core.expressions.unpack import Unpack -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.resolvable import Resolvable -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.enums import NodeType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import cached_property, is_descendant_of -from codegen.sdk.typescript.detached_symbols.promise_chain import TSPromiseChain -from codegen.sdk.typescript.enums import TSFunctionTypeNames -from codegen.sdk.utils import find_first_ancestor -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.function import Function - from codegen.sdk.core.interfaces.callable import Callable - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.visualizations.enums import VizNode - -Parent = TypeVar("Parent", bound="Expression | None") - - -@apidoc -class FunctionCall(Expression[Parent], HasName, Resolvable, Generic[Parent]): - """Abstract representation of a function invocation, e.g. in Python: - ``` - def f(): - g() # FunctionCall - ``` - """ - - _arg_list: Collection[Argument, Self] - - def __init__(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - super().__init__(node, file_node_id, ctx, parent) - # =====[ Grab the function name ]===== - self._name_node = self.child_by_field_name("function", default=Name) or self.child_by_field_name("constructor", default=Name) - if self._name_node is not None and self._name_node.ts_node.type in ("unary_expression", "await_expression"): - self._name_node = self._parse_expression(self._name_node.ts_node.children[-1], default=Name) - # =====[ Grab the arg list ]===== - arg_list_node = node.child_by_field_name("arguments") - if arg_list_node is None: - msg = f"Failed to parse function call. Child 'argument_list' node does not exist. Source: {self.source}" - raise ValueError(msg) - args = [Argument(x, i, self) for i, x in enumerate(arg_list_node.named_children) if x.type != "comment"] - self._arg_list = Collection(arg_list_node, self.file_node_id, self.ctx, self, children=args) - - def __repr__(self) -> str: - """Custom string representation showing the function call chain structure. - - Format: FunctionCall(name=current, pred=pred_name, succ=succ_name, base=base_name) - - It will only print out predecessor, successor, and base that are of type FunctionCall. If it's a property, it will not be logged - """ - # Helper to safely get name - - # Get names for each part - parts = [f"name='{self.name}'"] - - if self.predecessor and isinstance(self.predecessor, FunctionCall): - parts.append(f"predecessor=FunctionCall(name='{self.predecessor.name}')") - - if self.successor and isinstance(self.successor, FunctionCall): - parts.append(f"successor=FunctionCall(name='{self.successor.name}')") - - parts.append(f"filepath='{self.file.filepath}'") - - return f"FunctionCall({', '.join(parts)})" - - @classmethod - def from_usage(cls, node: Editable[Parent], parent: Parent | None = None) -> Self | None: - """Creates a FunctionCall object from an Editable instance that represents a function call. - - Takes an Editable node that potentially represents a function call and creates a FunctionCall object from it. - Useful when working with search results from the Editable API that may contain function calls. - - Args: - node (Editable[Parent]): The Editable node that potentially represents a function call. - parent (Parent | None): The parent node for the new FunctionCall. If None, uses the parent from the input node. - - Returns: - Self | None: A new FunctionCall object if the input node represents a function call, None otherwise. - """ - call_node = find_first_ancestor(node.ts_node, ["call", "call_expression"]) - if call_node is None: - return None - return cls(call_node, node.file_node_id, node.ctx, parent or node.parent) - - @property - @reader - def parent_function(self) -> Function | None: - """Retrieves the parent function of the current function call. - - Returns the Function object that contains this function call, useful for understanding the context in which a function call is made. - - Returns: - Function | None: The parent Function object containing this function call, or None if not found or if the function call is not within a function. - """ - # HACK: This is temporary until we establish a full parent path - if self.file.programming_language == ProgrammingLanguage.TYPESCRIPT: - if func := find_first_ancestor(self.ts_node, [function_type.value for function_type in TSFunctionTypeNames]): - from codegen.sdk.typescript.function import TSFunction - - return TSFunction.from_function_type(func, self.file_node_id, self.ctx, self.parent) - elif self.file.programming_language == ProgrammingLanguage.PYTHON: - if func := find_first_ancestor(self.ts_node, ["function_definition"]): - return self.ctx.node_classes.function_cls(func, self.file_node_id, self.ctx, self.parent) - - return None - - @property - @reader - def is_awaited(self) -> bool: - """Determine if this function call is ultimately awaited in the TypeScript AST. - - This method returns ``True`` if one of the following conditions is met: - * The call is directly under an ``await_expression`` (i.e., `await foo()`). - * The call is part of another function call's argument list where that parent call is awaited. - * The call is inside an arrow function (block or single-expression) that returns it, - and that arrow function is ultimately passed to an awaited call. The arrow - function does not need to be marked ``async``. - - Specifically: - 1. The method first checks if the nearest non-parenthesized ancestor is an - ``await_expression``. - 2. If not, it looks for the nearest parent function call. If there is none, - the call is not awaited. - 3. If there is a parent call and it is awaited, the method checks whether this - function call is “returned” (explicitly or implicitly) up the chain toward - that awaited call. - - Returns: - bool: ``True`` if this function call is considered awaited (directly or indirectly), - otherwise ``False``. - """ - # 1) Direct check: are we directly under an 'await' node? - ancestor = self.ts_node.parent - while ancestor and ancestor.type == "parenthesized_expression": - ancestor = ancestor.parent - if ancestor and ancestor.type in ("await_expression", "await"): - return True - - # 2) Find the nearest parent call - nearest_call = None - arrow_nodes = [] - is_returned = False - - node = self.ts_node.parent - while node: - if node.type in ("call_expression", "call"): - nearest_call = node - break - - if node.type == "arrow_function": - arrow_nodes.append(node) - elif node.type == "return_statement": - is_returned = True - - node = node.parent - - if not nearest_call: - return False - - # 3) Check if the nearest parent call is awaited - parent_call_obj = FunctionCall(nearest_call, self.file_node_id, self.ctx, None) - if not parent_call_obj.is_awaited: - return False - - # If we have no arrow boundaries in between, we're certainly awaited - if not arrow_nodes: - return True - - # Otherwise, check if we're effectively returned (explicitly or implicitly) in the arrow callbacks - if is_returned: - return True - - for arrow_node in arrow_nodes: - arrow_body = arrow_node.child_by_field_name("body") - if arrow_body: - # Single-expression arrow => implicitly returns the entire expression - if arrow_body.type != "statement_block": - if is_descendant_of(arrow_body, self.ts_node): - return True - # If it's a block body, rely on is_returned above - - return False - - @writer - def asyncify(self) -> None: - """Converts the function call to an async function call by wrapping it with 'await'. - - This method adds 'await' syntax to a function call if it is not already awaited. It wraps the function call in parentheses and prefixes it with 'await'. - - Args: - None - - Returns: - None - """ - if self.is_awaited: - return - self.insert_before("await (", newline=False) - self.insert_after(")", newline=False) - - @property - @reader - def predecessor(self) -> FunctionCall[Parent] | None: - """Returns the previous function call in a function call chain. - - Returns the previous function call in a function call chain. This method is useful for traversing function call chains - to analyze or modify sequences of chained function calls. - - Returns: - FunctionCall[Parent] | None: The previous function call in the chain, or None if there is no predecessor - or if the predecessor is not a function call. - """ - # Recursively travel down the tree to find the previous function call (child nodes are previous calls) - name = self.get_name() - while name: - if isinstance(name, FunctionCall): - return name - elif isinstance(name, ChainedAttribute): - name = name.object - else: - break - return None - - @property - @reader - def successor(self) -> FunctionCall[Parent] | None: - """Returns the next function call in a function call chain. - - Returns the next function call in a function call chain. This method is useful for traversing function call chains - to analyze or modify sequences of chained function calls. - - Returns: - FunctionCall[Parent] | None: The next function call in the chain, or None if there is no successor - or if the successor is not a function call. - """ - # this will avoid parent function calls in tree-sitter that are NOT part of the chained calls - if not isinstance(self.parent, ChainedAttribute): - return None - - return self.parent_of_type(FunctionCall) - - @property - @noapidoc - @override - def viz(self) -> VizNode: - from codegen.visualizations.enums import VizNode - - func = self.function_definition - from codegen.sdk.core.function import Function - - if isinstance(func, Function) and func.is_method: - name = f"{func.parent_class.name}.{self.name}" - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=name, symbol_name=self.__class__.__name__) - else: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - - @property - @reader - def source(self) -> str: - """Gets the source code representation of this FunctionCall. - - Returns the textual representation of the function call. For chained function calls (e.g., a().b()), - it returns only the current function call's source code by removing the predecessor's source. - - Args: - None - - Returns: - str: The source code representation of this function call. For chained calls, returns only the current - function call's portion of the chain. - """ - if self.predecessor: - # TODO: breaks edit logic b/c start/end bytes no longer match up - # Remove the parent function call from the source - return self.extended_source.replace(self.predecessor.extended_source, "").strip()[1:].strip() - else: - return self.extended_source - - @property - @reader - def args(self) -> Collection[Argument, Self]: - """Returns a list of arguments passed into the function invocation. - - The `args` property provides access to all arguments, both positional and keyword, that are passed to the function call. - - Args: - None - - Returns: - Collection[Argument, Self]: A collection containing the function's arguments. - """ - # TODO - this may be language-specific - return self._arg_list - - def set_kwarg(self, name: str, value: str, *, create_on_missing: bool = True, override_existing: bool = True) -> None: - """Set a keyword argument in a function call. - - Sets or modifies a keyword argument in the function call. Can create new arguments or modify existing ones based on configuration. - - Args: - name (str): The name of the parameter/argument to set. - value (str): The value to set the argument to. - create_on_missing (bool, optional): If True, creates a new keyword argument if it doesn't exist. Defaults to True. - override_existing (bool, optional): If True, modifies the value of existing argument. Defaults to True. - - Returns: - None - - Raises: - None - """ - if existing := self.get_arg_by_parameter_name(name): - if not existing.is_named: - existing.add_keyword(name) - if override_existing: - existing.set_value(value) - - elif create_on_missing: - if param := self.find_parameter_by_name(name): - # Smart insert into the right place: - for idx, arg in enumerate(self.args): - if other_param := arg.parameter: - if other_param.index > param.index: - self.args.insert(idx, f"{name}={value}") - return - self.args.append(f"{name}={value}") - - @noapidoc - @reader - def find_parameter_by_index(self, index: int) -> Parameter | None: - from codegen.sdk.python import PyFunction - - for function_definition in self.function_definitions: - if function_definition.node_type == NodeType.EXTERNAL or function_definition.parameters is None: - continue - - if isinstance(function_definition, PyFunction) and (function_definition.is_method and not function_definition.is_static_method): - index += 1 - for param in function_definition.parameters: - if index == param.index: - return param - - @noapidoc - @reader - def find_parameter_by_name(self, name: str) -> Parameter | None: - for function_definition in self.function_definitions: - if function_definition.node_type == NodeType.EXTERNAL or function_definition.parameters is None: - continue - for param in function_definition.parameters: - if param.name == name: - return param - - @reader - def get_arg_by_parameter_name(self, param_name: str) -> Argument | None: - """Returns an argument by its parameter name. - - Searches through the arguments of a function call to find an argument that matches - a specified parameter name. This first checks for named arguments (kwargs) that match - the parameter name directly, then checks for positional arguments by resolving their - corresponding parameter names. - - Args: - param_name (str): The name of the parameter to search for. - - Returns: - Argument | None: The matching argument if found, None otherwise. - """ - args = self.args - if len(args) == 0: - return None - - # =====[ Named args ]===== - for arg in args: - if arg.name == param_name: - return arg - - for arg in self.args: - if param := arg.parameter: - if param.name == param_name: - return arg - - @reader - def get_arg_by_index(self, arg_idx: int) -> Argument | None: - """Returns the Argument with the given index from the function call's argument list. - - Args: - arg_idx (int): The index of the argument to retrieve. - - Returns: - Argument | None: The Argument object at the specified index, or None if the index is out of bounds. - """ - try: - return self.args[arg_idx] - except IndexError: - return None - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def convert_args_to_kwargs(self, exclude: int = 0) -> None: - """Converts positional arguments in a function call to keyword arguments. - - This method converts positional arguments to keyword arguments, excluding any leading arguments specified by the exclude parameter. - This is useful when refactoring function calls to be more explicit and self-documenting. - - Args: - exclude (int): Number of leading positional arguments to exclude from conversion. Defaults to 0. - - Returns: - None - - Note: - - Skips conversion if the argument is already named - - Skips arguments within the exclude range - - Skips unpacked arguments (e.g. **kwargs) - - Stops converting if it encounters a named argument that would conflict with an existing one - - Requires the function definition to be resolvable and have parameters - """ - definition = self.function_definition - from codegen.sdk.core.interfaces.callable import Callable - - if definition is None or definition.parameters is None or not isinstance(definition, Callable): - return - - for arg in reversed(self.args): - if arg.is_named: - # skip if the argument is already named - continue - - if arg.index < exclude: - # skip if the argument is in the exclude range - continue - if isinstance(arg.value, Unpack): - # Skip unpack (ie **kwargs) - continue - if param := arg.parameter: - if other_arg := self.get_arg_by_parameter_name(param.name): - if other_arg.is_named and other_arg != arg: - return # Already exists, can't keep converting - arg.add_keyword(param.name) - - @cached_property - @reader - @noapidoc - def function_definition_frames(self) -> list[ResolutionStack[Callable]]: - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.interfaces.callable import Callable - - result = [] - if self.get_name(): - for resolution in self.get_name().resolved_type_frames: - top_node = resolution.top.node - if isinstance(top_node, Callable): - if isinstance(top_node, Class): - if constructor := top_node.constructor: - result.append(resolution.with_new_base(constructor, direct=True)) - continue - result.append(resolution) - return result - - @cached_property - @reader - def function_definitions(self) -> list[Callable]: - """Returns a list of callable objects that could potentially be the target of this function - call. - - Finds and returns all possible functions that this call could be invoking based on name resolution. - This is useful for analyzing parameter names, parameter types, and return types of the potential target functions. - - Returns: - list[Callable]: A list of Callable objects representing the possible function definitions that this call could be invoking. - """ - result = [] - for frame in self.function_definition_frames: - result.append(frame.top.node) - return result - - @property - @reader - def function_definition(self) -> Callable | None: - """Returns the resolved function definition that is being called. - - This method returns the function definition associated with this function call. - This is useful for accessing parameter names, parameter types, and return types of the called function. - - Returns: - Callable | None: The resolved function definition, or None if no definition is found. - """ - return next(iter(self.function_definitions), None) - - @remover - def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Removes a node and optionally its related extended nodes. - - This method removes a FunctionCall node from the codebase. If the node is part of an expression statement, - it removes the entire expression statement. Otherwise, it performs a standard node removal. - - Args: - delete_formatting (bool, optional): Whether to delete associated formatting nodes. Defaults to True. - priority (int, optional): Priority level for the removal operation. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate identical removals. Defaults to True. - - Returns: - None - """ - if self.ts_node.parent.type == "expression_statement": - Value(self.ts_node.parent, self.file_node_id, self.ctx, self.parent).remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - else: - super().remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.function import Function - - if self.get_name().ts_node.type == "import" or self.full_name == "require": - # TS imports - for imp in self.file.imports: - if imp.ts_node.start_point[0] == self.ts_node.start_point[0]: - yield from imp.resolved_type_frames - return - if len(self.function_definitions) == 0: - resolved = False - for resolution in self.get_name().resolved_type_frames: - if len(resolution.generics) == 1: - yield from self.with_resolution_frame(next(iter(resolution.generics.values())), direct=resolution.direct) - resolved = True - elif len(resolution.generics) > 1: - yield from self.with_resolution(resolution) - resolved = True - if not resolved: - yield ResolutionStack(self) # This let's us still calculate dependencies even if we can't resolve a function call's definition - for function_def_frame in self.function_definition_frames: - function_def = function_def_frame.top.node - if isinstance(function_def, Function): - if function_def.is_constructor: - yield from self.with_resolution_frame(function_def.parent_class, direct=function_def_frame.direct) - elif return_type := function_def.return_type: - if function_def_frame.generics: - if generic := function_def_frame.generics.get(return_type.source, None): - yield from self.with_resolution_frame(generic, direct=function_def_frame.direct) - return - if self.ctx.config.generics: - for arg in self.args: - if arg.parameter and (type := arg.parameter.type): - if type.source == return_type.source: - yield from self.with_resolution_frame(arg.value, direct=function_def_frame.direct) - return - if isinstance(type, GenericType): - for param in type.parameters: - if param.source == return_type.source: - yield from self.with_resolution_frame(arg.value, direct=function_def_frame.direct) - return - - yield from self.with_resolution_frame(return_type, direct=False) - elif isinstance(function_def, Class): - yield from self.with_resolution_frame(function_def, direct=function_def_frame.direct, aliased=function_def_frame.aliased) - # else: - - # yield from self.with_resolution_frame(function_def, direct=False) # Untyped functions - # else: - # yield from self.with_resolution_frame(function_def, direct=False) # External Modules - - @noapidoc - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - for arg in self.args: - arg._compute_dependencies(usage_type, dest) - if desc := self.child_by_field_name("type_arguments"): - desc._compute_dependencies(UsageKind.GENERIC, dest) - match = self.get_name() - if match: - if len(self.function_definition_frames) > 0: - if isinstance(match, ChainedAttribute): - match.object._compute_dependencies(usage_type, dest) - if isinstance(match, FunctionCall): - match._compute_dependencies(usage_type, dest) - for definition in self.function_definition_frames: - definition.add_usage(self, usage_type, dest, self.ctx) - else: - match._compute_dependencies(usage_type, dest) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns a list of all function calls contained within this function call. - - This method traverses through all arguments and the function name node to find any nested - function calls. For example, if a function call has arguments that are themselves function - calls, these will be included in the returned list. - - Returns: - list[FunctionCall]: A list of FunctionCall instances contained within this function call, - including the call itself. Sorted by their appearance in the code. - """ - calls = [self] - for arg in self.args: - calls.extend(arg.function_calls) - calls.extend(self._name_node.function_calls) - # for call in self._name_node.function_calls: - # if isinstance(call.parent, TSChainedAttribute): - # call.parent = self - # calls.append(call) - return sort_editables(calls, dedupe=False) - - @property - @reader - def attribute_chain(self) -> list[FunctionCall | Name]: - """Returns a list of elements in the chainedAttribute that the function call belongs in. - - Breaks down chained expressions into individual components in order of appearance. - For example: `a.b.c().d` -> [Name("a"), Name("b"), FunctionCall("c"), Name("d")] - - Returns: - list[FunctionCall | Name]: List of Name nodes (property access) and FunctionCall nodes (method calls) - """ - if isinstance(self.get_name(), ChainedAttribute): # child is chainedAttribute. MEANING that this is likely in the middle or the last function call of a chained function call chain. - return self.get_name().attribute_chain - elif isinstance( - self.parent, ChainedAttribute - ): # does not have child chainedAttribute, but parent is chainedAttribute. MEANING that this is likely the TOP function call of a chained function call chain. - return self.parent.attribute_chain - else: # this is a standalone function call - return [self] - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = self.get_name().descendant_symbols - for arg in self.args: - symbols.extend(arg.descendant_symbols) - return symbols - - @noapidoc - @writer - def rename_if_matching(self, old: str, new: str): - if name := self.get_name(): - name.rename_if_matching(old, new) - - @noapidoc - def register_api_call(self, url: str): - assert url, self - self.ctx.global_context.multigraph.usages[url].append(self) - - @property - @reader - def call_chain(self) -> list[FunctionCall]: - """Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one.""" - ret = [] - - # backward traversal - curr = self - pred = curr.predecessor - while pred is not None and isinstance(pred, FunctionCall): - ret.insert(0, pred) - pred = pred.predecessor - - ret.append(self) - - # forward traversal - curr = self - succ = curr.successor - while succ is not None and isinstance(succ, FunctionCall): - ret.append(succ) - succ = succ.successor - - return ret - - @property - @reader - def base(self) -> Editable | None: - """Returns the base object of this function call chain. - - Args: - Editable | None: The base object of this function call chain. - """ - name = self.get_name() - while isinstance(name, ChainedAttribute): - if isinstance(name.object, FunctionCall): - return name.object.base - else: - name = name.object - return name - - @property - @reader - def promise_chain(self) -> TSPromiseChain | None: - """Return the promise chain associated with this function call, if a then call is found. - - Returns: - TSPromiseChain | None: The promise chain associated with this function call, if a then call is found. - """ - if any(call.name == "then" for call in self.call_chain) is True: - return TSPromiseChain(self.attribute_chain) - return None diff --git a/src/codegen/sdk/core/detached_symbols/parameter.py b/src/codegen/sdk/core/detached_symbols/parameter.py deleted file mode 100644 index 7dc08d3e8..000000000 --- a/src/codegen/sdk/core/detached_symbols/parameter.py +++ /dev/null @@ -1,232 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from typing_extensions import deprecated - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageType -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.typeable import Typeable -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.extensions.resolution import UsageKind -from codegen.sdk.utils import find_first_descendant -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.resolution_stack import ResolutionStack - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.function import Function - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.symbol_groups.collection import Collection - - -logger = get_logger(__name__) - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Collection[Parameter, Function]") - - -@apidoc -class Parameter(Usable[Parent], Typeable[TType, Parent], HasValue, Expression[Parent], Generic[TType, Parent]): - """Abstract representation of a parameter in a Function definition.""" - - _pos: int - _name_node: Name | None = None - - def __init__(self, ts_node: TSNode, index: int, parent: Parent) -> None: - super().__init__(ts_node, parent.file_node_id, parent.ctx, parent) - self._pos = index - name_node = self._get_name_node(ts_node) - self._name_node = self._parse_expression(name_node, default=Name) - self._init_type() - value_node = self._get_value_node(ts_node) - self._value_node = self._parse_expression(value_node) if value_node else None - - @reader - def _get_name_node(self, ts_node: TSNode) -> TSNode | None: - if ts_node.type == "identifier": - return ts_node - else: - name_node = find_first_descendant(ts_node, ["identifier", "shorthand_property_identifier_pattern", "this"]) - if name_node is None: - # Some parameters don't have names, e.g. the {} in `async run({}, arg2, arg3) {..}` - self._log_parse("Unable to find name node in parameter: %s", ts_node.text.decode("utf-8")) - return name_node - - @reader - def _get_value_node(self, ts_node: TSNode) -> TSNode | None: - return ts_node.child_by_field_name("value") - - @property - @reader - def index(self) -> int: - """Returns the 0-based index of this parameter within its parent function's parameter list. - - Args: - None - - Returns: - int: The position of the parameter in the function's parameter list, 0-based. - """ - return self._pos - - @deprecated("Use `type.edit` instead") - @writer - def set_type_annotation(self, type_annotation: str) -> None: - """Sets the type annotation for this parameter. - - This method is deprecated in favor of `type.edit`. - - Args: - type_annotation (str): The type annotation to set for the parameter. - - Returns: - None - """ - self.type.edit(type_annotation) - - @property - @reader - def default(self) -> str | None: - """Returns the default value of a parameter if one exists. - - Gets the default value of a parameter in a function definition. This is the value that would be used if the parameter is not provided in a function call. - - Args: - None - - Returns: - str | None: The string representation of the default value if one exists, None otherwise. - """ - default_node = self.ts_node.child_by_field_name("value") - if default_node is None: - return None - return default_node.text.decode("utf-8") - - @property - @abstractmethod - def is_optional(self) -> bool: - """Returns whether the parameter is optional in its function definition. - - A parameter is optional if either: - 1. It has a default value - 2. Its type annotation is Optional[T] or T | None - 3. It is variadic (*args, **kwargs) - - Returns: - bool: True if the parameter is optional, False otherwise - """ - msg = "Subclasses must implement this method" - raise NotImplementedError(msg) - - @property - @abstractmethod - def is_variadic(self) -> bool: - """Returns whether the parameter is a variadic parameter. - - A variadic parameter allows a function to accept a variable number of arguments (e.g., *args in Python). - - Returns: - bool: True if the parameter is variadic (can accept variable number of arguments), - False otherwise. - """ - msg = "Subclasses must implement this method" - raise NotImplementedError(msg) - - @writer - def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Removes the parameter from the function definition and all its call sites. - - Removes the parameter from a function's definition and also removes the corresponding argument - from all call sites of the function. If an argument cannot be found at a call site, logs a message - and continues with other call sites. - - Args: - delete_formatting (bool, optional): Whether to delete formatting around the parameter. Defaults to True. - priority (int, optional): Priority level for the removal operation. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate removal operations. Defaults to True. - - Returns: - None - """ - # Step 1: Remove all usages of the parameter in call sites - call_sites = self.parent_function.call_sites - for call_site in call_sites: - arg = call_site.get_arg_by_parameter_name(self.name) - if arg is None: - arg = call_site.get_arg_by_index(self.index) - if arg is None: - logger.info(f"Unable to find argument with parameter name {self.name} at call site {call_site}") - continue - arg.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - - # Step 2: Actually remove the parameter from the function header - super().remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - - @writer - def rename(self, new_name: str, priority: int = 0) -> None: - """Renames a parameter in a function definition and updates all related references. - - Performs a comprehensive rename operation by updating the parameter name in the function definition, - all variable usages within the function body, and any keyword arguments in call sites. - - Args: - new_name (str): The new name for the parameter. - priority (int, optional): The priority of the edit operation. Defaults to 0. - - Returns: - None - """ - # Step 1: Rename the parameter in the function definition itself - self.set_name(new_name) - - # Step 2: Rename the parameter variable usages in the function body - for usage in self.usages(UsageType.DIRECT): - usage.match.edit(new_name) - - # Step 3: Rename any keyword arguments in all call sites - parent_function = self.parent_function - call_sites = parent_function.call_sites - for call_site in call_sites: - arg_to_rename = [arg for arg in call_site.args if arg.is_named and arg.name == self.name] - for arg in arg_to_rename: - arg.rename(new_name) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if self.type: - yield from self.with_resolution_frame(self.type) - if value := self.value: - yield from self.with_resolution_frame(value) - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.type: - self.type._compute_dependencies(UsageKind.TYPE_ANNOTATION, self.parent.self_dest) - if self.value: - self.value._compute_dependencies(UsageKind.DEFAULT_VALUE, self.parent.self_dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - ret = super().descendant_symbols - if self.type: - ret.extend(self.type.descendant_symbols) - if self.value: - ret.extend(self.value.descendant_symbols) - return ret diff --git a/src/codegen/sdk/core/directory.py b/src/codegen/sdk/core/directory.py deleted file mode 100644 index 806e90ff8..000000000 --- a/src/codegen/sdk/core/directory.py +++ /dev/null @@ -1,269 +0,0 @@ -import os -from collections.abc import Iterator -from pathlib import Path -from typing import TYPE_CHECKING, Generic, Literal, Self - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.interfaces.has_symbols import ( - FilesParam, - HasSymbols, - TClass, - TFile, - TFunction, - TGlobalVar, - TImport, - TImportStatement, - TSymbol, -) -from codegen.sdk.core.utils.cache_utils import cached_generator -from codegen.sdk.enums import NodeType -from codegen.sdk.extensions.sort import sort_editables -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -@apidoc -class Directory( - HasSymbols[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], - Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport], -): - """Directory representation for codebase. - - GraphSitter abstraction of a file directory that can be used to look for files and symbols within a specific directory. - - Attributes: - path: Absolute path of the directory. - dirpath: Relative path of the directory. - parent: The parent directory, if any. - items: A dictionary containing files and subdirectories within the directory. - """ - - ctx: "CodebaseContext" - path: Path # Absolute Path - dirpath: str # Relative Path - _files: list[str] # List of file names - _subdirectories: list[str] # List of subdirectory names - - def __init__(self, ctx: "CodebaseContext", path: Path, dirpath: str): - self.ctx = ctx - self.path = path - self.dirpath = dirpath - self._files = [] - self._subdirectories = [] - - def __iter__(self): - return iter(self.items) - - def _is_a_subdirectory_of(self, target_directory: Self): - """Checks whether this directory is a subdirectory of another directory""" - if self.parent == target_directory: - return True - if self.parent is None: - return False - return self.parent._is_a_subdirectory_of(target_directory=target_directory) - - def __contains__(self, item: str | TFile | Self) -> bool: - from codegen.sdk.core.file import File - - # Try to match all file and subdirectory names - if isinstance(item, str): - if item in self.item_names: - return True - # Try to match all subdirectories - elif isinstance(item, Directory): - if item.name in [directory.name for directory in self.subdirectories]: - return True - # Try to match all files - elif isinstance(item, File): - if item.name in [file.name for file in self.files(extensions="*")]: - return True - - # Attempt to match recursively - for directory in self.subdirectories(recursive=False): - if item in directory: - return True - - # If no match, return False - return False - - def __len__(self) -> int: - # Using item names here as items will cause an infinite loop - return len(self.item_names) - - def __getitem__(self, item_name: str) -> TFile | Self: - return next((item for item in self.items if item.name == item_name), None) - - def __repr__(self) -> str: - return f"Directory(name='{self.name}', items={self.item_names})" - - @property - def name(self) -> str: - """Get the base name of the directory's path. - - Extracts the final component of the directory path. For example, for a path '/home/user/project', returns 'project'. - - Returns: - str: The directory's base name. - """ - return os.path.basename(self.dirpath) - - @proxy_property - def files(self, *, extensions: list[str] | Literal["*"] | None = None, recursive: bool = False) -> list[TFile]: - """Gets a list of all top level files in the directory. - - Set `recursive=True` to get all files recursively. - - By default, this only returns source files. Setting `extensions='*'` will return all files, and - `extensions=[...]` will return all files with the specified extensions. - - For Python and Typescript repos WITH file parsing enabled, - `extensions='*'` is REQUIRED for listing all non source code files. - Or else, codebase.files will ONLY return source files (e.g. .py, .ts). - - For repos with file parsing disabled or repos with other languages, this will return all files in the codebase. - - Returns all Files in the codebase, sorted alphabetically. For Python codebases, returns PyFiles (python files). - For Typescript codebases, returns TSFiles (typescript files). - - Returns: - list[TSourceFile]: A sorted list of source files in the codebase. - """ - # If there are no source files, return ALL files - if len(self.ctx.get_nodes(NodeType.FILE)) == 0: - extensions = "*" - # If extensions is not set, use the extensions from the codebase - elif extensions is None: - extensions = self.ctx.extensions - - files = [] - for file_name in self._files: - if extensions == "*": - files.append(self.get_file(file_name)) - elif extensions is not None: - if any(file_name.endswith(ext) for ext in extensions): - files.append(self.get_file(file_name)) - - if recursive: - for directory in self.subdirectories: - files.extend(directory.files(extensions=extensions, recursive=True)) - - return sort_editables(files, alphabetical=True, dedupe=False) - - @proxy_property - def subdirectories(self, recursive: bool = False) -> list[Self]: - """Get a list of all top level subdirectories in the directory. - - Set `recursive=True` to get all subdirectories recursively. - - Returns: - list[Directory]: A sorted list of subdirectories in the directory. - """ - subdirectories = [] - for directory_name in self._subdirectories: - subdirectories.append(self.get_subdirectory(directory_name)) - - if recursive: - for directory in self.subdirectories: - subdirectories.extend(directory.subdirectories(recursive=True)) - - return sorted(subdirectories, key=lambda x: x.name) - - @proxy_property - def items(self, recursive: bool = False) -> list[Self | TFile]: - """Get a list of all files and subdirectories in the directory. - - Set `recursive=True` to get all files and subdirectories recursively. - - Returns: - list[Self | TFile]: A sorted list of files and subdirectories in the directory. - """ - return self.files(extensions="*", recursive=recursive) + self.subdirectories(recursive=recursive) - - @property - def item_names(self, recursive: bool = False) -> list[str]: - """Get a list of all file and subdirectory names in the directory. - - Set `recursive=True` to get all file and subdirectory names recursively. - - Returns: - list[str]: A list of file and subdirectory names in the directory. - """ - return self._files + self._subdirectories - - @property - def file_names(self) -> list[str]: - """Get a list of all file names in the directory.""" - return self._files - - @property - def tree(self) -> list[Self | TFile]: - """Get a recursive list of all files and subdirectories in the directory. - - Returns: - list[Self | TFile]: A recursive list of files and subdirectories in the directory. - """ - return self.items(recursive=True) - - @property - def parent(self) -> Self | None: - """Get the parent directory of the current directory.""" - return self.ctx.get_directory(self.path.parent) - - @noapidoc - @cached_generator() - def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterator[TFile]: - """Yield files recursively from the directory.""" - yield from self.files(*args, extensions="*", **kwargs, recursive=True) - - def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None: - """Get a file by its name relative to the directory.""" - file_path = os.path.join(self.dirpath, filename) - absolute_path = self.ctx.to_absolute(file_path) - # Try to get the file from the graph first - file = self.ctx.get_file(file_path, ignore_case=ignore_case) - if file is not None: - return file - # If the file is not in the graph, check the filesystem - for file in absolute_path.parent.iterdir(): - if ignore_case and str(absolute_path).lower() == str(file).lower(): - return self.ctx._get_raw_file_from_path(file) - elif not ignore_case and str(absolute_path) == str(file): - return self.ctx._get_raw_file_from_path(file) - return None - - def get_subdirectory(self, subdirectory_name: str) -> Self | None: - """Get a subdirectory by its name (relative to the directory).""" - return self.ctx.get_directory(os.path.join(self.dirpath, subdirectory_name)) - - def update_filepath(self, new_filepath: str) -> None: - """Update the filepath of the directory and its contained files.""" - old_path = self.dirpath - new_path = new_filepath - for file in self.files(recursive=True): - new_file_path = os.path.join(new_path, os.path.relpath(file.file_path, old_path)) - file.update_filepath(new_file_path) - - def remove(self) -> None: - """Remove all the files in the files container.""" - for f in self.files(recursive=True): - f.remove() - - def rename(self, new_name: str) -> None: - """Rename the directory.""" - parent_dir, _ = os.path.split(self.dirpath) - new_path = os.path.join(parent_dir, new_name) - self.update_filepath(new_path) - - def _add_file(self, file_name: str) -> None: - """Add a file to the directory.""" - self._files.append(file_name) - - def _add_subdirectory(self, subdirectory_name: str) -> None: - """Add a subdirectory to the directory.""" - self._subdirectories.append(subdirectory_name) diff --git a/src/codegen/sdk/core/export.py b/src/codegen/sdk/core/export.py deleted file mode 100644 index e9d63a9e5..000000000 --- a/src/codegen/sdk/core/export.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.interfaces.exportable import Exportable -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.core.symbol_groups.collection import Collection - - -Parent = TypeVar("Parent", bound="Collection[Export, ExportStatement]") - - -@apidoc -class Export(Exportable[Parent], Generic[Parent]): - """Represents a single symbol being exported. - - Attributes: - export_statement: The statement representing the export. - """ - - export_statement: ExportStatement - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - self.to_file_id = file_node_id - super().__init__(ts_node=ts_node, file_node_id=file_node_id, ctx=ctx, parent=parent) - - @noapidoc - @abstractmethod - def parse(self, ctx: CodebaseContext) -> None: - """Add self to the graph and SYMBOL_USAGE edges from export to exported symbol.""" - - @property - @abstractmethod - def exported_symbol(self) -> Exportable | None: - """Returns the symbol, file, or import being exported from this export object. - - Returns: - Exportable | None: The exported symbol, file, or import, or None if it cannot be resolved. - """ - - @property - @abstractmethod - def resolved_symbol(self) -> Exportable | None: - """Returns the resolved symbol for an export. - - Gets the final symbol, file, or external module that this export resolves to by following through indirect imports and exports. - - Returns: - Exportable | None: The final resolved symbol, which can be a Symbol, File, or External module. Returns None if the symbol cannot be resolved. - """ - - @abstractmethod - def is_named_export(self) -> bool: - """Determines if the export is named or default. - - Returns: - bool: True if the export is named, False if it is default. - """ - - @abstractmethod - def is_module_export(self) -> bool: - """Determines if the export is a module-level export. - - This method checks if the export statement represents a module-level export, - such as wildcard exports or default object exports. - - Returns: - bool: True if the export is a module-level export, False otherwise. - """ - - def is_aliased(self) -> bool: - """Determines if the Export object is aliased. - - Checks if the exported symbol has a different name than the name it is exported as. - - Returns: - bool: True if the exported symbol has a different name than the name it is exported as, False otherwise. - """ - from codegen.sdk.core.import_resolution import Import - - if self.exported_symbol is None: - return False - if isinstance(self.exported_symbol, Import): - return self.exported_name != self.exported_symbol.symbol_name - return self.exported_name != self.exported_symbol.name - - @noapidoc - @commiter - def compute_export_dependencies(self) -> None: - raise NotImplementedError - - @property - @noapidoc - def parent_symbol(self) -> Self: - """Returns the parent symbol of the symbol.""" - return self diff --git a/src/codegen/sdk/core/expressions/__init__.py b/src/codegen/sdk/core/expressions/__init__.py deleted file mode 100644 index 28664353e..000000000 --- a/src/codegen/sdk/core/expressions/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import TYPE_CHECKING - -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.expressions.string import String -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.expressions.value import Value -from codegen.sdk.core.symbol_groups.dict import Dict -from codegen.sdk.core.symbol_groups.list import List - -if TYPE_CHECKING: - from codegen.sdk.core.detached_symbols.function_call import FunctionCall # noqa: TC004 - -__all__ = ["Dict", "Expression", "FunctionCall", "List", "Name", "String", "Type", "Value"] diff --git a/src/codegen/sdk/core/expressions/await_expression.py b/src/codegen/sdk/core/expressions/await_expression.py deleted file mode 100644 index e3643de34..000000000 --- a/src/codegen/sdk/core/expressions/await_expression.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.wrapper_expression import IWrapper -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class AwaitExpression(Expression[Parent], HasValue, IWrapper, Generic[Parent]): - """An awaited expression, only found in asynchronous contexts, e.g. await(foo(bar))""" - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent=parent) - value_node = self.ts_node.named_children[0] - self._value_node = self.ctx.parser.parse_expression(value_node, self.file_node_id, self.ctx, parent) if value_node else None - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets all function calls within the await expression. - - Returns: - list[FunctionCall]: A list of function call nodes contained within the await expression's value. - """ - return self.resolve().function_calls diff --git a/src/codegen/sdk/core/expressions/binary_expression.py b/src/codegen/sdk/core/expressions/binary_expression.py deleted file mode 100644 index fdb4ec349..000000000 --- a/src/codegen/sdk/core/expressions/binary_expression.py +++ /dev/null @@ -1,132 +0,0 @@ -import itertools -from collections import deque -from collections.abc import Generator -from functools import cached_property -from typing import Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.interfaces.unwrappable import Unwrappable -from codegen.sdk.core.symbol_groups.expression_group import ExpressionGroup -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.extensions.sort import sort_editables -from codegen.shared.decorators.docs import apidoc, noapidoc - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class BinaryExpression(Expression[Parent], Chainable, Generic[Parent]): - """Represents binary expressions, e.g. all of +,-,*,/, as well as boolean operations (and, or) etc. - - Attributes: - left: The left operand of the binary expression. - right: The right operand of the binary expression. - """ - - left: Expression[Self] | None - right: Expression[Self] | None - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self.left = self.child_by_field_name("left") - self.right = self.child_by_field_name("right") - - @property - @noapidoc - def operator(self) -> ExpressionGroup[Expression[Self], Self]: - """Returns the operator of the binary expression.""" - operator_nodes = self.ts_node.children[1:-1] - return ExpressionGroup(self.file_node_id, self.ctx, self, children=[self._parse_expression(node) for node in operator_nodes]) - - @property - def operators(self) -> list[ExpressionGroup[Expression[Self], Self]]: - """Returns a list of operators in a chain of binary operations. - - Returns all operators found in a chain of binary operations, maintaining the order in which they appear. For example, - in the expression "a + b - c * d / e", it would return the operators [+, -, *, /] in that order. - - Returns: - list[ExpressionGroup[Expression[Self], Self]]: The list of operators in the binary expression chain, ordered as they appear in the code. - """ - operators = [self.operator] - nodes_to_process = deque([self.left, self.right]) - while nodes_to_process: - node = nodes_to_process.popleft() - if isinstance(node, BinaryExpression): - operators.append(node.operator) - nodes_to_process.extend([node.left, node.right]) - return sort_editables(operators, dedupe=False) - - @cached_property - def elements(self) -> list[Expression[Self]]: - """Returns all elements in a binary expression chain. - - Retrieves all elements that appear in a chain of binary operations in the expression, - traversing through nested binary expressions to extract individual elements. - - Args: - None - - Returns: - list[Expression[Self]]: A sorted list of non-binary expression elements in the chain. - For example, in the expression 'a + b - c * d / e', returns [a, b, c, d, e] in order. - """ - elements = [] - nodes_to_process = deque([self.left, self.right]) - while nodes_to_process: - node = nodes_to_process.popleft() - if isinstance(node, BinaryExpression): - nodes_to_process.extend([node.left, node.right]) - else: - elements.append(node) - return sort_editables(elements, dedupe=False) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - for e in self.elements: - yield from self.with_resolution_frame(e) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - return list(itertools.chain.from_iterable(elem.descendant_symbols for elem in self.elements)) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self.left._compute_dependencies(usage_type, dest) - self.right._compute_dependencies(usage_type, dest) - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable) -> None: - """Simplifies a binary expression by reducing it based on a boolean condition. - - - Args: - bool_condition (bool): The boolean value to reduce the condition to. - - """ - reduce_operator = False - if "and" in self.operator or "&&" in self.operator: - reduce_operator = not bool_condition - # We can inline the entire operator if the condition if False. - # a and b evaluates to False if either a or b is False - elif "or" in self.operator or "||" in self.operator: - reduce_operator = bool_condition # We can inline the entire operator if the condition is True - # a or b evaluates to True if either a or b is True - if reduce_operator: - self.parent.reduce_condition(bool_condition, self) - else: - node.remove() - if isinstance(self.parent, Unwrappable): - other_node = self.left if node == self.right else self.right - self.parent.unwrap(other_node) diff --git a/src/codegen/sdk/core/expressions/boolean.py b/src/codegen/sdk/core/expressions/boolean.py deleted file mode 100644 index 28655adaf..000000000 --- a/src/codegen/sdk/core/expressions/boolean.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Generic, TypeVar, override - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class Boolean(Expression[Parent], Builtin, Generic[Parent]): - """A boolean value eg. - - True, False - """ - - def __bool__(self): - return self.ts_node.type == "true" - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass - - @property - def __class__(self): - return bool diff --git a/src/codegen/sdk/core/expressions/builtin.py b/src/codegen/sdk/core/expressions/builtin.py deleted file mode 100644 index 04be98f97..000000000 --- a/src/codegen/sdk/core/expressions/builtin.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Self, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.external_module import ExternalModule - - -@noapidoc -class Builtin(Chainable, HasAttribute): - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - # TODO: resolve builtin type - yield ResolutionStack(self) - - @noapidoc - @override - def resolve_attribute(self, name: str) -> "ExternalModule | None": - # HACK/TODO - return None - # return ExternalModule(self.ts_node, self.file_node_id, self.ctx, name) diff --git a/src/codegen/sdk/core/expressions/chained_attribute.py b/src/codegen/sdk/core/expressions/chained_attribute.py deleted file mode 100644 index ccd5a788f..000000000 --- a/src/codegen/sdk/core/expressions/chained_attribute.py +++ /dev/null @@ -1,182 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Optional, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Name -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.resolvable import Resolvable -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - - -Object = TypeVar("Object", bound="Chainable") -Attribute = TypeVar("Attribute", bound="Resolvable") -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class ChainedAttribute(Expression[Parent], Resolvable, Generic[Object, Attribute, Parent]): - """An attribute of an object. (IE a method on a class, a function from a module, etc) - - Examples: - A.method() - """ - - _object: Object - _attribute: Attribute - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent, object: TSNode, attribute: TSNode): - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self._object = self._parse_expression(object, default=Name) - if self.ctx.parser._should_log: - if not isinstance(self._object, Chainable): - msg = f"{self._object.__class__} is not chainable: {self._object.source}\nfile: {self.filepath}" - raise ValueError(msg) - self._attribute = self._parse_expression(attribute, default=Name) - if self.ctx.parser._should_log: - if not isinstance(self._attribute, Resolvable): - msg = f"{self._attribute.__class__} is not resolvable: {self._attribute.source}\nfile: {self.filepath}" - raise ValueError(msg) - - @property - @reader - def full_name(self) -> str: - """Returns the full name of the attribute, including the object expression. - - Gets the complete name representation of a chained attribute, which includes both the object and attribute parts (e.g., 'my_object.my_attribute'). - - Returns: - str: The full string representation of the chained attribute expression. - """ - return self.source - - @property - @reader - def attribute(self) -> Attribute: - """Gets the attribute being accessed in a chained attribute expression. - - This property returns the Attribute component of a chained attribute expression (e.g., in `object.attribute`, returns the `attribute` part). - - Args: - None - - Returns: - Attribute: The attribute component of the chained expression. - """ - return self._attribute - - @property - @reader - def attribute_chain(self) -> list["FunctionCall | Name"]: - """Returns a list of elements in a chained attribute expression. - - Breaks down chained expressions into individual components in order of appearance. - For example: `a.b.c().d` -> [Name("a"), Name("b"), FunctionCall("c"), Name("d")] - - Returns: - list[FunctionCall | Name]: List of Name nodes (property access) and FunctionCall nodes (method calls) - """ - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - - ret = [] - curr = self - - # Traverse backwards in code (children of tree node) - while isinstance(curr, ChainedAttribute): - curr = curr.object - - if isinstance(curr, FunctionCall): - ret.insert(0, curr) - curr = curr.get_name() - elif isinstance(curr, ChainedAttribute): - ret.insert(0, curr.attribute) - - # This means that we have reached the base of the chain and the first item was an attribute (i.e a.b.c.func()) - if isinstance(curr, Name) and not isinstance(curr.parent, FunctionCall): - ret.insert(0, curr) - - curr = self - - # Traversing forward in code (parents of tree node). Will add the current node as well - while isinstance(curr, ChainedAttribute) or isinstance(curr, FunctionCall): - if isinstance(curr, FunctionCall): - ret.append(curr) - elif isinstance(curr, ChainedAttribute) and not isinstance(curr.parent, FunctionCall): - ret.append(curr.attribute) - - curr = curr.parent - - return ret - - @property - def object(self) -> Object: - """Returns the object that contains the attribute being looked up. - - Provides access to the object part of a chained attribute expression (e.g., in 'A.method', returns the 'A' part). - - Returns: - Object: The object component of the chained attribute expression. Guaranteed to be an instance of Chainable. - """ - return self._object - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if not self.ctx.config.method_usages: - return - if res := self.file.valid_import_names.get(self.full_name, None): - # Module imports - yield from self.with_resolution_frame(res) - return - - for resolved_type in self.object.resolved_type_frames: - top = resolved_type.top - - if not isinstance(top.node, HasAttribute): - generics: dict = resolved_type.generics.copy() - if top.node.source.lower() == "dict" and self.attribute.source in ("values", "get", "pop"): - if len(generics) == 2: - generics.pop(next(iter(generics.keys()))) - yield from self.with_resolution_frame(top.node, generics=generics, direct=resolved_type.is_direct_usage, chained=True) - self._log_parse("%r does not have attributes, passing %s generics", top.node, len(generics)) - continue - name = self.attribute.source - if attr := top.node.resolve_attribute(name): - yield from self.with_resolution_frame(attr, chained=True, generics=resolved_type.generics) - else: - self._log_parse("Couldn't resolve attribute %s on %s", name, top.node) - yield from self.with_resolution_frame(top.node, direct=resolved_type.is_direct_usage, chained=True) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None"] = None) -> None: - edges = [] - for used_frame in self.resolved_type_frames: - edges.extend(used_frame.get_edges(self, usage_type, dest, self.ctx)) - edges = list(dict.fromkeys(edges)) - self.ctx.add_edges(edges) - if self.object.source not in ("self", "this"): - self.object._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - return self.object.descendant_symbols + self.attribute.descendant_symbols - - @noapidoc - @writer - def rename_if_matching(self, old: str, new: str): - if self.attribute.source == old: - self.attribute.edit(new) diff --git a/src/codegen/sdk/core/expressions/comparison_expression.py b/src/codegen/sdk/core/expressions/comparison_expression.py deleted file mode 100644 index fc0cb2f2a..000000000 --- a/src/codegen/sdk/core/expressions/comparison_expression.py +++ /dev/null @@ -1,59 +0,0 @@ -from functools import cached_property -from typing import Self, TypeVar - -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.binary_expression import BinaryExpression -from codegen.sdk.core.symbol_groups.expression_group import ExpressionGroup -from codegen.shared.decorators.docs import apidoc - -Parent = TypeVar("Parent") - - -@apidoc -class ComparisonExpression(BinaryExpression): - """Any comparison expression in the code. - - Includes all set of `<`, `<=`, `>`, `>=`, `==`, `!=` etc. - """ - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self.left = self.elements[0] - self.right = self.elements[-1] - - @property - def operators(self) -> list[ExpressionGroup[Expression[Self], Self]]: - """Returns a list of operator expressions in a comparison expression. - - Extracts and groups the non-named operators (e.g., <, <=, >, >=, ==, !=) from the - comparison expression's tree-sitter node. Each group of operators is wrapped in an - ExpressionGroup. - - Returns: - list[ExpressionGroup[Expression[Self], Self]]: A list of ExpressionGroups - containing one or more expression operators that appear between the compared - elements. - """ - elements = set(self.ts_node.named_children) - operators = [] - operator_group = [] - for n in self.ts_node.children: - if n not in elements: - operator_group.append(n) - elif operator_group: - operator = ExpressionGroup(self.file_node_id, self.ctx, self, children=[self._parse_expression(op) for op in operator_group]) - operators.append(operator) - operator_group.clear() - return operators - - @cached_property - def elements(self) -> list[Expression[Self]]: - """Returns a list of expressions for named child nodes. - - Args: - None - - Returns: - list[Expression[Self]]: A list of Expression objects for each named child node. - """ - return [self._parse_expression(node) for node in self.ts_node.named_children] diff --git a/src/codegen/sdk/core/expressions/defined_name.py b/src/codegen/sdk/core/expressions/defined_name.py deleted file mode 100644 index 7e28b4b8c..000000000 --- a/src/codegen/sdk/core/expressions/defined_name.py +++ /dev/null @@ -1,26 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.expressions import Name -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.symbol import Symbol - - -Parent = TypeVar("Parent", bound="Symbol") - - -class DefinedName(Name[Parent], Generic[Parent]): - """A name that defines a symbol. - - Does not reference any other names - """ - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield ResolutionStack(self) diff --git a/src/codegen/sdk/core/expressions/expression.py b/src/codegen/sdk/core/expressions/expression.py deleted file mode 100644 index 85cc5d794..000000000 --- a/src/codegen/sdk/core/expressions/expression.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from typing import Generic, TypeVar - -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.enums import NodeType -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Expression(Editable[Parent], Generic[Parent]): - """Represents an arbitrary Expression, such as List, Dict, Binary Expression, String. - - Attributes: - node_type: The type of the node, set to NodeType.EXPRESSION. - """ - - node_type: NodeType = NodeType.EXPRESSION - - @property - @reader - def resolved_value(self) -> Expression | list[Expression]: - """Returns the resolved type of an Expression. - - Returns the inferred type of the expression. For example, a function call's resolved value will be its definition. - - Returns: - Expression | list[Expression]: The resolved expression type(s). Returns a single Expression if there is only one resolved type, - or a list of Expressions if there are multiple resolved types. Returns self if the expression is not resolvable or has no resolved types. - """ - if isinstance(self, Chainable) and (resolved_types := self.resolved_types): - if len(resolved_types) == 1: - return resolved_types[0] - return resolved_types - return self diff --git a/src/codegen/sdk/core/expressions/generic_type.py b/src/codegen/sdk/core/expressions/generic_type.py deleted file mode 100644 index 5fc60e7c9..000000000 --- a/src/codegen/sdk/core/expressions/generic_type.py +++ /dev/null @@ -1,76 +0,0 @@ -from abc import abstractmethod -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.named_type import NamedType -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.resolution import ResolutionStack -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent") - - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class GenericType(NamedType[Parent], Generic[TType, Parent]): - """Abstract representation of the generic types of the programming language.""" - - _parameters: Collection[TType, Self] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self._parameters = self._get_parameters() - - @property - @reader - def parameters(self) -> Collection[TType, Self]: - """Retrieves the generic type parameters associated with this type. - - Args: - None - - Returns: - Collection[TType, Self]: A collection of generic type parameters associated with this type. - """ - return self._parameters - - @abstractmethod - def _get_parameters(self) -> Collection[TType, Self]: - pass - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - super()._compute_dependencies(usage_type, dest) - for param in self._parameters: - param._compute_dependencies(UsageKind.GENERIC, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - """Returns the nested symbols of the importable object, including itself.""" - ret = self.get_name().descendant_symbols - for param in self._parameters: - ret.extend(param.descendant_symbols) - return ret - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if name := self.get_name(): - yield from self.with_resolution_frame(name, generic_parameters=self.parameters) diff --git a/src/codegen/sdk/core/expressions/multi_expression.py b/src/codegen/sdk/core/expressions/multi_expression.py deleted file mode 100644 index 58d09874c..000000000 --- a/src/codegen/sdk/core/expressions/multi_expression.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar, override - -from codegen.sdk.core.expressions import Expression -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - - -Parent = TypeVar("Parent", bound="Expression") -TExpression = TypeVar("TExpression", bound="Expression") - - -@apidoc -class MultiExpression(Expression[Parent], Generic[Parent, TExpression]): - """Represents an group of Expressions, such as List, Dict, Binary Expression, String. - - Attributes: - expressions: A list of expressions contained within the MultiExpression. - """ - - expressions: list[TExpression] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, expressions: list[TExpression]) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self.expressions = expressions - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - for exp in self.expressions: - exp._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py deleted file mode 100644 index 78554a5b7..000000000 --- a/src/codegen/sdk/core/expressions/name.py +++ /dev/null @@ -1,122 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Optional, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock -from codegen.sdk.core.interfaces.resolvable import Resolvable -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.symbol import Symbol - -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class Name(Expression[Parent], Resolvable, Generic[Parent]): - """Editable attribute on any given code objects that has a name. - - For example, function, classes, global variable, interfaces, attributes, parameters are all - composed of a name. - """ - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - """Resolve the types used by this symbol.""" - for used in self.resolve_name(self.source, self.start_byte): - yield from self.with_resolution_frame(used) - - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: - """Compute the dependencies of the export object.""" - edges = [] - for used_frame in self.resolved_type_frames: - edges.extend(used_frame.get_edges(self, usage_type, dest, self.ctx)) - if self.ctx.config.debug: - edges = list(dict.fromkeys(edges)) - self.ctx.add_edges(edges) - - @noapidoc - @writer - def rename_if_matching(self, old: str, new: str): - if self.source == old: - self.edit(new) - - @noapidoc - def _resolve_conditionals(self, conditional_parent: ConditionalBlock, name: str, original_resolved): - """Resolves name references within conditional blocks by traversing the conditional chain. - - This method handles name resolution within conditional blocks (like if/elif/else statements) by: - 1. Finding the appropriate search boundary based on the conditional block's position - 2. Handling "fake" conditionals by traversing up the conditional chain - 3. Yielding resolved names while respecting conditional block boundaries - - Args: - conditional_parent (ConditionalBlock): The parent conditional block containing the name reference - name (str): The name being resolved - original_resolved: The originally resolved symbol that triggered this resolution - - Yields: - Symbol | Import | WildcardImport: Resolved symbols found within the conditional blocks - - Notes: - - A "fake" conditional is one where is_true_conditional() returns False - - The search_limit ensures we don't resolve names that appear after our target - - The method stops when it either: - a) Reaches the top of the conditional chain - b) Returns to the original conditional block - c) Can't find any more resolutions - """ - search_limit = conditional_parent.start_byte_for_condition_block - if search_limit >= original_resolved.start_byte: - search_limit = original_resolved.start_byte - 1 - if not conditional_parent.is_true_conditional(original_resolved): - # If it's a fake conditional we must skip any potential enveloping conditionals - def get_top_of_fake_chain(conditional, resolved, search_limit=0): - if skip_fake := conditional.parent_of_type(ConditionalBlock): - if skip_fake.is_true_conditional(resolved): - return skip_fake.start_byte_for_condition_block - search_limit = skip_fake.start_byte_for_condition_block - return get_top_of_fake_chain(skip_fake, conditional, search_limit) - return search_limit - - if search_limit := get_top_of_fake_chain(conditional_parent, original_resolved): - search_limit = search_limit - else: - return - - original_conditional = conditional_parent - while next_resolved := next(conditional_parent.resolve_name(name, start_byte=search_limit, strict=False), None): - yield next_resolved - next_conditional = next_resolved.parent_of_type(ConditionalBlock) - if not next_conditional or next_conditional == original_conditional: - return - search_limit = next_conditional.start_byte_for_condition_block - if next_conditional and not next_conditional.is_true_conditional(original_resolved): - pass - if search_limit >= next_resolved.start_byte: - search_limit = next_resolved.start_byte - 1 - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]: - resolved_name = next(super().resolve_name(name, start_byte or self.start_byte, strict=strict), None) - if resolved_name: - yield resolved_name - else: - return - - if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): - if self.parent_of_type(ConditionalBlock) == conditional_parent: - # Use in the same block, should only depend on the inside of the block - return - - yield from self._resolve_conditionals(conditional_parent=conditional_parent, name=name, original_resolved=resolved_name) diff --git a/src/codegen/sdk/core/expressions/named_type.py b/src/codegen/sdk/core/expressions/named_type.py deleted file mode 100644 index f908244c0..000000000 --- a/src/codegen/sdk/core/expressions/named_type.py +++ /dev/null @@ -1,75 +0,0 @@ -from abc import abstractmethod -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Name, String -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.interfaces.resolvable import Resolvable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class NamedType(Resolvable, Type[Parent], HasName, Generic[Parent]): - """An abstract representation of a named type.""" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self._name_node = self._parse_expression(self._get_name_node(), default=Name) - - def __eq__(self, other: object) -> bool: - from codegen.sdk.core.symbol import Symbol - - if isinstance(other, Symbol): - for resolved in self.resolved_types: - if other == resolved: - return True - return super().__eq__(other) - - def __hash__(self) -> int: - # needed so this class is hashable - return super().__hash__() - - @abstractmethod - def _get_name_node(self) -> TSNode: - pass - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if name := self.get_name(): - yield from self.with_resolution_frame(name) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - if isinstance(self.get_name(), String): - # TODO: string annotations - self._log_parse("String type annotations are not currently supported") - return - self.get_name()._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - """Returns the nested symbols of the importable object, including itself.""" - return self.get_name().descendant_symbols - - @noapidoc - @writer - def rename_if_matching(self, old: str, new: str): - self.get_name().rename_if_matching(old, new) diff --git a/src/codegen/sdk/core/expressions/none_type.py b/src/codegen/sdk/core/expressions/none_type.py deleted file mode 100644 index 8a5956090..000000000 --- a/src/codegen/sdk/core/expressions/none_type.py +++ /dev/null @@ -1,29 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class NoneType(Type[Parent], Generic[Parent]): - """Represents a None or Null object.""" - - @noapidoc - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - pass - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from [] diff --git a/src/codegen/sdk/core/expressions/number.py b/src/codegen/sdk/core/expressions/number.py deleted file mode 100644 index a52c3605b..000000000 --- a/src/codegen/sdk/core/expressions/number.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Generic, TypeVar, override - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class Number(Expression[Parent], Builtin, Generic[Parent]): - """A number value. - - eg. 1, 2.0, 3.14 - """ - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass - - @property - def __class__(self): - return int diff --git a/src/codegen/sdk/core/expressions/parenthesized_expression.py b/src/codegen/sdk/core/expressions/parenthesized_expression.py deleted file mode 100644 index 6e05de3fa..000000000 --- a/src/codegen/sdk/core/expressions/parenthesized_expression.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Generic, TypeVar, override - -from codegen.sdk.codebase.transactions import TransactionPriority -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.unwrappable import Unwrappable -from codegen.sdk.core.interfaces.wrapper_expression import IWrapper -from codegen.sdk.extensions.autocommit import reader -from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement -from codegen.shared.decorators.docs import apidoc - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class ParenthesizedExpression(Unwrappable[Parent], HasValue, IWrapper, Generic[Parent]): - """An expression surrounded in a set of parenthesis. - - Example: - ```typescript - (5 + 5) - ``` - """ - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent=parent) - value_node = self.ts_node.named_children[0] - self._value_node = self.ctx.parser.parse_expression(value_node, self.file_node_id, self.ctx, self) if value_node else None - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Retrieves a list of function calls within a parenthesized expression. - - Gets all function calls from the resolved value of this parenthesized expression. - - Returns: - list[FunctionCall]: A list of FunctionCall objects representing all function calls within the parenthesized expression. - """ - return self.resolve().function_calls - - @writer - @override - def unwrap(self, node: Expression | None = None) -> None: - """Removes the parentheses from a parenthesized expression node. - - Modifies the AST by removing the parentheses from a ParenthesizedExpression node, leaving only the inner expression. - - Args: - node (Expression | None, optional): The node to unwrap. Defaults to None. - - Returns: - None - """ - if isinstance(self.parent, TSIfBlockStatement): - return - if node is None: - remaining = list( - child - for child in self.value.children - if not self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=child.start_byte, end_byte=child.end_byte, transaction_order=TransactionPriority.Remove) - ) - if len(remaining) == 1: - node = remaining[0] - else: - return - if node.start_point[0] == node.end_point[0]: - for child in self._anonymous_children: - child.remove() - if isinstance(self.parent, Unwrappable): - self.parent.unwrap(node) - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable) -> None: - """Simplifies an expression based on a boolean condition. - - Args: - bool_condition (bool): The boolean value to reduce the condition to. - node (Editable): The node to be simplified. - - Returns: - None - """ - self.unwrap() - self.parent.reduce_condition(bool_condition, self) diff --git a/src/codegen/sdk/core/expressions/placeholder_type.py b/src/codegen/sdk/core/expressions/placeholder_type.py deleted file mode 100644 index 5ce966f21..000000000 --- a/src/codegen/sdk/core/expressions/placeholder_type.py +++ /dev/null @@ -1,32 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import commiter -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class PlaceholderType(Type[Parent], Generic[TType, Parent]): - """Represents a type that has not been implemented yet.""" - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - self._add_all_identifier_usages(usage_type, dest=dest) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from [] diff --git a/src/codegen/sdk/core/expressions/string.py b/src/codegen/sdk/core/expressions/string.py deleted file mode 100644 index 3be669a98..000000000 --- a/src/codegen/sdk/core/expressions/string.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class String(Expression[Parent], Builtin, Generic[Parent]): - """GraphSitter representation of String. - - Attributes: - content: The content of the string - content_nodes: A collection of string fragments and escape sequences in TS, or a single string content in Python. - expressions: Embedded expressions in the string, only applicable for templated or formatted strings. - """ - - content: str - content_nodes: Collection[Expression[Editable], Self] # string content is a collection of string_fragments and escape_sequences in TS and a single string_content in Python - expressions: list[Expression[Editable]] # expressions in the string, only applicable for template strings - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - content_children = list(self.children_by_field_types({"string_content", "string_fragment", "escape_sequence"})) - self.content_nodes = Collection(ts_node, self.file_node_id, self.ctx, self, delimiter="", children=content_children) - self.content = "".join(x.ts_node.text.decode("utf-8") for x in content_children) - - @reader - def __eq__(self, other: object) -> bool: - if isinstance(other, str) and other == self.content: - return True - return super().__eq__(other) - - def __str__(self): - return self.content - - def __hash__(self): - return super().__hash__() - - @property - @reader - def with_quotes(self) -> str: - """Retrieves the string representation with quotation marks. - - Returns: - str: The string value with its surrounding quotation marks. - """ - return self.source - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - # If the string is a template string, we need to compute the dependencies of the string content - for expression in self.expressions: - expression._compute_dependencies(usage_type, dest) - - @property - def __class__(self): - return str diff --git a/src/codegen/sdk/core/expressions/subscript_expression.py b/src/codegen/sdk/core/expressions/subscript_expression.py deleted file mode 100644 index 51d92a7aa..000000000 --- a/src/codegen/sdk/core/expressions/subscript_expression.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Optional, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Name -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.resolvable import Resolvable -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.chainable import Chainable - from codegen.sdk.core.interfaces.has_name import HasName - - -Object = TypeVar("Object", bound="Chainable") -Index = TypeVar("Index", bound="Expression") -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class SubscriptExpression(Expression[Parent], Resolvable[Parent], Generic[Object, Index, Parent]): - """Indexing onto an object (Aka using brackets on an object) - - Examples: - A[] - - Attributes: - object: The object being indexed. - indices: A list of indices used for indexing the object. - - """ - - object: Object - indices: list[Index] - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self.object = self._parse_expression(self.ts_node.children[0], default=Name) - self.indices = self.children[1:] - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - # TODO: implement this properly - yield from self.object.resolved_type_frames - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None"] = None) -> None: - self.object._compute_dependencies(usage_type, dest) - for index in self.indices: - index._compute_dependencies(usage_type, dest) - - @writer - @noapidoc - def rename_if_matching(self, old: str, new: str) -> None: - if self.object: - self.object.rename_if_matching(old, new) diff --git a/src/codegen/sdk/core/expressions/ternary_expression.py b/src/codegen/sdk/core/expressions/ternary_expression.py deleted file mode 100644 index c3681cd3c..000000000 --- a/src/codegen/sdk/core/expressions/ternary_expression.py +++ /dev/null @@ -1,78 +0,0 @@ -import itertools -from collections.abc import Generator -from typing import Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.interfaces.unwrappable import Unwrappable -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class TernaryExpression(Expression[Parent], Chainable, Generic[Parent]): - """Any ternary expression in the code where a condition will determine branched execution. - - Attributes: - condition: The condition expression that determines which branch to execute. - consequence: The expression to execute if the condition is true. - alternative: The expression to execute if the condition is false. - """ - - condition: Expression[Self] | None - consequence: Expression[Self] | None - alternative: Expression[Self] | None - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Simplifies a ternary expression based on a boolean condition. - - Args: - bool_condition (bool): The boolean value to reduce the condition to. If True, keeps the consequence branch. If False, keeps the alternative branch. - node (Editable | None, optional): The node to be edited. Defaults to None. - - Returns: - None: Modifies the ternary expression in place. - """ - # ==== [ Reduce condition to True ] ==== - to_keep = self.consequence if bool_condition else self.alternative - for node in self._anonymous_children: - node.remove() - self.condition.remove() - if bool_condition: - self.alternative.remove() - else: - self.consequence.remove() - self.remove_byte_range(self.alternative.ts_node.prev_sibling.end_byte, self.alternative.start_byte) - if isinstance(to_keep, Unwrappable): - to_keep.unwrap() - if isinstance(self.parent, Unwrappable): - self.parent.unwrap(to_keep) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.consequence) - yield from self.with_resolution_frame(self.alternative) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - elems = [self.condition, self.consequence, self.alternative] - return list(itertools.chain.from_iterable(elem.descendant_symbols for elem in elems if elem)) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self.condition._compute_dependencies(usage_type, dest) - self.consequence._compute_dependencies(usage_type, dest) - self.alternative._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/core/expressions/tuple_type.py b/src/codegen/sdk/core/expressions/tuple_type.py deleted file mode 100644 index 1495fe8a5..000000000 --- a/src/codegen/sdk/core/expressions/tuple_type.py +++ /dev/null @@ -1,57 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.importable import Importable - - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class TupleType(Collection[Type, Parent], Type[Parent], Generic[TType, Parent]): - """An abstract representation of a tuple type. - For example `[number, number]`. - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent, delimiter=" |") - elements = list(self._get_types(ts_node)) - self._init_children(elements) - self._bracket_size = 0 - - def _get_types(self, node: TSNode) -> Generator[TType, None, None]: - for child in node.named_children: - type_cls = self.ctx.node_classes.type_map.get(child.type, None) - if isinstance(type_cls, type) and issubclass(type_cls, self.__class__): - yield from self._get_types(child) - else: - yield self._parse_type(child) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - for type in self.symbols: - yield from self.with_resolution_frame(type) - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - """Returns the nested symbols of the importable object, including itself.""" - ret = [] - for param in self.symbols: - ret.extend(param.descendant_symbols) - return ret diff --git a/src/codegen/sdk/core/expressions/type.py b/src/codegen/sdk/core/expressions/type.py deleted file mode 100644 index 2630427d4..000000000 --- a/src/codegen/sdk/core/expressions/type.py +++ /dev/null @@ -1,48 +0,0 @@ -import itertools -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -from typing_extensions import deprecated - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.symbol import Symbol - - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Type(Expression[Parent], Chainable, ABC, Generic[Parent]): - """Abstract representation of a type - Used to store the types of variables, parameters, or return values in functions, classes, etc. - """ - - @noapidoc - @abstractmethod - def _compute_dependencies(self, usage_type: UsageKind, dest: "Importable"): ... - - @property - @deprecated("Use resolved_types instead for internal uses") - @noapidoc - @reader - def resolved_symbol(self) -> "Symbol | str | None": - from codegen.sdk.core.symbol import Symbol - - for resolved in self.resolved_types: - if isinstance(resolved, Symbol): - return resolved - return None - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - """Returns the nested symbols of the importable object, including itself.""" - return list(itertools.chain.from_iterable(child.descendant_symbols for child in self.children)) diff --git a/src/codegen/sdk/core/expressions/unary_expression.py b/src/codegen/sdk/core/expressions/unary_expression.py deleted file mode 100644 index e1f8c3ba8..000000000 --- a/src/codegen/sdk/core/expressions/unary_expression.py +++ /dev/null @@ -1,58 +0,0 @@ -from collections.abc import Generator -from typing import Generic, Self, TypeVar, override - -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.extensions.resolution import ResolutionStack -from codegen.sdk.extensions.utils import TSNode -from codegen.shared.decorators.docs import apidoc, noapidoc - -Parent = TypeVar("Parent", bound="Expression") - - -@apidoc -class UnaryExpression(Expression[Parent], Chainable, Generic[Parent]): - """Unary expression which is a single operation on a single operand. eg. -5, !true. - - Attributes: - argument: The argument of the unary expression - """ - - argument: Expression[Self] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self.argument = self._parse_expression(ts_node.child_by_field_name("argument")) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - """Resolve the types used by this symbol.""" - yield from self.with_resolution_frame(self.argument) - - @commiter - @noapidoc - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - self.argument._compute_dependencies(usage_type, dest) - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Simplifies a unary expression by reducing it based on a boolean condition. - - - Args: - bool_condition (bool): The boolean value to reduce the condition to. - - """ - if self.ts_node.type == "not_operator" or self.source.startswith("!"): - self.parent.reduce_condition(not bool_condition, self) - else: - super().reduce_condition(bool_condition, node) diff --git a/src/codegen/sdk/core/expressions/union_type.py b/src/codegen/sdk/core/expressions/union_type.py deleted file mode 100644 index a9c979cae..000000000 --- a/src/codegen/sdk/core/expressions/union_type.py +++ /dev/null @@ -1,57 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.importable import Importable - - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class UnionType(Collection[Type, Parent], Type[Parent], Generic[TType, Parent]): - """An abstract representation of a union type. - For example `str | None` or `string | number`. - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent, delimiter=" |") - elements = list(self._get_types(ts_node)) - self._init_children(elements) - self._bracket_size = 0 - - def _get_types(self, node: TSNode) -> Generator[TType, None, None]: - for child in node.named_children: - type_cls = self.ctx.node_classes.type_map.get(child.type, None) - if isinstance(type_cls, type) and issubclass(type_cls, self.__class__): - yield from self._get_types(child) - else: - yield self._parse_type(child) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - for type in self.symbols: - yield from self.with_resolution_frame(type) - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - """Returns the nested symbols of the importable object, including itself.""" - ret = [] - for param in self.symbols: - ret.extend(param.descendant_symbols) - return ret diff --git a/src/codegen/sdk/core/expressions/unpack.py b/src/codegen/sdk/core/expressions/unpack.py deleted file mode 100644 index 10dd7be52..000000000 --- a/src/codegen/sdk/core/expressions/unpack.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.unwrappable import Unwrappable -from codegen.sdk.core.interfaces.wrapper_expression import IWrapper -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Unpack(Unwrappable[Parent], HasValue, IWrapper, Generic[Parent]): - """Unpacking of an iterable. - - Example: - ```python - [a, *b] - ``` - """ - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self._value_node = self.children[0] - - def unwrap(self, node: Expression | None = None) -> None: - """Unwraps a node's content into its parent node. - - Unwraps the content of a node by removing its wrapping syntax and merging its content with its parent node. - Specifically handles dictionary unwrapping, maintaining proper indentation and formatting. - - Args: - node (Expression | None): The node to unwrap. If None, uses the instance's value node. - - Returns: - None - """ - from codegen.sdk.core.symbol_groups.dict import Dict - - node = node or self._value_node - if isinstance(node, Dict) and isinstance(self.parent, Dict): - if self.start_point[0] != self.parent.start_point[0]: - self.remove(delete_formatting=False) - self.remove_byte_range(self.start_byte - self.start_point[1], self.start_byte) - next_sibling = self.next_sibling - if next_sibling.source == ",": - next_sibling = next_sibling.next_sibling - indent_start = next_sibling.start_byte - next_sibling.start_point[1] - self.remove_byte_range(self.end_byte, next_sibling.start_byte) - self.insert_at(next_sibling.start_byte, self.file.content_bytes[indent_start : next_sibling.start_byte].decode("utf-8"), priority=-10) - else: - # Delete the remaining characters on this line - self.remove_byte_range(self.end_byte, next_sibling.start_byte - next_sibling.start_point[1]) - - else: - self.remove() - for k, v in node.items(): - self.parent[k] = v.source.strip() - if node.unpack: - self.parent._underlying.append(self.node.unpack.source) diff --git a/src/codegen/sdk/core/expressions/value.py b/src/codegen/sdk/core/expressions/value.py deleted file mode 100644 index c70dec4ba..000000000 --- a/src/codegen/sdk/core/expressions/value.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.has_name import HasName - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Value(Expression[Parent], Generic[Parent]): - """Editable attribute on code objects that has a value. - - For example, Functions, Classes, Assignments, Interfaces, Expressions, Arguments and Parameters all have values. - - See also HasValue. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.ctx.parser.log_unparsed(self.ts_node) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None): - for node in self.children: - node._compute_dependencies(usage_type, dest=dest) diff --git a/src/codegen/sdk/core/external/dependency_manager.py b/src/codegen/sdk/core/external/dependency_manager.py deleted file mode 100644 index 8f9ea7a3a..000000000 --- a/src/codegen/sdk/core/external/dependency_manager.py +++ /dev/null @@ -1,38 +0,0 @@ -from abc import abstractmethod -from typing import TYPE_CHECKING - -from codegen.sdk.core.external.external_process import ExternalProcess -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -class DependencyManager(ExternalProcess): - """Manages dependencies for the given repository. - - Handles reading, installing, and managing any dependency-based operations. - """ - - @abstractmethod - def parse_dependencies(self): - pass - - @abstractmethod - def install_dependencies(self): - pass - - @abstractmethod - def remove_dependencies(self): - pass - - -def get_dependency_manager(language: ProgrammingLanguage, codebase_context: "CodebaseContext", enabled: bool = False) -> DependencyManager | None: - from codegen.sdk.typescript.external.dependency_manager import TypescriptDependencyManager - - ts_enabled = enabled or codebase_context.config.ts_dependency_manager - if language == ProgrammingLanguage.TYPESCRIPT: - if ts_enabled: - return TypescriptDependencyManager(repo_path=codebase_context.repo_path, base_path=codebase_context.projects[0].base_path) - - return None diff --git a/src/codegen/sdk/core/external/external_process.py b/src/codegen/sdk/core/external/external_process.py deleted file mode 100644 index f1951bafb..000000000 --- a/src/codegen/sdk/core/external/external_process.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import threading -import time -from abc import ABC, abstractmethod - -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class ExternalProcess(ABC): - """Base class for all additional extrnal services that require a separate process. - - Examples include language engines, dependency managers, etc. - - Attributes: - repo_path (str): Path to the repository root directory - base_path (str | None): Optional subdirectory path within the repo to analyze - full_path (str): Complete path combining repo_path and base_path - is_ready (bool): Whether the engine has completed initialization and is ready - error (BaseException | None): Whether the engine encountered an error during startup - """ - - repo_path: str - base_path: str | None - full_path: str - is_ready: bool - _error: BaseException | None - - def __init__(self, repo_path: str, base_path: str | None = None): - self.repo_path: str = repo_path - self.base_path: str | None = base_path - self.full_path = os.path.join(repo_path, base_path) if base_path else repo_path - self.is_ready: bool = False - self._error: BaseException | None = None - - def start(self, async_start: bool = False): - if async_start: - # Create a new thread to start the engine - thread = threading.Thread(target=self._start) - thread.start() - else: - self._start() - - @abstractmethod - def _start(self): - pass - - def reparse(self, async_start: bool = False): - # Reparse logic is handled by re-running start() - self.is_ready = False - self.start(async_start=async_start) - - def ready(self) -> bool: - return self.is_ready - - def error(self) -> BaseException | None: - return self._error - - def wait_until_ready(self, ignore_error: bool = False): - logger.info(f"Waiting for {self.__class__.__name__} to be ready...") - # Wait for 3 minutes first - start_time = time.time() - while not self.ready() and not self.error() and (time.time() - start_time) < 60 * 3: - time.sleep(1) - - # After 3 minutes, check every 15 seconds and warn - while not self.ready() and not self.error() and (time.time() - start_time) < 60 * 5: - logger.warning(f"{self.__class__.__name__} still not ready after 3 minutes for {self.full_path}") - time.sleep(15) - - # After 5 minutes, check every 30 seconds and error - while not self.ready() and not self.error(): - logger.error(f"{self.__class__.__name__} still not ready after 5 minutes for {self.full_path}") - time.sleep(30) - - if not ignore_error and self.error(): - raise self.error() diff --git a/src/codegen/sdk/core/external/language_engine.py b/src/codegen/sdk/core/external/language_engine.py deleted file mode 100644 index 1673bd658..000000000 --- a/src/codegen/sdk/core/external/language_engine.py +++ /dev/null @@ -1,37 +0,0 @@ -from abc import abstractmethod -from typing import TYPE_CHECKING - -from codegen.sdk.core.external.external_process import ExternalProcess -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.editable import Editable - - -class LanguageEngine(ExternalProcess): - """Base class for all third part language engine support. - - This class provides the foundation for integrating external language analysis engines. - It handles initialization, startup, and status tracking of the engine. - """ - - @abstractmethod - def get_return_type(self, node: "Editable") -> str | None: - pass - - -def get_language_engine(language: ProgrammingLanguage, codebase_context: "CodebaseContext", use_ts: bool = False, use_v8: bool = False) -> LanguageEngine | None: - from codegen.sdk.typescript.external.ts_analyzer_engine import NodeTypescriptEngine, V8TypescriptEngine - - use_ts = use_ts or codebase_context.config.ts_language_engine - use_v8 = use_v8 or codebase_context.config.v8_ts_engine - if language == ProgrammingLanguage.TYPESCRIPT: - if use_ts and use_v8: - # Enables with both ts_language_engine and v8_ts_engine feature flags are on - return V8TypescriptEngine(repo_path=codebase_context.repo_path, base_path=codebase_context.projects[0].base_path, dependency_manager=codebase_context.dependency_manager) - elif use_ts: - # Enabled with only ts_language_engine feature flag is on - return NodeTypescriptEngine(repo_path=codebase_context.repo_path, base_path=codebase_context.projects[0].base_path, dependency_manager=codebase_context.dependency_manager) - - return None diff --git a/src/codegen/sdk/core/external_module.py b/src/codegen/sdk/core/external_module.py deleted file mode 100644 index 0e97bbf04..000000000 --- a/src/codegen/sdk/core/external_module.py +++ /dev/null @@ -1,160 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, override - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.interfaces.callable import Callable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.placeholder.placeholder_stub import StubPlaceholder -from codegen.sdk.enums import ImportType, NodeType -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.visualizations.enums import VizNode - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.expressions.name import Name - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - - -@apidoc -class ExternalModule( - Callable, - HasAttribute["ExternalModule"], -): - """Represents an external module, like `datetime`, that can be referenced. - - These are only added to the graph during import resolution and will not exist in a local file's subgraph. This is because we don't know what an import is referencing or resolves to until we see - the full codebase. - - Attributes: - node_type: The type of node, set to NodeType.EXTERNAL. - """ - - node_type: Literal[NodeType.EXTERNAL] = NodeType.EXTERNAL - _import: Import | None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, import_name: Name, import_node: Import | None = None) -> None: - self.node_id = ctx.add_node(self) - super().__init__(ts_node, file_node_id, ctx, None) - self._name_node = import_name - self.return_type = StubPlaceholder(parent=self) - assert self._idx_key not in self.ctx._ext_module_idx - self.ctx._ext_module_idx[self._idx_key] = self.node_id - self._import = import_node - - @property - def _idx_key(self) -> str: - return self.source + "::" + self.name - - @noapidoc - @commiter - def parse(self, ctx: CodebaseContext) -> None: - msg = f"{type(self)} is not part of the graph at the moment" - raise NotImplementedError(msg) - - @classmethod - def from_import(cls, imp: Import) -> ExternalModule: - """Creates an ExternalModule instance from an Import instance. - - This class method creates a new ExternalModule object that represents an external module - that can be referenced in the codebase, such as 'datetime' or other imported modules. - External modules are added to the graph during import resolution. - - Args: - imp (Import): An Import instance containing the module information. - - Returns: - ExternalModule: A new ExternalModule instance representing the external module. - """ - return cls(imp.ts_node, imp.file_node_id, imp.ctx, imp._unique_node, imp) - - @property - @reader - def parameters(self) -> list[Parameter]: - """Returns list of named parameters from an external function symbol. - - Retrieves the parameter list from an external module function. This is not yet implemented and will raise an error. - - Returns: - list[Parameter]: A list of parameters associated with the external function. - - Raises: - NotImplementedError: This functionality is not yet supported for external modules. - """ - # TODO: figure out how to get parameters from this module - msg = "Parsing parameters from an external module is not yet supported." - raise NotImplementedError(msg) - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Returns the import string used to import this module. - - Gets the string representation needed to import this external module. This method is used to generate import statements. - - Args: - alias (str | None, optional): An alternative name for the imported module. - module (str | None, optional): The module from which to import. - import_type (ImportType, optional): The type of import to generate. Defaults to ImportType.UNKNOWN. - is_type_import (bool, optional): Whether this is a type import. Defaults to False. - - Returns: - str: The import string that can be used to import this module. - """ - # TODO - will need to fix the relative imports - return self.source - - @property - def file(self) -> None: - """File property for ExternalModule class. - - Returns None since ExternalModule represents an external module that is not part of any local file. - - Returns: - None: Always returns None as ExternalModule is not associated with any file. - """ - return None - - @property - def filepath(self) -> str: - """Returns the filepath of the module. - - For an ExternalModule, this will always return an empty string as it represents an external module that - does not belong to the local codebase. - - Returns: - str: An empty string representing the filepath of the external module. - """ - return "" - - @property - @noapidoc - def viz(self) -> VizNode: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - - @noapidoc - @reader - def resolve_attribute(self, name: str) -> ExternalModule | None: - return self._import.resolve_attribute(name) or self - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - pass - - def __hash__(self): - if self._hash is None: - self._hash = hash((self.filepath, self.range, self.ts_node.kind_id, self._idx_key)) - return self._hash - - @reader - def __eq__(self, other: object): - if isinstance(other, ExternalModule): - return super().__eq__(other) and self._idx_key == other._idx_key - return super().__eq__(other) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py deleted file mode 100644 index 12bcab303..000000000 --- a/src/codegen/sdk/core/file.py +++ /dev/null @@ -1,1190 +0,0 @@ -import os -import re -import resource -import sys -from abc import abstractmethod -from collections.abc import Generator, Sequence -from functools import cached_property -from os import PathLike -from pathlib import Path -from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, override - -from tree_sitter import Node as TSNode -from typing_extensions import deprecated - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.codebase.range_index import RangeIndex -from codegen.sdk.codebase.span import Range -from codegen.sdk.core.autocommit import commiter, mover, reader, remover, writer -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.dataclasses.usage import UsageType -from codegen.sdk.core.directory import Directory -from codegen.sdk.core.import_resolution import Import, WildcardImport -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.core.statements.import_statement import ImportStatement -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.enums import EdgeType, ImportType, NodeType, SymbolType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.topological_sort import pseudo_topological_sort -from codegen.sdk.tree_sitter_parser import get_parser_by_filepath_or_extension, parse_file -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.utils import is_minified_js -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.logging.get_logger import get_logger -from codegen.visualizations.enums import VizNode - -if TYPE_CHECKING: - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.function import Function - from codegen.sdk.core.interface import Interface - -logger = get_logger(__name__) - - -@apidoc -class File(Editable[None]): - """Represents a generic file. - - Could represent a source file or a non-code file such as a markdown file or image file. - - Attributes: - name: The name of the file. - file_path: The relative file path as a string. - path: The absolute path of the file as a Path object. - node_type: The type of node, set to NodeType.FILE. - """ - - name: str - file_path: str - path: Path - node_type: Literal[NodeType.FILE] = NodeType.FILE - _pending_imports: set[str] - _binary: bool = False - _range_index: RangeIndex - - def __init__(self, filepath: PathLike, ctx: CodebaseContext, ts_node: TSNode | None = None, binary: bool = False) -> None: - if ts_node is None: - # TODO: this is a temp hack to deal with all symbols needing a TSNode. - parser = get_parser_by_filepath_or_extension(".py") - ts_node = parser.parse(bytes("", "utf-8")).root_node - self._range_index = RangeIndex() - super().__init__(ts_node, getattr(self, "node_id", None), ctx, None) - self.path = self.ctx.to_absolute(filepath) - self.file_path = str(self.ctx.to_relative(self.path)) - self.name = self.path.stem - self._binary = binary - - @property - @reader - @override - def _source(self): - """Text representation of the Editable instance.""" - if self._binary: - return f"[Binary Blob of size {len(self.content_bytes)} Bytes]" - else: - return self.content - - @property - def file(self) -> Self: - """A property that returns the file object for non-source files. - - This is used by Editable.file to work with non-source files, allowing consistent interface usage across both source and non-source files. - - Returns: - Self: The current file object. - """ - # This is a hack to allow Editable.file to work for non-source files - return self - - @classmethod - @noapidoc - def from_content(cls, filepath: str | Path, content: str | bytes, ctx: CodebaseContext, sync: bool = False, binary: bool = False) -> Self | None: - """Creates a new file from content.""" - if sync: - logger.warn("Creating & Syncing non-source files are not supported. Ignoring sync...") - path = ctx.to_absolute(filepath) - if not path.exists(): - update_graph = True - path.parent.mkdir(parents=True, exist_ok=True) - ctx.io.write_file(path, content) - ctx.io.save_files({path}) - - new_file = cls(filepath, ctx, ts_node=None, binary=binary) - return new_file - - @property - @noapidoc - @reader - def content_bytes(self) -> bytes: - """Loaded dynamically every time to preserve source of truth. - - TODO: move rest of graph sitter to operate in bytes to prevent multi byte character issues? - """ - return self.ctx.io.read_bytes(self.path) - - @property - @reader - def content(self) -> str: - """Returns the content of the file as a UTF-8 encoded string. - - Gets the content of the file, either from pending changes or by reading from disk. Binary files cannot be read as strings. - - Args: - None - - Returns: - str: The content of the file as a UTF-8 encoded string. - - Raises: - ValueError: If the file is binary. Use content_bytes instead for binary files. - """ - if self._binary: - msg = "Cannot read binary file as string. Use content_bytes instead." - raise ValueError(msg) - - return self.content_bytes.decode(encoding="utf-8") - - @noapidoc - def write(self, content: str | bytes, to_disk: bool = False) -> None: - """Writes contents to the file.""" - self.ctx.io.write_file(self.path, content) - if to_disk: - self.ctx.io.save_files({self.path}) - if self.ts_node.start_byte == self.ts_node.end_byte: - # TS didn't parse anything, register a write to make sure the transaction manager can restore the file later. - self.edit("") - - @noapidoc - @deprecated("Use write instead") - def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None: - self.write(content_bytes, to_disk=to_disk) - - @property - @reader - def directory(self) -> Directory | None: - """Returns the directory that contains this file. - - The file can be housed within a directory in the codebase, and this property will return that directory instance. - - Returns: - Directory | None: The directory containing this file, or None if the file is not in any directory. - """ - return self.ctx.get_directory(self.path.parent) - - @property - def is_binary(self) -> bool: - """Indicates whether the file contains binary data. - - A property that returns True if the file contains binary data, False if it contains text data. - - Returns: - bool: True if the file contains binary data, False if it contains text data. - """ - return self._binary - - @property - @reader - def extension(self) -> str: - """Returns the file extension. - - Returns: - str: The file extension including the dot (e.g., '.py', '.ts', '.js'). - """ - return os.path.splitext(self.file_path)[1] - - @property - @reader - def owners(self) -> set[str]: - """Returns the CODEOWNERS of the file. - - Returns all Github CODEOWNERS associated with this file. If there is no CODEOWNERS file in the codebase, returns an empty set. - - Returns: - set[str]: A set of Github usernames or team names that own this file. Empty if no CODEOWNERS file exists. - """ - if self.ctx.codeowners_parser: - # return get_filepath_owners(codeowners=self.ctx.codeowners_parser, filepath=self.file_path) - filename_owners = self.ctx.codeowners_parser.of(self.file_path) - return {owner[1] for owner in filename_owners} - return set() - - @cached_property - @noapidoc - def github_url(self) -> str | None: - if self.ctx.base_url: - if self.ctx.base_url.endswith(".git"): - return self.ctx.base_url.replace(".git", "/blob/develop/") + self.file_path - else: - return self.ctx.base_url + "/" + self.file_path - - @property - @reader - def start_byte(self) -> int: - """Returns the starting byte position of a file in its content. - - The start byte is always 0 for a file as it represents the beginning of the file's content. - - Returns: - int: Always returns 0. - """ - return 0 - - @remover - def remove(self) -> None: - """Removes the file from the file system and graph. - - Queues the file to be removed during the next commit operation. The file will be removed from the filesystem and its node will be removed from the graph. - - Args: - None - - Returns: - None - """ - self.transaction_manager.add_file_remove_transaction(self) - self.ctx.io.write_file(self.path, None) - - @property - def filepath(self) -> str: - """Retrieves the file path of the file that this Editable instance belongs to. - - Returns: - str: The file path of the file. - """ - return self.file_path - - @mover - def rename(self, new_name: str) -> None: - """Renames the file to the specified name, preserving the file extension. - - Args: - new_name (str): The new name for the file. If the new name includes the file extension, it will be used as-is. - Otherwise, the original file extension will be preserved. - - Returns: - None - - Note: - This method will update all imports that reference this file to use the new filepath. - The file will be physically moved on disk and all graph references will be updated. - """ - # Split the filepath into directory, filename, and extension - directory = self.path.parent - extension = self.path.suffix - - # Check if new name already contains the extension - if new_name.endswith(extension): - new_filename = new_name - else: - # Create the new filename with the original extension - new_filename = new_name + extension - - # Join the directory with the new filename - new_filepath = directory / new_filename - - # Rename the file - self.update_filepath(str(new_filepath)) - - @mover - def update_filepath(self, new_filepath: str) -> None: - """Updates the file path and inbound imports of a file. - - Updates the file path of the file on disk and in the codebase graph. Additionally updates all - inbound imports to reference the new file path. - - Args: - new_filepath (str): The new file path to rename the file to. - - Raises: - BadWriteError: If there are pending file writes that haven't been committed. - ValueError: If the new file path already exists in the codebase graph. - """ - # =====[ Change the file on disk ]===== - self.transaction_manager.add_file_rename_transaction(self, new_filepath) - - def parse(self, ctx: "CodebaseContext") -> None: - """Parses the file representation into the graph. - - This method is called during file initialization to parse the file and build its graph representation within the codebase graph. - - Args: - ctx (CodebaseContext): The codebase context that the file belongs to. - - Returns: - None - """ - pass - - @noapidoc - @commiter - def _compute_dependencies(self, *args, **kwargs) -> None: - pass - - @writer - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Replace the source of this file with new_src. - - For non-source files, replaces the entire content. For source files, delegates to the parent - Editable implementation which uses TreeSitter nodes for precise editing. - - Args: - new_src (str): The new source text to replace the current text with. - fix_indentation (bool): If True, adjusts the indentation of new_src to match the current - text's indentation level. Only applies to source files. Defaults to False. - priority (int): The priority of the edit transaction. Higher priority edits are - applied first. Defaults to 0. - dedupe (bool): If True, deduplicates identical transactions. Defaults to True. - - Raises: - ValueError: If attempting to edit a binary file. - - Returns: - None - """ - if self.is_binary: - msg = "Cannot replace content in binary files" - raise ValueError(msg) - - if self.ts_node is None or not isinstance(self, SourceFile): - self._edit_byte_range(new_src, 0, len(self.content_bytes), priority, dedupe) - else: - super().edit(new_src, fix_indentation, priority, dedupe) - - @writer - def replace(self, old: str, new: str, count: int = -1, is_regex: bool = False, priority: int = 0) -> int: - """Replace occurrences of text in the file. - - For non-source files, performs a direct string replacement. For source files, delegates to the - parent Editable implementation which uses TreeSitter nodes for precise replacements. - - Args: - old (str): The text to be replaced. - new (str): The text to replace with. - count (int): Maximum number of replacements to make. -1 means replace all occurrences. - Only applies to source files. Defaults to -1. - is_regex (bool): If True, treat 'old' as a regular expression pattern. - Only applies to source files. Defaults to False. - priority (int): The priority of the edit transaction. Higher priority edits are - applied first. Defaults to 0. - - Raises: - ValueError: If attempting to replace content in a binary file. - - Returns: - list[Editable]: List of affected Editable objects. For non-source files, always returns - an empty list since they don't have Editable sub-components. - """ - if self.is_binary: - msg = "Cannot replace content in binary files" - raise ValueError(msg) - - if self.ts_node is None or not isinstance(self, SourceFile): - if old not in self.content: - return 0 - - self._edit_byte_range(self.content.replace(old, new), 0, len(self.content_bytes), priority) - return 1 - else: - return super().replace(old, new, count, is_regex, priority) - - @staticmethod - @noapidoc - def get_extensions() -> list[str]: - """Returns a list of file extensions for the given programming language file.""" - return [] # By default, no extensions are "supported" for generic files - - -TImport = TypeVar("TImport", bound="Import") -TFunction = TypeVar("TFunction", bound="Function") -TClass = TypeVar("TClass", bound="Class") -TGlobalVar = TypeVar("TGlobalVar", bound="Assignment") -TInterface = TypeVar("TInterface", bound="Interface") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class SourceFile( - File, - HasBlock, - Usable, - HasAttribute[Symbol | TImport], - Generic[TImport, TFunction, TClass, TGlobalVar, TInterface, TCodeBlock], -): - """Represents a file with source code in the codebase. - - Enables creating, reading, updating, and deleting files and searching through their contents, - etc. - - Attributes: - code_block: Represents the block of code contained in the file. - """ - - code_block: TCodeBlock - _nodes: list[Importable] - - def __init__(self, ts_node: TSNode, filepath: PathLike, ctx: CodebaseContext) -> None: - self.node_id = ctx.add_node(self) - self._nodes = [] - super().__init__(filepath, ctx, ts_node=ts_node) - self._nodes.clear() - self.ctx.filepath_idx[self.file_path] = self.node_id - self._pending_imports = set() - try: - self.parse(ctx) - except RecursionError as e: - logger.exception(f"RecursionError parsing file {filepath}: {e} at depth {sys.getrecursionlimit()} and {resource.getrlimit(resource.RLIMIT_STACK)}") - raise e - except Exception as e: - logger.exception(f"Failed to parse file {filepath}: {e}") - raise e - - @property - @reader - @override - def _source(self): - """Text representation of the Editable instance.""" - return self.ts_node.text.decode("utf-8") - - @noapidoc - @commiter - def parse(self, ctx: CodebaseContext) -> None: - self.__dict__.pop("_source", None) - # Add self to the graph - self.code_block = self._parse_code_block(self.ts_node) - - self.code_block.parse() - # We need to clear the valid symbol/import names before we start resolving exports since these can be outdated. - self.invalidate() - sort_editables(self._nodes) - - @noapidoc - @commiter - def remove_internal_edges(self) -> None: - """Removes all its direct nodes and edges for each of its internal symbols and imports.""" - # ==== [ Classes, Assignments, Function, Interfaces ] ==== - for symbol in self.symbols(nested=True): - symbol._remove_internal_edges() - - # ==== [ Exports ] ==== - if hasattr(self, "exports"): - for export in self.exports: - export._remove_internal_edges() - - # ==== [ Imports ] ==== - for imp in self.imports: - imp._remove_internal_edges() - - @noapidoc - @commiter - def unparse(self, reparse: bool = False) -> list[Importable]: - """Removes all its direct nodes and edges for each of its internal symbols and imports. - - Returns a list of external import node ids that need to be re-resolved - """ - external_edges_to_resolve = [] - - # Collect node ids of all the file's nested children and itself to remove - node_ids_to_remove = set() - # ==== [ Classes, Assignments, Function, Interfaces ] ==== - for symbol in self.get_nodes(): - node_ids_to_remove.add(symbol.node_id) - - # ==== [ File ] ==== - node_ids_to_remove.add(self.node_id) - self._remove_internal_edges() - - # Save any external import resolution edges to be re-resolved before removing the nodes - for node_id in node_ids_to_remove: - external_edges_to_resolve.extend(self.ctx.predecessors(node_id)) - - # Finally, remove the nodes - for node_id in node_ids_to_remove: - if reparse and node_id == self.node_id: - continue - if self.ctx.has_node(node_id): - self.ctx.remove_node(node_id) - if not reparse: - self.ctx.filepath_idx.pop(self.file_path, None) - self._nodes.clear() - return list(filter(lambda node: self.ctx.has_node(node.node_id) and node is not None, external_edges_to_resolve)) - - @noapidoc - @commiter - def sync_with_file_content(self) -> None: - """Re-parses parent file and re-sets current TSNode.""" - self._pending_imports.clear() - self.ts_node = parse_file(self.filepath, self.content) - if self.node_id is None: - self.ctx.filepath_idx[self.file_path] = self.node_id - self.file_node_id = self.node_id - else: - assert self.ctx.has_node(self.node_id) - self.name = self.path.stem - self._range_index.clear() - self.parse(self.ctx) - - @staticmethod - @noapidoc - def get_extensions() -> list[str]: - """Returns a list of file extensions for the given programming language file.""" - - @abstractmethod - def symbol_can_be_added(self, symbol: Symbol) -> bool: - """Checks if the file type supports adding the given symbol. - - Determines whether the given symbol can be added to this file based on the symbol's type and the file's - language/type support. - - Args: - symbol (Symbol): The symbol to check for add compatibility. - - Returns: - bool: True if the symbol can be added to this file type, False otherwise. - """ - - @noapidoc - @commiter - def _compute_dependencies(self, *args, **kwargs) -> None: - self.invalidate() - self.code_block._compute_dependencies() - - @noapidoc - def invalidate(self): - self.__dict__.pop("valid_symbol_names", None) - self.__dict__.pop("valid_import_names", None) - for imp in self.imports: - imp.__dict__.pop("_wildcards", None) - - @classmethod - @noapidoc - def from_content(cls, filepath: str | PathLike | Path, content: str, ctx: CodebaseContext, sync: bool = True, verify_syntax: bool = True) -> Self | None: - """Creates a new file from content and adds it to the graph.""" - path = ctx.to_absolute(filepath) - - # Sanity check to ensure file is not a minified file - if is_minified_js(content): - logger.info(f"File {filepath} is a minified file. Skipping...", extra={"filepath": filepath}) - return None - - ts_node = parse_file(path, content) - if ts_node.has_error and verify_syntax: - logger.info("Failed to parse file %s", filepath) - return None - - update_graph = False - if not ctx.io.file_exists(path): - update_graph = True - path.parent.mkdir(parents=True, exist_ok=True) - ctx.io.write_file(path, content) - ctx.io.save_files({path}) - - if update_graph and sync: - ctx.add_single_file(path) - return ctx.get_file(filepath) - else: - return cls(ts_node, Path(filepath), ctx) - - @classmethod - @noapidoc - def create_from_filepath(cls, filepath: str, ctx: CodebaseContext) -> Self | None: - """Makes a new empty file and adds it to the graph. - - Graph-safe. - """ - if filepath in ctx.filepath_idx: - msg = f"File already exists in graph: {filepath}" - raise ValueError(msg) - - ts_node = parse_file(filepath, "") - if ts_node.has_error: - logger.info("Failed to parse file %s", filepath) - raise SyntaxError - - file = cls(ts_node, filepath, ctx) - file.write("", to_disk=True) - return file - - @property - @reader(cache=False) - def inbound_imports(self) -> list[TImport]: - """Returns all imports that are importing symbols contained in this file. - - Retrieves a list of Import objects representing imports that reference symbols or content defined in this file. - This includes imports of symbols declared in the file and imports of the file itself. - - Returns: - list[TImport]: A list of Import objects that reference content from this file. - """ - inbound_imports = set() - for s in self.symbols: - inbound_imports.update(i for i in s.symbol_usages(UsageType.DIRECT | UsageType.CHAINED) if isinstance(i, Import)) - for imp in self.imports: - inbound_imports.update(i for i in imp.symbol_usages(UsageType.DIRECT | UsageType.CHAINED) if isinstance(i, Import)) - - inbound_imports.update(i for i in self.symbol_usages(UsageType.DIRECT | UsageType.CHAINED) if isinstance(i, Import)) - return list(inbound_imports) - - @property - @reader(cache=False) - def import_statements(self) -> list[ImportStatement]: - """Returns all ImportStatements in the file, where each import statement can contain - multiple imports. - - Retrieves a list of all import statements in the file, sorted by their position. Each ImportStatement can contain - multiple individual imports (e.g., 'from module import a, b, c'). - - Returns: - list[ImportStatement]: A sorted list of import statements contained in the file. - """ - return sort_editables(x.import_statement for x in self.imports) - - @property - @reader - def importers(self) -> list[TImport]: - """Returns all imports that directly imports this file as a module. - - This method returns a list of imports where this file is imported directly as a module, - not individual symbols from this file. - - For example: - - `from a import ` will be included - - `from import a` will NOT be included - - Args: - None - - Returns: - list[TImport]: List of Import objects that import this file as a module, - sorted by file location. - """ - imps = [x for x in self.ctx.in_edges(self.node_id) if x[2].type == EdgeType.IMPORT_SYMBOL_RESOLUTION] - return sort_editables((self.ctx.get_node(x[0]) for x in imps), by_file=True, dedupe=False) - - @property - @reader(cache=False) - def imports(self) -> list[TImport]: - """List of all Imports in this file. - - Retrieves all imports defined in this file. The imports are sorted by their position in the file. - - Returns: - list[TImport]: A list of Import instances contained in this file, ordered by their position. - """ - return list(filter(lambda node: isinstance(node, Import), self.get_nodes(sort_by_id=True))) - - @reader - def has_import(self, symbol_alias: str) -> bool: - """Returns True if the file has an import with the given alias. - - Checks if the file contains an import statement with a specific alias. - - Args: - symbol_alias (str): The alias to check for in the import statements. - - Returns: - bool: True if an import with the given alias exists, False otherwise. - """ - aliases = [x.alias for x in self.imports if x.alias] - return any(a.source == symbol_alias for a in aliases) - - @reader - def get_import(self, symbol_alias: str) -> TImport | None: - """Returns the import with matching alias. Returns None if not found. - - Args: - symbol_alias (str): The alias name to search for. This can match either the direct import name or the aliased name. - - Returns: - TImport | None: The import statement with the matching alias if found, None otherwise. - """ - return next((x for x in self.imports if x.alias is not None and x.alias.source == symbol_alias), None) - - @proxy_property - def symbols(self, nested: bool = False) -> list[Symbol | TClass | TFunction | TGlobalVar | TInterface]: - """Returns all Symbols in the file, sorted by position in the file. - - Args: - nested: Include nested symbols - - Returns: - list[Symbol | TClass | TFunction | TGlobalVar | TInterface]: A list of all top-level symbols in the file, sorted by their position in the file. Symbols can be one of the following types: - - Symbol: Base symbol class - - TClass: Class definition - - TFunction: Function definition - - TGlobalVar: Global variable assignment - - TInterface: Interface definition - """ - return sort_editables([x for x in self.get_nodes(sort=False) if isinstance(x, Symbol) and (nested or x.is_top_level)], dedupe=False) - - @reader(cache=False) - @noapidoc - def get_nodes(self, *, sort_by_id: bool = False, sort: bool = True) -> Sequence[Importable]: - """Returns all nodes in the file, sorted by position in the file.""" - ret = self._nodes - if sort: - return sort_editables(ret, by_id=sort_by_id, dedupe=False) - return ret - - @reader - def get_symbol(self, name: str) -> Symbol | None: - """Gets a symbol by its name from the file. - - Attempts to resolve the symbol by name using name resolution rules first. If that fails, - searches through the file's symbols list for a direct name match. - - Args: - name (str): The name of the symbol to find. - - Returns: - Symbol | None: The found symbol, or None if not found. - """ - if symbol := next(self.resolve_name(name, self.end_byte), None): - if isinstance(symbol, Symbol): - return symbol - return next((x for x in self.symbols if x.name == name), None) - - @property - @reader(cache=False) - def symbols_sorted_topologically(self) -> list[Symbol]: - """Returns all Symbols in the file, sorted topologically (parents first). Robust to - dependency loops. - - Performs a topological sort of the symbols in the file based on symbol dependencies. This ensures that parent symbols - appear before their dependents while handling potential dependency loops gracefully. - - Args: - None - - Returns: - list[Symbol]: A list of symbols sorted topologically with parents appearing before their dependents. - """ - ids = [x.node_id for x in self.symbols] - # Create a subgraph based on G - subgraph = self.ctx.build_subgraph(ids) - symbol_names = pseudo_topological_sort(subgraph) - return [subgraph.get_node_data(x) for x in symbol_names] - - @property - @reader(cache=False) - def global_vars(self) -> list[TGlobalVar]: - """Returns all GlobalVars in the file. - - Retrieves all global variables (assignments) defined at the top level in the file, sorted by their position in the file. - - Returns: - list[TGlobalVar]: A list of global variable assignments, where each element is an Assignment representing a global variable. - """ - return [s for s in self.symbols if s.symbol_type == SymbolType.GlobalVar] - - @reader - def get_global_var(self, name: str) -> TGlobalVar | None: - """Returns a specific global var by name. Returns None if not found. - - Args: - name (str): The name of the global variable to find. - - Returns: - TGlobalVar | None: The global variable if found, None otherwise. - """ - return next((x for x in self.global_vars if x.name == name), None) - - @property - @reader(cache=False) - def classes(self) -> list[TClass]: - """Returns all Classes in the file. - - Returns a list of all Classes defined in the file, sorted by position in the file. - Use this method to iterate over all classes in a file or to get information about class definitions. - - Returns: - list[TClass]: A list of Class objects in the file, sorted by position in the file. - """ - return [s for s in self.symbols if s.symbol_type == SymbolType.Class] - - @reader - def get_class(self, name: str) -> TClass | None: - """Returns a specific Class by full name. Returns None if not found. - - Searches for a class in the file with the specified name. Similar to get_symbol, but specifically for Class types. - - Args: - name (str): The full name of the class to search for. - - Returns: - TClass | None: The matching Class object if found, None otherwise. - """ - if symbol := next(self.resolve_name(name, self.end_byte), None): - if isinstance(symbol, Class): - return symbol - - @property - @reader(cache=False) - def functions(self) -> list[TFunction]: - """Returns all Functions in the file. - - Returns a list of all top-level functions defined in the file, sorted by their position in the file. - Does not include nested functions (functions defined within other functions or classes). - - Returns: - list[TFunction]: A list of Function objects representing all top-level functions in the file. - """ - return [s for s in self.symbols if s.symbol_type == SymbolType.Function] - - @reader - def get_function(self, name: str) -> TFunction | None: - """Returns a specific Function by name. - - Gets a Function object from the file by searching for a function with the given name. - - Args: - name (str): The name of the function to find. - - Returns: - TFunction | None: The matching Function object if found, None otherwise. - """ - return next((x for x in self.functions if x.name == name), None) - - @noapidoc - @reader - def get_node_by_name(self, name: str) -> Symbol | TImport | None: - """Returns something defined in this file by name. - - Used during import resolution - """ - symbol = self.get_symbol(name) - if symbol is not None: - return symbol - imp = self.get_import(name) - if imp is not None: - return imp - return None - - @cached_property - @noapidoc - @reader(cache=True) - def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImport]]: - """Returns a dict mapping name => Symbol (or import) in this file.""" - valid_symbol_names = {} - for s in self.symbols: - valid_symbol_names[s.full_name] = s - for imp in self.imports: - for name, dest in imp.names: - valid_symbol_names[name] = dest - return valid_symbol_names - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: - """Resolves a name to a symbol, import, or wildcard import within the file's scope. - - Performs name resolution by first checking the file's valid symbols and imports. When a start_byte - is provided, ensures proper scope handling by only resolving to symbols that are defined before - that position in the file. - - Args: - name (str): The name to resolve. - start_byte (int | None): If provided, only resolves to symbols defined before this byte position - in the file. Used for proper scope handling. Defaults to None. - strict (bool): When True and using start_byte, only yields symbols if found in the correct scope. - When False, allows falling back to global scope. Defaults to True. - - Yields: - Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches - the name and scope requirements. Yields at most one result. - """ - if resolved := self.valid_symbol_names.get(name): - # If we have a start_byte and the resolved symbol is after it, - # we need to look for earlier definitions of the symbol - if start_byte is not None and resolved.end_byte > start_byte: - # Search backwards through symbols to find the most recent definition - # that comes before our start_byte position - for symbol in reversed(self.symbols): - if symbol.start_byte <= start_byte and symbol.name == name: - yield symbol - return - # If strict mode and no valid symbol found, return nothing - if not strict: - return - # Either no start_byte constraint or symbol is before start_byte - yield resolved - return - return - - @property - @reader - def import_module_name(self) -> str: - """Returns the module name that this file gets imported as. - - Gets the module name for this file in the context of imports. This name is used when other files import this file, either directly or when importing symbols from this file. - - Returns: - str: The module name used when importing this file. - """ - return self.get_import_module_name_for_file(self.filepath, self.ctx) - - @classmethod - @abstractmethod - @noapidoc - def get_import_module_name_for_file(cls, filepath: str, ctx: CodebaseContext) -> str: ... - - @abstractmethod - def remove_unused_exports(self) -> None: - """Removes unused exports from the file. - - Removes all exports that have no usages by any other files in the codebase. This helps reduce unnecessary exports and maintain a cleaner API surface. - - Returns: - None - """ - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @mover - def update_filepath(self, new_filepath: str) -> None: - """Renames the file and updates all imports to point to the new location. - - When a file is renamed, this method does three things: - 1. Creates a new file node in the graph with the new filepath - 2. Moves the file on disk to the new location - 3. Updates all inbound imports to point to the new module location - - Args: - new_filepath (str): The new filepath to move the file to. - - Returns: - None - """ - # =====[ Add the new filepath as a new file node in the graph ]===== - new_file = self.ctx.node_classes.file_cls.from_content(new_filepath, self.content, self.ctx) - # =====[ Change the file on disk ]===== - super().update_filepath(new_filepath) - # =====[ Update all the inbound imports to point to the new module ]===== - new_module_name = new_file.import_module_name - for imp in self.inbound_imports: - imp.set_import_module(new_module_name) - - @writer - def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None: - """Adds an import to the file. - - This method adds an import statement to the file. It can handle both string imports and symbol imports. - If the import already exists in the file, or is pending to be added, it won't be added again. - If there are existing imports, the new import will be added before the first import, - otherwise it will be added at the beginning of the file. - - Args: - imp (Symbol | str): Either a Symbol to import or a string representation of an import statement. - alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None. - import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN. - is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False. - - Returns: - Import | None: The existing import for the symbol if found, otherwise None. - """ - # Handle Symbol imports - if isinstance(imp, str): - # Handle string imports - import_string = imp - # Check for duplicate imports - if any(import_string.strip() in imp.source for imp in self.imports): - return None - else: - # Check for existing imports of this symbol - imports = self.imports - match = next((x for x in imports if x.imported_symbol == imp), None) - if match: - return match - - # Convert symbol to import string - import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import) - - if import_string.strip() in self._pending_imports: - # Don't add the import string if it will already be added by another symbol - return None - - # Add to pending imports and setup undo - self._pending_imports.add(import_string.strip()) - self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear()) - - # Insert the import at the appropriate location - if self.imports: - self.imports[0].insert_before(import_string, priority=1) - else: - self.insert_before(import_string, priority=1) - - return None - - @writer - def add_symbol_from_source(self, source: str) -> None: - """Adds a symbol to a file from a string representation. - - This method adds a new symbol definition to the file by appending its source code string. The symbol will be added - after existing symbols if present, otherwise at the beginning of the file. - - Args: - source (str): String representation of the symbol to be added. This should be valid source code for - the file's programming language. - - Returns: - None: The symbol is added directly to the file's content. - """ - symbols = self.symbols - if len(symbols) > 0: - symbols[-1].insert_after("\n" + source, fix_indentation=True) - else: - self.insert_after("\n" + source) - - @writer - def add_symbol(self, symbol: Symbol, should_export: bool = True) -> Symbol | None: - """Adds `symbol` to the file. - - Adds the given symbol to the file, optionally exporting it if applicable. If the symbol already exists in the file, returns the existing symbol. - - Args: - symbol (Symbol): The symbol to add to the file. - should_export (bool, optional): Whether to export the symbol. Defaults to True. - - Returns: - Symbol | None: The existing symbol if it already exists in the file or None if it was added. - - Raises: - ValueError: If the symbol type cannot be added to this file type. - """ - # Check if the symbol already exists in file - existing_symbol = self.get_symbol(symbol.name) - if existing_symbol is not None: - return existing_symbol - if not self.symbol_can_be_added(symbol): - msg = f"Symbol {symbol.name} cannot be added to this file type." - raise ValueError(msg) - - source = symbol.source - if isinstance(symbol, TSFunction) and symbol.is_arrow: - raw_source = symbol._named_arrow_function.text.decode("utf-8") - else: - raw_source = symbol.ts_node.text.decode("utf-8") - if should_export and hasattr(symbol, "export") and (not symbol.is_exported or raw_source not in symbol.export.source): - source = source.replace(raw_source, f"export {raw_source}") - - self.add_symbol_from_source(source) - - @noapidoc - @writer - def convert_js_to_esm(self) -> None: - """Converts a JS file to an ES module.""" - # Convert `require` to `import` - content = self.content - lines = content.split("\n") - converted_lines = [] - router_lines = [] - last_import_index = -1 - import_fixed = False - - for i, line in enumerate(lines): - # Handle require statements with destructuring - if "require(" in line and "{" in line: - line = re.sub( - r"const {([\w\s,]+)} = require\('(.+?)'\);", - lambda m: f"import {{{m.group(1)}}} from '{m.group(2)}';", - line, - ) - last_import_index = i - import_fixed = True - - # Handle regular require statements - elif "require(" in line: - line = re.sub(r"const (\w+) = require\('(.+?)'\);", r"import \1 from '\2';", line) - last_import_index = i - import_fixed = True - - # Convert module.exports - if "module.exports = " in line: - line = re.sub(r"module.exports = (\w+);", r"export default \1;", line) - - # TODO: remove express.Router() specifics - # Check for express.Router() assignment - if "= express.Router();" in line and import_fixed: - router_lines.append((i, line + "\n")) - else: - converted_lines.append(line) - - # Reinsert lines that contain "= express.Router();" after the last import - if router_lines: - # If no imports were found, router lines will be added at the beginning - insert_position = last_import_index + 1 if last_import_index != -1 else 0 - for _, router_line in router_lines: - converted_lines.insert(insert_position, router_line) - insert_position += 1 - - self.write("\n".join(converted_lines), to_disk=True) - - @property - @noapidoc - def viz(self) -> VizNode: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - - #################################################################################################################### - # AST-GREP - #################################################################################################################### - - # @writer - # def ast_grep_replace(self, pattern: str, replace: str) -> None: - # """Searches the file's AST for nodes that match the query""" - # root = SgRoot(self.content, "python").root() # 1. parse - # node = root.find(pattern=pattern) # 3. find - # edit = node.replace(replace) - # new_src = node.commit_edits([edit]) - # self.edit(new_src) - @property - @noapidoc - @reader(cache=True) - def valid_import_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImport]]: - """Returns a dict mapping name => Symbol (or import) in this file that can be imported from - another file. - """ - return self.valid_symbol_names - - @noapidoc - @reader - @override - def resolve_attribute(self, name: str) -> Symbol | TImport | None: - return self.valid_import_names.get(name, None) - - @property - @noapidoc - def self_dest(self) -> HasBlock: - """Returns the symbol usage resolution destination node for the symbol.""" - return self - - @property - @noapidoc - def parent_symbol(self) -> Self: - return self - - @reader - def find_by_byte_range(self, range: Range) -> list[Editable]: - """Finds all editable objects that overlap with the given byte range in the file. - - Uses the file's range index to efficiently retrieve all editable objects (like functions, - classes, variables) that intersect with the specified byte range. - - Args: - range (Range): The byte range to search within the file. - - Returns: - list[Editable]: A list of all Editable objects that overlap with the given range. - """ - return self._range_index.get_all_for_range(range) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - return self.get_nodes() diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py deleted file mode 100644 index ea5b8fc95..000000000 --- a/src/codegen/sdk/core/function.py +++ /dev/null @@ -1,417 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Self, override - -from typing_extensions import TypeVar - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.detached_symbols.code_block import CodeBlock -from codegen.sdk.core.detached_symbols.decorator import Decorator -from codegen.sdk.core.detached_symbols.parameter import Parameter -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.callable import Callable -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.interfaces.supports_generic import SupportsGenerics -from codegen.sdk.core.statements.statement import StatementType -from codegen.sdk.enums import SymbolType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import cached_property -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.visualizations.enums import VizNode - -if TYPE_CHECKING: - from collections.abc import Generator, Sequence - - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.export import Export - from codegen.sdk.core.file import File - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.statements.return_statement import ReturnStatement - from codegen.sdk.core.symbol import Symbol - - -TDecorator = TypeVar("TDecorator", bound="Decorator", default=Decorator) -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock", default=CodeBlock) -TParameter = TypeVar("TParameter", bound="Parameter", default=Parameter) -TType = TypeVar("TType", bound="Type", default=Type) - - -@apidoc -class Function( - SupportsGenerics[TType], - HasBlock[TCodeBlock, TDecorator], - Callable[TParameter, TType], - Chainable, - Generic[TDecorator, TCodeBlock, TParameter, TType], -): - """Abstract representation of a Function. - - Attributes: - symbol_type: The type of symbol, set to SymbolType.Function. - """ - - symbol_type = SymbolType.Function - - @property - @abstractmethod - def is_private(self) -> bool: - """Determines if a function has a private access modifier. - - A function is considered private if it starts with an underscore (_) in Python, or has a private keyword in other languages. - - Returns: - bool: True if the function has a private access modifier, False otherwise. - """ - - @property - @abstractmethod - def is_magic(self) -> bool: - """Returns True if function is a magic method. - - Determines if the function is a magic method based on Python's double underscore naming convention. - A magic method in Python is a special method surrounded by double underscores (e.g., __init__, __str__). - - Returns: - bool: True if the function is a magic method, False otherwise. - """ - - @property - def is_overload(self) -> bool: - """Indicates whether the function is an overloaded function in a multi-function definition. - - Determines if this function is part of a function overload group in the codebase. This property helps identify - functions that have multiple implementations with different parameter types. - - Returns: - bool: False, as this base implementation does not support overloads. - """ - return False - - @property - @abstractmethod - def is_property(self) -> bool: - """Returns whether this function is a property. - - Determines if the function has been decorated with `@property` decorator. - - Returns: - bool: True if the function is a property, False otherwise. - """ - pass - - @property - def is_method(self) -> bool: - """Returns whether the function is a method of a class. - - Determines if this function is defined within a class context. It checks if the parent of the function is a Class. - - Returns: - bool: True if the function is a method within a class, False otherwise. - """ - from codegen.sdk.core.class_definition import Class - - return isinstance(self.parent.parent.parent, Class) - - @property - def is_constructor(self) -> bool: - """Determines if the current function is a constructor method. - - A constructor method is a special method associated with a class. This property checks if the function - is both a class method and has a name that matches the class's constructor keyword. - - Returns: - bool: True if the function is a constructor method of a class, False otherwise. - """ - return self.is_method and self.name == self.parent_class.constructor_keyword - - @property - def is_async(self) -> bool: - """Returns True if the function is asynchronous. - - A property that determines whether the function has been defined with the 'async' keyword. - - Returns: - bool: True if the function is asynchronous, False otherwise. - """ - return any("async" == x.type for x in self.ts_node.children) - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: - from codegen.sdk.core.class_definition import Class - - for symbol in self.valid_symbol_names: - if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte): - yield symbol - return - yield from super().resolve_name(name, start_byte, strict=strict) - - @cached_property - @noapidoc - def valid_symbol_names(self) -> list[Importable]: - return sort_editables(self.parameters.symbols + self.descendant_symbols, reverse=True) - - # Faster implementation which uses more memory - # @noapidoc - # @reader - # def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: - # if symbols := self.valid_symbol_names.get(name, None): - # for symbol in symbols: - # from codegen.sdk.core.class_definition import Class - # - # if (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte: - # return symbol - # return super().resolve_name(name, start_byte) - # - # @cached_property - # @noapidoc - # def valid_symbol_names(self) -> dict[str, list[Importable]]: - # ret = defaultdict(list) - # for elem in sort_editables(self.parameters.symbols + self.descendant_symbols, reverse=True): - # ret[elem.name].append(elem) - # return ret - # - ########################################################################################################### - # PROPERTIES - ########################################################################################################### - - @property - @abstractmethod - @reader - def function_signature(self) -> str: - """Returns the signature of the function as a string. - - A property that returns the complete function signature including its declaration, parameters, and return type annotation. The signature format - varies based on the language, but follows the standard syntax for function declarations in that language. - - Returns: - str: A string representation of the function's complete signature. - """ - # TODO: rename to declaration_docstring? - - @property - @reader - def return_statements(self) -> list[ReturnStatement]: - """Returns a list of all return statements within this function's body. - - Provides access to return statements in the function's code block, which is useful for analyzing return patterns, - identifying early returns, and examining return types. - - Args: - None - - Returns: - list[ReturnStatement]: A list of all return statements found within the function's body. - """ - return self.code_block.get_statements(statement_type=StatementType.RETURN_STATEMENT) - - @property - @reader - def nested_functions(self) -> list[Self]: - """Returns a list of nested functions defined within this function's code block. - - Retrieves all functions that are defined within the current function's body. The functions are sorted by their position in the file. - - Returns: - list[Self]: A list of Function objects representing nested functions within this function's body, sorted by position in the file. - """ - functions = [m.symbol for m in self.code_block.symbol_statements if isinstance(m.symbol, self.__class__)] - return functions - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def set_return_type(self, new_return_type: str) -> None: - """Sets the return type annotation for the function. - - Sets or updates the return type annotation of the function. If an empty string is provided, - the return type annotation will be removed. - - Args: - new_return_type (str): The new return type annotation to be set. Use an empty string to remove - the return type annotation. - - Returns: - None - """ - # TODO: other set APIs should be consistent and also offer a remove option - # TODO: if new_return_type is empty string, should remove the return type - self.return_type.edit(new_return_type) - - @writer - def asyncify(self) -> None: - """Modifies the function to be asynchronous. - - Converts a synchronous function to be asynchronous by adding the 'async' keyword to its definition if it is not already - marked as asynchronous. - - Returns: - None - - Note: - This method has no effect if the function is already asynchronous. - """ - if self.is_async: - return - - self.add_keyword("async") - - @writer - def rename_local_variable(self, old_var_name: str, new_var_name: str, fuzzy_match: bool = False) -> None: - """Renames a local variable and all its usages within a function body. - - The method searches for matches of the old variable name within the function's code block and replaces them with the new variable name. It excludes parameter names from being renamed. - - Args: - old_var_name (str): The current name of the local variable to be renamed. - new_var_name (str): The new name to give to the local variable. - fuzzy_match (bool, optional): If True, matches variable names that contain old_var_name. Defaults to False. - - Returns: - None: The method modifies the AST in place. - """ - matches = self.code_block.get_assignments(old_var_name, fuzzy=fuzzy_match, parameters=False) - for match in matches: - new_name = new_var_name - if fuzzy_match: - new_name = match.name.replace(old_var_name, new_var_name) - match.rename(new_name) - - @writer - def insert_statements(self, lines: str, index: int = 0) -> None: - """Inserts lines of code into the function body at the specified index. - - Adds the provided lines as statements within the function's body at the given position. If index is 0, the lines will be prepended at the start of the function body. - - Args: - lines (str): The code lines to insert into the function body. - index (int, optional): The position in the function body where the lines should be inserted. Defaults to 0. - - Returns: - None - - Raises: - ValueError: If the provided index is out of range for the function's statements. - """ - if index == 0: - return self.prepend_statements(lines) - - statements = self.code_block.statements - if index >= len(statements): - msg = f"Index {index} out of range for function {self.name}" - raise ValueError(msg) - - first_statement = self.code_block.statements[index] - first_statement.insert_before(lines) - - @writer - def prepend_statements(self, lines: str) -> None: - """Prepends the provided code to the beginning of the function body. - - Args: - lines (str): The code to be prepended to the function body. - - Returns: - None - - Note: - This method handles indentation automatically to maintain proper code formatting. - """ - self.code_block.statements[0].insert_before(lines, fix_indentation=True) - - @writer - def add_statements(self, lines: str) -> None: - """Adds statements to the end of a function body. - - Adds the provided lines of code to the end of the function's code block. The method handles proper indentation automatically. - - Args: - lines (str): The lines of code to be added at the end of the function body. - - Returns: - None - """ - last_statement = self.code_block.statements[-1] - last_statement.insert_after(lines, fix_indentation=True) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if self.is_method and self.is_property: - if ret := self.return_type: - yield from self.with_resolution_frame(ret, direct=False) - else: - yield ResolutionStack(self) - - @property - @noapidoc - def viz(self) -> VizNode: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - - @property - @noapidoc - def parent_symbol(self) -> Symbol | File | Import | Export: - """Searches up its parent stack until it finds a top level symbol.""" - if self.is_method: - if self.parent_class.is_top_level: - return self - return super().parent_symbol - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets all function calls within the function and its parameters. - - Retrieves all function calls that appear within this function's body and within its parameter - declarations, sorted by position in the file. - - Args: - None - - Returns: - list[FunctionCall]: A sorted list of all function calls within the function and its parameters. - Function calls may appear multiple times in the list. - """ - fcalls = super().function_calls - for p in self.parameters: - fcalls.extend(p.function_calls) - return sort_editables(fcalls, dedupe=False) - - #################################################################################################################### - # EXTERNAL APIS - #################################################################################################################### - - @property - @reader - def inferred_return_type(self) -> str | None: - """Gets the inferred type of the function from the language's native language engine / compiler. - - Only enabled for specific languages that support native type inference. - """ - if self.ctx.language_engine: - return self.ctx.language_engine.get_return_type(self) - else: - msg = "Language engine not enabled for this repo or language." - raise NotImplementedError(msg) - - @property - @noapidoc - def descendant_symbols(self) -> Sequence[Importable]: - symbols = [self] - for param in self.parameters: - symbols.extend(param.descendant_symbols) - if self.return_type: - symbols.extend(self.return_type.descendant_symbols) - symbols.extend(self.code_block.descendant_symbols) - return symbols - - @noapidoc - def register_api(self, url: str): - self.ctx.global_context.multigraph.api_definitions[url] = self diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py deleted file mode 100644 index c6d11af9d..000000000 --- a/src/codegen/sdk/core/import_resolution.py +++ /dev/null @@ -1,717 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.codebase.transactions import TransactionPriority -from codegen.sdk.core.autocommit import commiter, reader, remover, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.core.statements.import_statement import ImportStatement -from codegen.sdk.enums import EdgeType, ImportType, NodeType -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.output.constants import ANGULAR_STYLE -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.visualizations.enums import VizNode - -if TYPE_CHECKING: - from collections.abc import Generator - - import rich.repr - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.interfaces.exportable import Exportable - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol import Symbol - - -TSourceFile = TypeVar("TSourceFile", bound="SourceFile") - - -@dataclass -class ImportResolution(Generic[TSourceFile]): - """Represents the resolution of an import statement to a symbol defined in another file. - - Has the following properties: - - from_file: Optional[SourceFile]. None when import resolves to an external module - - symbol: Optional[Union[Symbol, ExternalModule]]. None when import resolves to an external module - - imports_file: bool. True when we import the entire file (e.g. `from a.b.c import foo`) - """ - - from_file: TSourceFile | None = None # SourceFile object. None when import resolves to an external module - symbol: Symbol | ExternalModule | None = None # None when we import the entire file (e.g. `from a.b.c import foo`) - imports_file: bool = False # True when we import the entire file (e.g. `from a.b.c import foo`) - - -TSourceFile = TypeVar("TSourceFile", bound="SourceFile") - - -@apidoc -class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile], HasAttribute[TSourceFile]): - """Represents a single symbol being imported. - - Attributes: - to_file_id: The node ID of the file to which this import belongs. - module: The module from which the symbol is being imported, if applicable. - symbol_name: The name of the symbol being imported. For instance import a as b has a symbol_name of a. - alias: The alias of the imported symbol, if one exists. - node_type: The type of node, set to NodeType.IMPORT. - import_type: The type of import, indicating how the symbol is imported. - import_statement: The statement that this import is part of. - import_statement: the ImportStatement that this import belongs to - """ - - to_file_id: NodeId - module: Editable | None - symbol_name: Editable | None - alias: Editable | None - node_type: ClassVar[Literal[NodeType.IMPORT]] = NodeType.IMPORT - import_type: ImportType - import_statement: ImportStatement - - def __init__( - self, - ts_node: TSNode, - file_node_id: NodeId, - ctx: CodebaseContext, - parent: ImportStatement, - module_node: TSNode | None, - name_node: TSNode | None, - alias_node: TSNode | None, - import_type: ImportType = ImportType.UNKNOWN, - ) -> None: - self.to_file_id = file_node_id - super().__init__(ts_node, file_node_id, ctx, parent) - self.module = self.ctx.parser.parse_expression(module_node, self.file_node_id, ctx, self, default=Name) if module_node else None - self.alias = self.ctx.parser.parse_expression(alias_node, self.file_node_id, ctx, self, default=Name) if alias_node else None - self.symbol_name = self.ctx.parser.parse_expression(name_node, self.file_node_id, ctx, self, default=Name) if name_node else None - self._name_node = self._parse_expression(name_node, default=Name) - self.import_type = import_type - - def __rich_repr__(self) -> rich.repr.Result: - if self.module: - yield "module", self.module.source - if self.name: - yield "name", self.name - if self.alias: - yield "alias", self.alias.source, self.name - yield "wildcard", self.is_wildcard_import(), False - yield from super().__rich_repr__() - - __rich_repr__.angular = ANGULAR_STYLE - - @noapidoc - @abstractmethod - def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSourceFile] | None: - """Resolves the import to a symbol defined outside the file. - - Returns an ImportResolution object. - """ - - @noapidoc - @commiter - def add_symbol_resolution_edge(self) -> None: - """Resolves the import to a symbol defined outside the file. - - If import is successfully resolved, a new edge is added to the graph. Must be called after - `parse()` has been called for every file in the codebase. Returns the node id of the - resolved import object. - """ - resolution = self.resolve_import() - - # =====[ Case: Can't resolve the filepath ]===== - if resolution is None: - # =====[ Check if we are importing an external module in the graph ]===== - ext = self.ctx.get_external_module(self.source, self._unique_node.source) - if ext is None: - ext = ExternalModule.from_import(self) - self.ctx.add_edge(self.node_id, ext.node_id, type=EdgeType.IMPORT_SYMBOL_RESOLUTION) - # =====[ Case: Can resolve the filepath ]===== - elif resolution.symbol: - if resolution.symbol.node_id == self.node_id: - return [] # Circular to self - self.ctx.add_edge( - self.node_id, - resolution.symbol.node_id, - type=EdgeType.IMPORT_SYMBOL_RESOLUTION, - ) - - elif resolution.imports_file: - self.ctx.add_edge(self.node_id, resolution.from_file.node_id, type=EdgeType.IMPORT_SYMBOL_RESOLUTION) - # for symbol in resolution.from_file.symbols: - # usage = SymbolUsage(parent_symbol_name=self.name, child_symbol_name=self.name, type=SymbolUsageType.IMPORTED, match=self, usage_type=UsageType.DIRECT) - # self.ctx.add_edge(self.node_id, symbol.node_id, type=EdgeType.SYMBOL_USAGE, usage=usage) - - # Referenced symbols that we can't find. - # Could be: - # - a broken import - # - it's actually importing a full file (i.e. resolution.imports_file should be True) - # - an indirect import of an external module - # TODO: add as external module only if it resolves to an external module from resolution.from_file - # Solution: return the resolution object to be processed in a separate loop in `compute_codebase_graph` - return [] - - @property - @reader - def name(self) -> str | None: - """Returns the name or alias of the symbol being imported. - - Returns an identifier for the import which can be either the alias name of an imported symbol if it exists, or None. - For example, in `from a.b import c as d`, this returns 'd'. - For example, in `import { c as d } from 'a/b'`, this returns 'd'. - - Args: - None - - Returns: - str | None: The alias of the imported symbol if it exists, otherwise None. - """ - if self.alias is None: - return None - return self.alias.source - - @reader - def is_aliased_import(self) -> bool: - """Returns True if this import is aliased. - - Checks if the current import has an alias that is different from its original name. - For example, in 'from foo import bar as baz', returns True because 'baz' is different from 'bar'. - In 'from foo import bar', returns False because there is no alias. - - Args: - None - - Returns: - bool: True if the import has an alias different from its original name, False otherwise. - """ - if self.alias is None or self.symbol_name is None: - return False - return self.alias.source != self.symbol_name.source - - @abstractmethod - def is_module_import(self) -> bool: - """Returns True if this import is importing an entire module/file. - - Used to identify module imports vs symbol imports. This method evaluates whether - the import is bringing in an entire module rather than specific symbols. - - Returns: - bool: True if this import represents a module/file import, False if it represents a symbol import. - """ - - @reader - def is_symbol_import(self) -> bool: - """Returns True if this import is importing a symbol rather than a module. - - A symbol import is any import that references a specific object from a module, rather than importing the entire module. This method is the opposite of `is_module_import`. - - Returns: - bool: True if this import is a symbol import, False if it is a module import. - """ - return not self.is_module_import() - - @reader - def is_wildcard_import(self) -> bool: - """Returns True if the import symbol is a wildcard import. - - Determines whether this Import is a wildcard import, which means it imports all named exports from a module. - Wildcard imports are represented using `*` in Python (e.g. `from module import *`) - or `*` in TypeScript (e.g. `import * as name from 'module'`). - - Returns: - bool: True if this is a wildcard import, False otherwise. - """ - return self.import_type == ImportType.WILDCARD - - @property - @abstractmethod - def namespace(self) -> str | None: - """Returns the namespace prefix that must be used with dot notation to reference the - imported symbol. - - The namespace is the prefix required to access the imported symbol through dot notation. - For example, in 'import foo as bar', bar is the namespace needed to access foo's exports as 'bar.xyz'. - - Returns: - str | None: The namespace prefix if one exists, None otherwise. - - For symbol imports or unnamed wildcard imports: None - - For module imports: The module name or the module alias - """ - - @property - @reader - def from_file(self) -> TSourceFile | None: - """Returns the SourceFile that an Import is importing from. - - This property traverses the Symbol edge to find the source file where the imported symbol is defined. - - Args: - None - - Returns: - TSourceFile | None: The SourceFile containing the imported symbol. - Returns None if: - - The import resolves to an external module - - The imported symbol cannot be resolved - """ - imported = self.imported_symbol - if imported is None: - return None - elif imported.node_type == NodeType.EXTERNAL: - return None - elif imported.__class__.__name__.endswith("SourceFile"): # TODO - this is a hack for when you import a full file/module - return imported - else: - return imported.file - - @property - @reader - def to_file(self) -> TSourceFile: - """SourceFile that this import resides in. - - Returns the source file in which the current import statement is located. This property helps track the location - and context of import statements within the codebase graph. - - Returns: - TSourceFile: The source file containing this import statement. - """ - return self.ctx.get_node(self.to_file_id) - - @property - @reader - def resolved_symbol(self) -> Symbol | ExternalModule | TSourceFile | None: - """Returns the symbol, source file, or external module that this import ultimately resolves - to. - - This method follows indirect import chains to find the final resolved object. For example, if file A imports from B, which imports from C, this method returns the object from C. - - Returns: - Symbol | ExternalModule | TSourceFile | None: The final resolved object that this import points to. - - Symbol: If the import resolves to a symbol defined in the codebase - - ExternalModule: If the import resolves to an external module - - TSourceFile: If the import resolves to an entire source file - - None: If the import cannot be resolved - - Note: - If there is a circular import chain, returns the first repeated import in the chain. - """ - # TODO: rename to `resolved_object` to capture that it can return a SourceFile instance as well - imports_seen = set() - resolved_symbol = self.imported_symbol - - while resolved_symbol is not None and resolved_symbol.node_type == NodeType.IMPORT: - if resolved_symbol in imports_seen: - return resolved_symbol - - imports_seen.add(resolved_symbol) - resolved_symbol = resolved_symbol.imported_symbol - - return resolved_symbol - - @reader - def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalModule | TSourceFile | Import | None: - """Returns the symbol directly being imported, including an indirect import and an External - Module. - """ - from codegen.sdk.python.file import PyFile - from codegen.sdk.typescript.file import TSFile - - symbol = next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.IMPORT_SYMBOL_RESOLUTION, sort=False)), None) - if symbol is None: - # Unresolve import - could occur during unparse() - return None - if resolve_exports and isinstance(symbol, TSFile): - name = self.symbol_name.source if self.symbol_name else "" - if self.import_type == ImportType.DEFAULT_EXPORT: - assert isinstance(symbol, TSFile) - default = symbol - if len(symbol.default_exports) == 1 and name != symbol.name: - default = symbol.default_exports[0] - return symbol.valid_import_names.get(name, default) - if self.import_type == ImportType.NAMED_EXPORT: - if export := symbol.valid_import_names.get(name, None): - return export - elif resolve_exports and isinstance(symbol, PyFile): - name = self.symbol_name.source if self.symbol_name else "" - if self.import_type == ImportType.NAMED_EXPORT: - if symbol.name == name: - return symbol - if imp := symbol.valid_import_names.get(name, None): - return imp - - if symbol is not self: - return symbol - - @property - @reader - def imported_symbol(self) -> Symbol | ExternalModule | TSourceFile | Import | None: - """Returns the symbol directly being imported, including an indirect import and an External - Module. - - This property resolves the import's target and handles export-chain resolution. If the imported symbol - is an export, this method will follow the export chain until it reaches the final target. - - Returns: - Union[Symbol, ExternalModule, TSourceFile, Import, None]: The final resolved import target. - Can be: - - Symbol: The imported symbol - - ExternalModule: If import resolves to an external module - - SourceFile: If importing an entire file/module - - Import: If there is a circular import - - None: If the import is unresolved - """ - if symbol := self._imported_symbol(): - while symbol and symbol.node_type == NodeType.EXPORT: - symbol = symbol.exported_symbol - return symbol - - @property - @abstractmethod - def imported_exports(self) -> list[Exportable]: - """Returns the enumerated list of symbols imported from a module import. - - If the import represents a module/file import, returns a list of all exported symbols from that module. - If the import is a symbol import, returns a list containing only the imported symbol. - - Returns: - list[Exportable]: A list of exported symbols. For module imports, contains all exports from the module. - For symbol imports, contains only the single imported symbol. - """ - - @property - @reader - def is_dynamic(self) -> bool: - """Determines if this import is dynamically loaded based on its parent symbol. - - A dynamic import is one that appears within control flow or scope-defining statements, such as: - - Inside function definitions - - Inside class definitions - - Inside if/else blocks - - Inside try/except blocks - - Inside with statements - - Dynamic imports are only loaded when their containing block is executed, unlike - top-level imports which are loaded when the module is imported. - - Examples: - Dynamic imports: - ```python - def my_function(): - import foo # Dynamic - only imported when function runs - - - if condition: - from bar import baz # Dynamic - only imported if condition is True - - with context(): - import qux # Dynamic - only imported within context - ``` - - Static imports: - ```python - import foo # Static - imported when module loads - from bar import baz # Static - imported when module loads - ``` - - Returns: - bool: True if the import is dynamic (within a control flow or scope block), - False if it's a top-level import. - """ - return self.parent_of_types(self.ctx.node_classes.dynamic_import_parent_types) is not None - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def set_import_module(self, new_module: str) -> None: - """Sets the module of an import. - - Updates the module of an import statement while maintaining the import symbol. For named imports, this changes the module path that follows 'from' or is wrapped in quotes. - - Args: - new_module (str): The new module path to import from. - - Returns: - None - - Note: - If the import has no module (e.g., direct imports), this method has no effect. - """ - # TODO: if the import belongs in a multi-import statement, we need to break out the imports into individual import statements (CG-8349) - if self.module is None: - return - - self.module.source = new_module - - @writer - def set_import_symbol_alias(self, new_alias: str) -> None: - """Sets alias or name of an import at the declaration level. - - Changes the name used to refer to an imported symbol at its import declaration, either by modifying the alias if one exists, - or the name itself if no alias is used.The change only affects the import declaration, not import usages or callsites. - - Args: - new_alias (str): The new name to use for the imported symbol. - - Returns: - None - """ - if self.alias == self.symbol_name: - self.rename(new_alias) - else: - for imported_usage in self.usages: - if imported_usage.match is not None: - imported_usage.match.edit(new_alias) - self.alias.source = new_alias - - def rename(self, new_name: str, priority: int = 0) -> tuple[NodeId, NodeId]: - """Renames the import symbol and updates all its usages throughout the codebase. - - Renames both the import symbol name and any usage references to match the new name. If the import is aliased, only changes the symbol name and not the alias. - - Args: - new_name (str): The new name to give the imported symbol. - priority (int, optional): Priority of the rename operation. Defaults to 0. - - Returns: - tuple[NodeId, NodeId]: A tuple containing (file_node_id, new_import_node_id). - - Note: - For an import like 'from a.b.c import d as e', renaming with 'XYZ' will result in: - 'from a.b.c import XYZ as e' - - For an import like 'import { d as e } from 'a/b/c'', renaming with 'XYZ' will result in: - 'import { XYZ as e } from 'a/b/c'' - """ - if self.is_aliased_import(): - self.symbol_name.edit(new_name) - else: - super().rename(new_name, priority) - - @remover - def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Remove this import from the import statement. - - If this import belongs to an import statement with multiple imports, removes just this single import from it. - If this is the only import in the import statement, removes the entire import statement. - - Args: - delete_formatting (bool, optional): Whether to delete any associated formatting. Defaults to True. - priority (int, optional): The priority of the operation. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate imports. Defaults to True. - - Returns: - None - """ - import_statement = self.import_statement - # Hack to remove the entire import statement if it only has one import - if import_statement.imports.uncommitted_len <= 1: - super().remove(delete_formatting=delete_formatting, priority=priority) - else: - # If the import belongs in a multi-import statement, remove the import specifier - self.import_specifier.remove(delete_formatting=delete_formatting, priority=priority) - - @property - @reader - def import_specifier(self) -> Editable: - """Returns the specific editable text representation of the import identifier within the - import statement. - - Retrieves the import specifier text that appears in the actual import statement. This is the portion of text that identifies what is being imported. - - Returns: - Editable: The editable text object representing the import specifier. - For named imports like 'import { a as b } from 'c'', returns 'a as b'. - For from imports like 'from a.b import c', returns 'c'. - - Raises: - ValueError: If the subclass does not implement this property. - """ - msg = "Subclass must implement `import_specifier`" - raise ValueError(msg) - - @reader - def is_reexport(self) -> bool: - """Returns true if the Import object is also an Export object. - - Checks whether this Import node has a corresponding Export node with the same source. - If the import is an export, it implies there are no direct usages of the import within the file it is defined in. - - Returns: - bool: True if the import is re-exported, False otherwise. - """ - return self.export and self.export.source == self.source - - def _removed_child_commit(self) -> None: - self.parent.imports._removed_child_commit() - - def _removed_child(self) -> None: - self.parent.imports._removed_child() - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - """Resolve the types used by this import.""" - # if self.is_wildcard_import(): - # if from_file := self.from_file: - # yield parent.with_frame(from_file, direct=False, to_find=parent.to_find) - # return - - ix_seen = set() - - aliased = self.is_aliased_import() - if imported := self._imported_symbol(resolve_exports=True): - yield from self.with_resolution_frame(imported, direct=False, aliased=aliased) - else: - yield ResolutionStack(self, aliased=aliased) - - @cached_property - @noapidoc - @reader - def _wildcards(self) -> dict[str, WildcardImport[Self]]: - """A list of all imports or wildcard imports.""" - from codegen.sdk.core.file import SourceFile - - res = {} - if self.is_wildcard_import(): - resolved = self.resolved_symbol - if isinstance(resolved, SourceFile): - resolved.invalidate() - for name, symbol in resolved.valid_import_names.items(): - res[name] = WildcardImport(self, symbol) - return res - - @property - @noapidoc - def names(self) -> Generator[tuple[str, Self | WildcardImport[Self]], None, None]: - if self.is_wildcard_import() and not self.is_aliased_import(): - if getattr(self, "_resolving_wildcards", False): - return - self._resolving_wildcards = True - if self._wildcards: - yield from self._wildcards.items() - self._resolving_wildcards = False - for imp in self.file.importers: - imp.file.invalidate() - - return - elif self.resolved_symbol is None: - self._resolving_wildcards = False - yield self.name, self - - @property - @noapidoc - def viz(self) -> VizNode: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - - @property - @noapidoc - def parent_symbol(self) -> Self: - """Returns the parent symbol of the symbol.""" - return self - - @noapidoc - @commiter - def _compute_dependencies(self, *args, **kwargs) -> None: - """Compute the dependencies of the export object.""" - # if self.is_wildcard_import(): - # for _, wildcard in self._wildcards.items(): - # for used_frame in wildcard.resolved_type_frames: - # if used_frame.parent_frame: - # used_frame.parent_frame.add_usage(self.symbol_name or self.module, SymbolUsageType.IMPORTED_WILDCARD, self, self.ctx) - # else: - if isinstance(self, Import) and self.import_type == ImportType.NAMED_EXPORT: - # It could be a wildcard import downstream, hence we have to pop the cache - if file := self.from_file: - file.invalidate() - - for used_frame in self.resolved_type_frames: - if used_frame.parent_frame: - used_frame.parent_frame.add_usage(self._unique_node, UsageKind.IMPORTED, self, self.ctx) - - @property - def _unique_node(self): - """A unique node for this import to identify it by""" - # HACK: very much a hack - return self.symbol_name or self.alias or self.module or self - - def __hash__(self): - if self._hash is None: - self._hash = hash((self.filepath, self.range, self.ts_node.kind_id, self._unique_node.range)) - return self._hash - - @reader - def __eq__(self, other: object): - if isinstance(other, Import): - return super().__eq__(other) and self._unique_node.range == other._unique_node.range - return super().__eq__(other) - - @noapidoc - @reader - def remove_if_unused(self) -> None: - if all( - self.transaction_manager.get_transactions_at_range(self.filepath, start_byte=usage.match.start_byte, end_byte=usage.match.end_byte, transaction_order=TransactionPriority.Remove) - for usage in self.usages - ): - self.remove() - - @noapidoc - @reader - def resolve_attribute(self, attribute: str) -> TSourceFile | None: - # Handles implicit namespace imports in python - if not isinstance(self._imported_symbol(), ExternalModule): - return None - resolved = self.resolve_import(add_module_name=attribute) - if resolved and (isinstance(resolved.symbol, Editable) or isinstance(resolved.from_file, Editable)): - return resolved.symbol or resolved.from_file - return None - - -TImport = TypeVar("TImport", bound="Import") - - -class WildcardImport(Chainable, Generic[TImport]): - """Class to represent one of many wildcard imports.""" - - imp: TImport - symbol: Importable - - def __init__(self, imp: TImport, symbol: Importable): - self.imp = imp - self.symbol = symbol - self.ts_node = imp.ts_node - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - """Resolve the types used by this import.""" - yield from self.imp.with_resolution_frame(self.symbol, direct=True) - - @noapidoc - @reader - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass - - @property - @override - def filepath(self) -> str: - return self.imp.filepath - - -class ExternalImportResolver: - def resolve(self, imp: Import) -> str | None: - return None diff --git a/src/codegen/sdk/core/interface.py b/src/codegen/sdk/core/interface.py deleted file mode 100644 index 2c605d694..000000000 --- a/src/codegen/sdk/core/interface.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.interfaces.inherits import Inherits -from codegen.sdk.enums import SymbolType -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.function import Function - from codegen.sdk.core.statements.attribute import Attribute - from codegen.sdk.core.symbol_groups.parents import Parents - - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") -TAttribute = TypeVar("TAttribute", bound="Attribute") -TFunction = TypeVar("TFunction", bound="Function") -TType = TypeVar("TType", bound="Type") - - -@apidoc -class Interface(Inherits, HasBlock, HasAttribute[TAttribute], Generic[TCodeBlock, TAttribute, TFunction, TType]): - """Abstract representation of an Interface class. - - Attributes: - parent_interfaces: All the interfaces that this interface extends. - """ - - symbol_type = SymbolType.Interface - parent_interfaces: Parents[TType, Self] | None = None - code_block: TCodeBlock - - @noapidoc - @commiter - def compute_superclass_dependencies(self) -> None: - if self.parent_interfaces: - self.parent_interfaces.compute_superclass_dependencies() - - @property - @reader - def attributes(self) -> list[TAttribute]: - """List of attributes defined in this Interface.""" - msg = "Subclass must implement `parse`" - raise NotImplementedError(msg) - - @reader - def get_attribute(self, name: str) -> TAttribute | None: - """Returns the attribute with the given name, if it exists. - - Otherwise, returns None. - """ - return next((x for x in self.attributes if x.name == name), None) - - @reader - def extends(self, parent_interface: str | Interface, max_depth: int | None = None) -> bool: - """Returns True if the interface implements the given parent interface.""" - if self.parent_interfaces is None: - return False - return self.parent_interfaces.is_subclass_of(parent_interface, max_depth=max_depth) - - @proxy_property - @reader - def implementations(self, max_depth: int | None = None) -> list[Interface | Class]: - """Returns all classes and interfaces that implement a given interface. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - """ - return self._get_subclasses(max_depth) - - @noapidoc - @reader - @override - def resolve_attribute(self, name: str) -> TAttribute | None: - return self.get_attribute(name) diff --git a/src/codegen/sdk/core/interfaces/__init__.py b/src/codegen/sdk/core/interfaces/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/core/interfaces/callable.py b/src/codegen/sdk/core/interfaces/callable.py deleted file mode 100644 index 83a2db7b9..000000000 --- a/src/codegen/sdk/core/interfaces/callable.py +++ /dev/null @@ -1,130 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.core.placeholder.placeholder import Placeholder -from codegen.sdk.core.symbol_group import SymbolGroup -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.external_module import ExternalModule - from codegen.sdk.core.function import Function - from codegen.sdk.core.symbol import Symbol - - -@dataclass -class FunctionCallDefinition: - """Represents a function call and its definitions. - - This class encapsulates information about a function call and the possible - callable entities that define it. - - Attributes: - call (FunctionCall): The function call object representing the invocation. - callables (List[Union[Function, Class, ExternalModule]]): A list of callable - entities that define the function being called. - """ - - call: FunctionCall - callables: list["Function | Class | ExternalModule"] - - -TParameter = TypeVar("TParameter", bound="Parameter") -TType = TypeVar("TType", bound="Type") - - -@apidoc -class Callable(Usable, Generic[TParameter, TType]): - """Any symbol that can be invoked with arguments eg. - - Function, Class, Decorator, ExternalModule - - Attributes: - return_type: The type of value returned by the callable, or a placeholder. - """ - - _parameters: SymbolGroup[TParameter, Self] | list[TParameter] - - return_type: TType | Placeholder[Self] - - @property - @reader(cache=False) - def call_sites(self) -> list[FunctionCall]: - """Returns all call sites (invocations) of this callable in the codebase. - - Finds all locations in the codebase where this callable is invoked/called. Call sites exclude imports, certain exports, and external references. - - Returns: - list[FunctionCall]: A list of FunctionCall objects representing each invocation of this callable. - Returns empty list if the callable has no name. - """ - # TODO - rename this and `function_calls` to be more clear - call_sites: list[FunctionCall] = [] - - for usage in self.usages: - if isinstance(usage.match, FunctionCall): - call_sites.append(usage.match) - - return list(dict.fromkeys(call_sites)) - - @property - @reader - def parameters(self) -> SymbolGroup[TParameter, Self] | list[TParameter]: - """Retrieves all parameters of a callable symbol. - - This property provides access to all parameters of a callable symbol (function, class, decorator, or external module). - Parameters are stored as a SymbolGroup containing Parameter objects. - - Returns: - SymbolGroup[TParameter, Self] | list[TParameter]: A group of Parameter objects representing the callable's parameters, - or an empty list if the callable has no parameters. - """ - return self._parameters - - @reader - def get_parameter(self, name: str) -> TParameter | None: - """Gets a specific parameter from the callable's parameters list by name. - - Args: - name (str): The name of the parameter to retrieve. - - Returns: - TParameter | None: The parameter with the specified name, or None if no parameter with that name exists or if there are no parameters. - """ - return next((x for x in self._parameters if x.name == name), None) - - @reader - def get_parameter_by_index(self, index: int) -> TParameter | None: - """Returns the parameter at the given index. - - Retrieves a parameter from the callable's parameter list based on its positional index. - - Args: - index (int): The index of the parameter to retrieve. - - Returns: - TParameter | None: The parameter at the specified index, or None if the parameter list - is empty or the index does not exist. - """ - return next((x for x in self._parameters if x.index == index), None) - - @reader - def get_parameter_by_type(self, type: "Symbol") -> TParameter | None: - """Retrieves a parameter from the callable by its type. - - Searches through the callable's parameters to find a parameter with the specified type. - - Args: - type (Symbol): The type to search for. - - Returns: - TParameter | None: The parameter with the specified type, or None if no parameter is found or if the callable has no parameters. - """ - if self._parameters is None: - return None - return next((x for x in self._parameters if x.type == type), None) diff --git a/src/codegen/sdk/core/interfaces/chainable.py b/src/codegen/sdk/core/interfaces/chainable.py deleted file mode 100644 index e12446a48..000000000 --- a/src/codegen/sdk/core/interfaces/chainable.py +++ /dev/null @@ -1,76 +0,0 @@ -from abc import abstractmethod -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.extensions.utils import cached_property -from codegen.shared.decorators.docs import noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.has_attribute import HasAttribute - -Parent = TypeVar("Parent", bound="Editable") - - -@noapidoc -class Chainable(Editable[Parent], Generic[Parent]): - """Represents a class that can be used as an object in a function call chain.""" - - _resolving: bool = False - - @abstractmethod - def _resolved_types(self) -> Generator["ResolutionStack[Self]", None, None]: ... - - @cached_property - @noapidoc - def resolved_type_frames(self) -> list[ResolutionStack["Self"]]: - """Resolve the definition(s) of this object.""" - if self._resolving: - return [ResolutionStack(self)] # Break cycles - self._resolving = True - try: - ret = list(self._resolved_types()) - self.__dict__.pop("resolved_type_frames", None) - return ret - finally: - self._resolving = False - - @noapidoc - def with_resolution( - self, resolution: ResolutionStack["Self"], *args, generic_parameters: list | None = None, generics: dict | None = None, **kwargs - ) -> Generator[ResolutionStack["Self"], None, None]: - from codegen.sdk.core.interfaces.supports_generic import SupportsGenerics - - assert resolution is not self - generics = generics or resolution.generics - if generic_parameters: - if isinstance(resolution.top.node, SupportsGenerics) and self.ctx.config.generics: - generics = {k: v for v, k in zip(generic_parameters, resolution.top.node.generics)} - elif not generics: - generics = {i: v for i, v in enumerate(generic_parameters)} - yield resolution.with_frame(self, *args, **kwargs, generics=generics) - - @noapidoc - def with_resolution_frame(self, child: Editable, *args, generic_parameters: list | None = None, generics: dict | None = None, **kwargs) -> Generator[ResolutionStack["Self"], None, None]: - """Resolve the definition(s) of this object.""" - if isinstance(child, Chainable): - assert child is not self - if not child._resolving: - resolved = child.resolved_type_frames - if len(resolved) > 0: - for resolution in resolved: - yield from self.with_resolution(resolution, *args, generic_parameters=generic_parameters, generics=generics, **kwargs) - return - if generics is None: - generics = {i: v for i, v in enumerate(generic_parameters)} if generic_parameters else None - yield ResolutionStack(child).with_frame(self, *args, **kwargs, generics=generics) - - @cached_property - @noapidoc - def resolved_types(self) -> list["HasAttribute"]: - """Resolve the definition(s) of this object. - - Returns type at the top of the resolution stack. - """ - return list(frame.top.node for frame in self.resolved_type_frames) diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py deleted file mode 100644 index 1f9de6e19..000000000 --- a/src/codegen/sdk/core/interfaces/conditional_block.py +++ /dev/null @@ -1,37 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence - -from codegen.sdk.core.statements.statement import Statement -from codegen.shared.decorators.docs import noapidoc - - -class ConditionalBlock(Statement, ABC): - """An interface for any code block that might not be executed in the code, - e.g if block/else block, try block/catch block ect. - """ - - @property - @abstractmethod - @noapidoc - def other_possible_blocks(self) -> Sequence["ConditionalBlock"]: - """Should return all other "branches" that might be executed instead.""" - - @property - @noapidoc - def end_byte_for_condition_block(self) -> int: - """Returns the end byte for the specific condition block""" - return self.end_byte - - @property - @noapidoc - def start_byte_for_condition_block(self) -> int: - """Returns the start byte for the specific condition block""" - return self.start_byte - - @noapidoc - def is_true_conditional(self, descendant) -> bool: - """Returns if this conditional is truly conditional, - this is necessary as an override for things like finally - statements that share a parent with try blocks - """ - return True diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py deleted file mode 100644 index 86e08c844..000000000 --- a/src/codegen/sdk/core/interfaces/editable.py +++ /dev/null @@ -1,1177 +0,0 @@ -from __future__ import annotations - -import itertools -import re -from abc import abstractmethod -from functools import cached_property -from typing import TYPE_CHECKING, Generic, Self, TypeVar, Unpack, final, overload - -from rich.markup import escape -from rich.pretty import Pretty - -from codegen.sdk.codebase.span import Span -from codegen.sdk.codebase.transactions import EditTransaction, InsertTransaction, RemoveTransaction, TransactionPriority -from codegen.sdk.core.autocommit import commiter, reader, remover, repr_func, writer -from codegen.sdk.core.placeholder.placeholder import Placeholder -from codegen.sdk.extensions.utils import get_all_identifiers -from codegen.sdk.output.ast import AST -from codegen.sdk.output.constants import ANGULAR_STYLE, MAX_STRING_LENGTH -from codegen.sdk.output.jsonable import JSONable -from codegen.sdk.output.utils import style_editable -from codegen.sdk.utils import descendant_for_byte_range, find_all_descendants, find_first_ancestor, find_index, truncate_line -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Callable, Generator, Iterable, Sequence - - import rich.repr - from rich.console import Console, ConsoleOptions, RenderResult - from tree_sitter import Node as TSNode - from tree_sitter import Point, Range - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.codebase.flagging.code_flag import CodeFlag - from codegen.sdk.codebase.flagging.enums import FlagKwargs - from codegen.sdk.codebase.transaction_manager import TransactionManager - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.export import Export - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.file import File, SourceFile - from codegen.sdk.core.function import Function - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.core.symbol_group import SymbolGroup - from codegen.sdk.enums import NodeType - from codegen.visualizations.enums import VizNode -CONTAINER_CHARS = (b"(", b")", b"{", b"}", b"[", b"]", b"<", b">", b"import") -MAX_REPR_LEN: int = 200 - - -def _contains_container_chars(text: bytes) -> bool: - return any([char in text for char in CONTAINER_CHARS]) - - -def _is_empty_container(text: str) -> bool: - stripped_str = re.sub(r"\s+", "", text) - return len(stripped_str) == 2 and all([char in CONTAINER_CHARS for char in text]) - - -_EXCLUDE_FROM_REPR: list[str] = [ - "ctx", - "autocommit_cache", - "parent", - "file_node_id", - "to_file_id", - "ts_node", - "node_id", - "resolved_type_frames", - "resolved_types", - "valid_symbol_names", - "valid_import_names", - "predecessor", - "successor", - "base", - "call_chain", - "code_block", - "parent_statement", - "symbol_usages", - "usages", - "function_definition_frames", - "start_point", - "end_point", - "span", - "range", - "methods", - "ts_config", - "symbols", - "exports", -] - -Parent = TypeVar("Parent", bound="Editable") -P = TypeVar("P", bound=Placeholder) -T = TypeVar("T", bound="Editable") - - -@apidoc -class Editable(JSONable, Generic[Parent]): - """An editable instance is an abstract text representation of any text in a file. - - Attributes: - ts_node: The TreeSitter node associated with this Editable instance. - file_node_id: The unique identifier for the file node. - ctx: The codebase context that this Editable instance is part of. - parent: The parent node of this Editable instance. - node_type: The type of node this Editable instance represents. - """ - - ts_node: TSNode - file_node_id: NodeId - ctx: CodebaseContext - parent: Parent - node_type: NodeType - _file: File | None = None - _hash: int | None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - self.ts_node = ts_node - self.file_node_id = file_node_id - self.ctx = ctx - self.parent = parent - if ctx.config.debug: - seen = set() - while parent is not None: - assert (parent.ts_node, parent.__class__) not in seen - seen.add((parent.ts_node, parent.__class__)) - parent = parent.parent - if self.file and self.ctx.config.full_range_index: - self._add_to_index - - def __hash__(self): - if self._hash is None: - self._hash = hash((self.filepath, self.range, self.ts_node.kind_id)) - return self._hash - - def __str__(self) -> str: - return self.source - - @repr_func - def __repr__(self) -> str: - """Represent the string for logging purposes.""" - if hasattr(self, "__dict__"): - keys = list(self.__dict__.keys()) - elif hasattr(self, "__slots__"): - keys = list(self.__slots__) - else: - keys = list() - keys = ["name", "filepath", "start_point", "end_point", *keys] - if not hasattr(self, "name"): - keys[0] = "source" - elif "source" in keys: - keys.remove("source") - kws = [f"{k}={truncate_line(repr(getattr(self, k, None)), MAX_REPR_LEN)}" for k in dict.fromkeys(keys) if k not in _EXCLUDE_FROM_REPR and not k.startswith("_") and hasattr(self, k)] - return "{}({})".format(type(self).__name__, ", ".join(kws)) - - def __rich_repr__(self) -> rich.repr.Result: - yield escape(self.filepath) - - __rich_repr__.angular = ANGULAR_STYLE # type: ignore - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - yield Pretty(self, max_string=MAX_STRING_LENGTH) - if self.file: - yield from style_editable(self.ts_node, self.file.path, self.file.ts_node) - - @reader - def __eq__(self, other: object): - if other is None: - return False - if isinstance(other, Editable): - return self.filepath == other.filepath and self.ts_node.kind_id == other.ts_node.kind_id and self.range == other.range - if isinstance(other, str): - return self.source == other - return False - - @reader - def __contains__(self, item: str | Editable) -> bool: - if isinstance(item, Editable): - return item.source in self.source - return item in self.source - - @property - @noapidoc - def transaction_manager(self) -> TransactionManager: - return self.ctx.transaction_manager - - @property - @noapidoc - @reader - def start_byte(self) -> int: - """The start byte of the Editable instance that appears in file.""" - return self.ts_node.start_byte - - @property - @noapidoc - @reader - @final - def end_byte(self) -> int: - """The end byte of the Editable instance that appears in file.""" - return self.ts_node.end_byte - - @property - @noapidoc - @reader - @final - def start_point(self) -> Point: - """The start point (row, column) of the Editable instance that appears in file.""" - return self.ts_node.start_point - - @property - @noapidoc - @reader - @final - def end_point(self) -> Point: - """The end point (row, column) of the Editable instance that appears in file.""" - return self.ts_node.end_point - - @property - @noapidoc - @reader - def line_range(self) -> range: - """The 0-indexed line/row range that the Editable instance spans in the file.""" - return range(self.start_point[0], self.end_point[0] + 1) # +1 b/c end_point[0] is inclusive - - @property - @noapidoc - @reader - def _source(self) -> str: - """Text representation of the Editable instance.""" - return self.ts_node.text.decode("utf-8") - - @property - @reader - def source(self) -> str: - """Text representation of the Editable instance. - - Returns the source text of the Editable instance. This is the main property used to access the text content of any code element in GraphSitter. - - Returns: - str: The text content of this Editable instance. - """ - return self._source - - @source.setter - @writer - def source(self, value) -> None: - """Sets the source (text representation) of the Editable instance using .edit(..). - - Only edits if the new value is different from the current source. - - Args: - value (str): The new text representation to set. - - Returns: - None: The method returns nothing. - """ - if self.source != value: - self.edit(value) - - @property - @noapidoc - @reader(cache=False) - def extended_nodes(self) -> list[Editable]: - """List of Editable instances that includes itself and its extended symbols like `export`, - `public` or `decorator` - """ - return [self] - - @property - def extended(self) -> SymbolGroup: - """Returns a SymbolGroup of all extended nodes associated with this element. - - Creates a SymbolGroup that provides a common interface for editing all extended nodes, - such as decorators, modifiers, and comments associated with the element. - - Args: - None - - Returns: - SymbolGroup: A group containing this node and its extended nodes that allows - batch modification through a common interface. - """ - from codegen.sdk.core.symbol_group import SymbolGroup - - return SymbolGroup(self.file_node_id, self.ctx, self.parent, children=self.extended_nodes) - - @property - @reader - def extended_source(self) -> str: - """Returns the source text representation of all extended nodes. - - Gets the source text of all extended nodes combined. This property allows reading the source text - of all extended nodes (e.g. decorators, export statements) associated with this node. - - Returns: - str: The combined source text of all extended nodes. - """ - return self.extended.source - - @extended_source.setter - def extended_source(self, value: str) -> None: - """Set the source of all extended nodes. - - Updates the source of all nodes in the extended nodes list by calling .edit(..). This is useful for updating multiple related nodes (e.g. decorators, export statements) at once. - - Args: - value (str): The new source text to set for all extended nodes. - - Returns: - None - """ - self.extended.edit(value) - - @property - @reader - @noapidoc - def children(self) -> list[Editable[Self]]: - """List of Editable instances that are children of this node.""" - return [self._parse_expression(child) for child in self.ts_node.named_children] - - @property - @reader - @noapidoc - def _anonymous_children(self) -> list[Editable[Self]]: - """All anonymous children of an editable.""" - return [self._parse_expression(child) for child in self.ts_node.children if not child.is_named] - - @property - @reader - @noapidoc - def next_sibling(self) -> Editable | None: - """Returns the Editable instance that next appears in the file.""" - if self.ts_node is None: - return None - - next_sibling_node = self.ts_node.next_sibling - if next_sibling_node is None: - return None - - return self._parse_expression(next_sibling_node) - - @property - @reader - @noapidoc - def next_named_sibling(self) -> Editable[Parent] | None: - if self.ts_node is None: - return None - - next_named_sibling_node = self.ts_node.next_named_sibling - if next_named_sibling_node is None: - return None - - return self.parent._parse_expression(next_named_sibling_node) - - @property - @reader - @noapidoc - def previous_named_sibling(self) -> Editable[Parent] | None: - if self.ts_node is None: - return None - - previous_named_sibling_node = self.ts_node.prev_named_sibling - if previous_named_sibling_node is None: - return None - - return self.parent._parse_expression(previous_named_sibling_node) - - @property - def file(self) -> SourceFile: - """The file object that this Editable instance belongs to. - - Retrieves or caches the file object associated with this Editable instance. - - Returns: - File: The File object containing this Editable instance. - """ - if self._file is None: - self._file = self.ctx.get_node(self.file_node_id) - return self._file # type: ignore - - @property - def filepath(self) -> str: - """The file path of the file that this Editable instance belongs to. - - Returns a string representing the absolute file path of the File that contains this Editable instance. - - Returns: - str: The absolute file path. - """ - return self.file.file_path - - @reader - def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable[Self]]: - """Returns a list of string literals within this node's source that match any of the given - strings. - - Args: - strings_to_match (list[str]): A list of strings to search for in string literals. - fuzzy_match (bool): If True, matches substrings within string literals. If False, only matches exact strings. Defaults to False. - - Returns: - list[Editable[Self]]: A list of Editable objects representing the matching string literals. - """ - matches: list[Editable[Self]] = [] - for node in self.extended_nodes: - matches.extend(node._find_string_literals(strings_to_match, fuzzy_match)) - return matches - - @noapidoc - @reader - def _find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> Sequence[Editable[Self]]: - all_string_nodes = find_all_descendants(self.ts_node, type_names={"string"}) - editables = [] - for string_node in all_string_nodes: - assert string_node.text is not None - full_string = string_node.text.strip(b'"').strip(b"'") - if fuzzy_match: - if not any([str_to_match.encode("utf-8") in full_string for str_to_match in strings_to_match]): - continue - else: - if not any([str_to_match.encode("utf-8") == full_string for str_to_match in strings_to_match]): - continue - editables.append(self._parse_expression(string_node)) - return editables - - @writer - def replace(self, old: str, new: str, count: int = -1, is_regex: bool = False, priority: int = 0) -> int: - """Search and replace occurrences of text within this node's source and its extended nodes. - - This method performs string replacement similar to Python's string.replace(), with support for regex patterns. - It operates on both the main node and any extended nodes (e.g. decorators, exports). - - Args: - old (str): The text or pattern to search for. - new (str): The text to replace matches with. - count (int, optional): Maximum number of replacements to make. Defaults to -1 (replace all). - is_regex (bool, optional): Whether to treat 'old' as a regex pattern. Defaults to False. - priority (int, optional): Priority of the replacement operation. Defaults to 0. - - Returns: - int: The total number of replacements made. - - Raises: - ValueError: If there are multiple occurrences of the substring in a node's source. - """ - total_count = 0 - for node in self.extended_nodes: - total_count += node._replace(old, new, count - total_count, is_regex, priority) - if 0 < count <= total_count: - break - return total_count - - @noapidoc - @writer - def _replace(self, old: str, new: str, count: int = -1, is_regex: bool = False, priority: int = 0) -> int: - """Search and replace an instance of `substring` within this node's source. - - Only replaces up to the `count` specified, and returns the total instances replaced. - """ - total_count = 0 - if not is_regex: - old = re.escape(old) - - for match in re.finditer(old.encode("utf-8"), self.ts_node.text): # type: ignore - start_byte = self.ts_node.start_byte + match.start() - end_byte = self.ts_node.start_byte + match.end() - t = EditTransaction( - start_byte, - end_byte, - self.file, - new, - priority=priority, - ) - self.transaction_manager.add_transaction(t, dedupe=True) - - total_count += 1 - if 0 < count <= total_count: - break - return total_count - - @reader - def find(self, strings_to_match: list[str] | str, *, exact: bool = False) -> list[Editable]: - """Find and return matching nodes or substrings within an Editable instance. - - This method searches through the extended_nodes of the Editable instance and returns all nodes or substrings that match the given search criteria. - - Args: - strings_to_match (Union[list[str], str]): One or more strings to search for. - exact (bool): If True, only return nodes whose source exactly matches one of the strings_to_match. - If False, return nodes that contain any of the strings_to_match as substrings. - Defaults to False. - - Returns: - list[Editable]: A list of Editable instances that match the search criteria. - """ - matches = [] - for node in self.extended_nodes: - matches.extend(node._find(strings_to_match, exact)) - return matches - - @noapidoc - @reader - def _find(self, strings_to_match: list[str] | str, exact: bool = False) -> list[Editable]: - if isinstance(strings_to_match, str): - strings_to_match = [strings_to_match] - # Use search to find string - search_results = itertools.chain.from_iterable(map(self._search, map(re.escape, strings_to_match))) - if exact: - search_results = filter(lambda result: result.source in strings_to_match, search_results) - - # Combine and deduplicate results - return list(search_results) - - @reader - def search(self, regex_pattern: str, include_strings: bool = True, include_comments: bool = True) -> list[Editable]: - """Returns a list of all regex match of `regex_pattern`, similar to python's re.search(). - - Searches for matches of a regular expression pattern within the text of this node and its extended nodes. - - Args: - regex_pattern (str): The regular expression pattern to search for. - include_strings (bool): When False, excludes the contents of string literals from the search. Defaults to True. - include_comments (bool): When False, excludes the contents of comments from the search. Defaults to True. - - Returns: - list[Editable]: A list of Editable objects corresponding to the matches found. - """ - matches = [] - for node in self.extended_nodes: - matches.extend(node._search(regex_pattern, include_strings=include_strings, include_comments=include_comments)) - return matches - - @noapidoc - @reader - def _search(self, regex_pattern: str, include_strings: bool = True, include_comments: bool = True) -> list[Editable]: - matching_byte_ranges: list[tuple[int, int]] = [] - string = self.ts_node.text - - pattern = re.compile(regex_pattern.encode("utf-8")) - start_byte_offset = self.ts_node.byte_range[0] - for match in pattern.finditer(string): # type: ignore - matching_byte_ranges.append((match.start() + start_byte_offset, match.end() + start_byte_offset)) - - matches: list[Editable] = [] - for byte_range in matching_byte_ranges: - ts_match = descendant_for_byte_range(self.ts_node, byte_range[0], byte_range[1], allow_comment_boundaries=include_comments) - if ts_match is not None: - # Check for inclusion of comments and/or strings - if (include_strings or ts_match.type not in ("string", "string_content", "string_fragment")) and (include_comments or ts_match.type != "comment"): - matches.append(self._parse_expression(ts_match)) - return list(matches) - - @writer(commit=False) - @noapidoc - def insert_at(self, byte: int, new_src: str | Callable[[], str], *, priority: int | tuple = 0, dedupe: bool = True, exec_func: Callable[[], None] | None = None) -> None: - # Insert the new_src - t = InsertTransaction( - byte, - self.file, - new_src, - priority=priority, - exec_func=exec_func, - ) - self.transaction_manager.add_transaction(t, dedupe=dedupe) - - def _get_indent(self) -> int: - return self.ts_node.start_point[1] - - @writer(commit=False) - def insert_before(self, new_src: str, fix_indentation: bool = False, newline: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Inserts text before this node's source with optional indentation and newline handling. - - This method inserts the provided text before the current node's source code. It can automatically handle indentation and newline placement. - - Args: - new_src (str): The text to insert before this node. - fix_indentation (bool): Whether to fix the indentation of new_src to match the current node. Defaults to False. - newline (bool): Whether to add a newline after new_src. Defaults to True. - priority (int): Transaction priority for managing multiple edits. Defaults to 0. - dedupe (bool): Whether to deduplicate identical transactions. Defaults to True. - - Returns: - None - """ - if self.ts_node is None: - return - - indentation = " " * min(node._get_indent() for node in self.extended_nodes) - if fix_indentation: - src_lines = new_src.split("\n") - src_lines = src_lines[:1] + [line if line == "" else indentation + line for line in src_lines[1:]] - new_src = "\n".join(src_lines) - - # Add a newline before the new_src - if newline: - new_src += "\n" - - if fix_indentation: - new_src += indentation - self.insert_at(self.start_byte, new_src, priority=priority, dedupe=dedupe) - - @writer(commit=False) - def insert_after(self, new_src: str, fix_indentation: bool = False, newline: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Inserts code after this node. - - Args: - new_src (str): The source code to insert after this node. - fix_indentation (bool, optional): Whether to adjust the indentation of new_src to match the current node. Defaults to False. - newline (bool, optional): Whether to add a newline before the new_src. Defaults to True. - priority (int, optional): Priority of the insertion transaction. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate identical transactions. Defaults to True. - - Returns: - None - """ - if self.ts_node is None: - return - - if fix_indentation: - indentation = " " * min(node._get_indent() for node in self.extended_nodes) - src_lines = new_src.split("\n") - src_lines = [line if line == "" else indentation + line for line in src_lines] - new_src = "\n".join(src_lines) - - # Add a newline before the new_src - if newline: - new_src = "\n" + new_src - - self.insert_at(self.ts_node.end_byte, new_src, priority=priority, dedupe=dedupe) - - @writer - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Replace the source of this `Editable` with `new_src`. - - Replaces the text representation of this Editable instance with new text content. The method handles indentation adjustments and transaction management. - - Args: - new_src (str): The new source text to replace the current text with. - fix_indentation (bool): If True, adjusts the indentation of `new_src` to match the current text's indentation level. Defaults to False. - priority (int): The priority of the edit transaction. Higher priority edits are applied first. Defaults to 0. - dedupe (bool): If True, deduplicates identical transactions. Defaults to True. - - Returns: - None - """ - if fix_indentation: - line = self.file.content.split("\n")[self.ts_node.start_point[0]] - indentation = line[: len(line) - len(line.strip())] - src_lines = new_src.split("\n") - src_lines = src_lines[:1] + [line if line == "" else indentation + line for line in src_lines[1:]] - new_src = "\n".join(src_lines) - - t = EditTransaction( - self.start_byte, - self.end_byte, - self.file, - new_src, - priority=priority, - ) - self.transaction_manager.add_transaction(t, dedupe=dedupe) - - @writer - def _edit_byte_range(self, new_src: str, start_byte: int, end_byte: int, priority: int = 0, dedupe: bool = True) -> None: - t = EditTransaction( - start_byte, - end_byte, - self.file, - new_src, - priority=priority, - ) - self.transaction_manager.add_transaction(t, dedupe=dedupe) - - @remover - @noapidoc - def remove_byte_range(self, start_byte: int, end_byte: int) -> None: - if self.ctx.config.debug: - assert start_byte < end_byte - t = RemoveTransaction(start_byte, end_byte, self.file) - self.transaction_manager.add_transaction(t) - - @remover - def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Deletes this Node and its related extended nodes (e.g. decorators, comments). - - Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase. - After removing the node, it handles cleanup of any surrounding formatting based on the context. - - Args: - delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True. - priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0. - dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True. - - Returns: - None - """ - for node in self.extended_nodes: - node._remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - - @remover - @noapidoc - def _remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - if self.parent._smart_remove(self, delete_formatting=delete_formatting, priority=priority, dedupe=dedupe): - return - # If the node deleted is the only node, delete the entire node - parent = self.ts_node.parent - removed_start_byte = self.start_byte - removed_end_byte = self.end_byte - if parent is not None and parent.type in ("parenthesized_expression", "jsx_expression") and self.ts_node.is_named: - removed_start_byte = min(parent.start_byte, removed_start_byte) - removed_end_byte = max(parent.end_byte, removed_end_byte) - parent = parent.parent - while parent is not None and parent.byte_range == self.ts_node.byte_range: - parent = parent.parent - if parent is not None and parent.type in ("named_imports", "export_statement") and len(parent.named_children) == 1 and self.ts_node.is_named: - removed_start_byte = min(parent.start_byte, removed_start_byte) - removed_end_byte = max(parent.end_byte, removed_end_byte) - parent = parent.parent - - def should_keep(node: TSNode): - if node.type == "comment": - # Remove comments on the same rows as the deleted node - if node.end_point[0] <= self.end_point[0] and node.start_byte > removed_start_byte: - return False - return True - - siblings = None if parent is None else list(filter(should_keep, parent.named_children if self.ts_node.is_named else parent.children)) - # same line - - # In the case this is an import_from_statement, the first sibling is the module_name, and the rest are the imports - if parent is not None and parent.type == "import_from_statement" and siblings and len(siblings) > 0: - siblings = siblings[1:] - - if isinstance(self.parent, Editable): - exec_func = self.parent._removed_child_commit - else: - exec_func = None - - # Delete the node - t = RemoveTransaction(removed_start_byte, removed_end_byte, self.file, priority=priority, exec_func=exec_func) - if self.transaction_manager.add_transaction(t, dedupe=dedupe): - if exec_func is not None: - self.parent._removed_child() - - # If there are sibling nodes, delete the surrounding whitespace & formatting (commas) - if delete_formatting and siblings and len(siblings) > 1: - index = find_index(self.ts_node, siblings) - - # Check if all previous siblings are being deleted - all_previous_deleted = all( - self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=siblings[i].start_byte, end_byte=siblings[i].end_byte, transaction_order=TransactionPriority.Remove) - for i in range(index) - ) - - if all_previous_deleted: - if index != 0: - self.remove_byte_range(siblings[index - 1].end_byte, removed_start_byte) - # If it's the first import or all previous imports are being deleted, - # remove the comma after - start_byte = removed_end_byte - if index + 1 < len(siblings): - end_byte = siblings[index + 1].start_byte - else: - return # Do not delete if it's the last node - elif _contains_container_chars(self.file.content_bytes[siblings[index - 1].end_byte : removed_start_byte]): - if index + 1 < len(siblings): - start_byte = removed_end_byte - end_byte = siblings[index + 1].start_byte - else: - return # Do not delete the last node - else: - start_byte = siblings[index - 1].end_byte - end_byte = removed_start_byte - - # Check that it is not deleting a list container - if _contains_container_chars(self.file.content_bytes[start_byte:end_byte]): - return - - t = RemoveTransaction( - start_byte, - end_byte, - self.file, - priority=priority, - ) - self.transaction_manager.add_transaction(t, dedupe=dedupe) - - # ================================================================================================================== - # Utilities - # ================================================================================================================== - # TODO: not sure if these functions should be here tbh - @overload - def child_by_field_name(self, field_name: str, *, placeholder: type[P], default: type[Expression] | None = None) -> Expression[Self] | P: ... - - @overload - def child_by_field_name(self, field_name: str, *, placeholder: None = ..., default: type[Expression] | None = None) -> Expression[Self] | None: ... - - @reader - @noapidoc - def child_by_field_name(self, field_name: str, *, placeholder: type[P] | None = None, **kwargs) -> Expression[Self] | P | None: - """Get child by field name.""" - node = self.ts_node.child_by_field_name(field_name) - if node is None: - if placeholder is not None: - return placeholder(self) - return None - return self._parse_expression(node, **kwargs) - - @reader - @noapidoc - def children_by_field_types(self, field_types: str | Iterable[str]) -> Generator[Expression[Self], None, None]: - """Get child by field types.""" - if isinstance(field_types, str): - field_types = [field_types] - for child in self.ts_node.children: - if child.type in field_types: - if node := self._parse_expression(child): - yield node - - @reader - @noapidoc - def child_by_field_types(self, field_types: str | Iterable[str]) -> Expression[Self] | None: - """Get child by fiexld types.""" - return next(self.children_by_field_types(field_types), None) - - @property - @reader - @noapidoc - def ts_node_type(self) -> str: - """This is the underlying type of the TreeSitter node corresponding to this entity, and the - value will correspond to the tree-sitter language grammar. - """ - return self.ts_node.type - - @commiter - @noapidoc - def commit(self) -> None: - """Commits any pending transactions for the current node to the codebase. - - Commits only the transactions that affect the file this node belongs to. This is useful when you want to - commit changes made to a specific node without committing all pending transactions in the codebase. - - Args: - None - - Returns: - None - """ - self.ctx.commit_transactions(files={self.file.path}) - - @noapidoc - def _removed_child(self) -> None: - pass - - @noapidoc - def _removed_child_commit(self) -> None: - pass - - @property - @reader - def variable_usages(self) -> list[Editable]: - """Returns Editables for all TreeSitter node instances of variable usages within this node's - scope. - - This method finds all variable identifier nodes in the TreeSitter AST, excluding: - - Function names in function calls - - Import names in import statements - - Property access identifiers (except the base object) - - Keyword argument names (in Python and TypeScript) - - This is useful for variable renaming and usage analysis within a scope. - - Returns: - list[Editable]: A list of Editable nodes representing variable usages. Each - Editable corresponds to a TreeSitter node instance where the variable - is referenced. - """ - usages: Sequence[Editable[Self]] = [] - identifiers = get_all_identifiers(self.ts_node) - for identifier in identifiers: - # Excludes function names - parent = identifier.parent - if parent is None: - continue - if parent.type in ["call", "call_expression"]: - continue - # Excludes local import statements - if parent.parent is not None and parent.parent.type in ["import_statement", "import_from_statement"]: - continue - # Excludes property identifiers - if parent.type == "attribute" and parent.children.index(identifier) != 0: - continue - # Excludes arg keyword (Python specific) - if parent.type == "keyword_argument" and identifier == parent.child_by_field_name("name"): - continue - # Excludes arg keyword (Typescript specific) - arguments = find_first_ancestor(parent, ["arguments"]) - if arguments is not None and any(identifier == arg.child_by_field_name("left") for arg in arguments.named_children): - continue - - usages.append(self._parse_expression(identifier)) - - return usages - - @reader - def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> Sequence[Editable[Self]]: - """Returns Editables for all TreeSitter nodes corresponding to instances of variable usage - that matches the given variable name. - - Retrieves a list of variable usages that match a specified name, with an option for fuzzy matching. By default, excludes property identifiers and argument keywords. - - Args: - var_name (str): The variable name to search for. - fuzzy_match (bool): If True, matches variables where var_name is a substring. If False, requires exact match. Defaults to False. - - Returns: - list[Editable]: List of Editable objects representing variable usage nodes matching the given name. - """ - if fuzzy_match: - return [usage for usage in self.variable_usages if var_name in usage.source] - else: - return [usage for usage in self.variable_usages if var_name == usage.source] - - @overload - def _parse_expression(self, node: TSNode, **kwargs) -> Expression[Self]: ... - - @overload - def _parse_expression(self, node: TSNode | None, **kwargs) -> Expression[Self] | None: ... - - def _parse_expression(self, node: TSNode | None, **kwargs) -> Expression[Self] | None: - return self.ctx.parser.parse_expression(node, self.file_node_id, self.ctx, self, **kwargs) - - def _parse_type(self, node: TSNode) -> Type[Self] | None: - return self.ctx.parser.parse_type(node, self.file_node_id, self.ctx, self) - - def flag(self, **kwargs: Unpack[FlagKwargs]) -> CodeFlag[Self]: - """Adds a visual flag comment to the end of this Editable's source text. - - Flags this Editable by appending a comment with emoji flags at the end of its source text. - This is useful for visually highlighting specific nodes in the source code during development - and debugging. - - Returns: - None - """ - # TODO: remove this once the frontend can process code flags - return self.ctx.flags.flag_instance(self, **kwargs) - - @noapidoc - @abstractmethod - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - """Compute the dependencies of the export object.""" - pass - - @commiter - @noapidoc - def _add_symbol_usages(self: HasName, identifiers: list[TSNode], usage_type: UsageKind, dest: HasName | None = None) -> None: - from codegen.sdk.core.expressions import Name - from codegen.sdk.core.interfaces.resolvable import Resolvable - - if dest is None: - dest = self - for x in identifiers: - if dep := self._parse_expression(x, default=Name): - assert isinstance(dep, Resolvable) - dep._compute_dependencies(usage_type, dest) - - @commiter - @noapidoc - def _add_all_identifier_usages(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - id_types = self.ctx.node_classes.resolvables - # Skip identifiers that are part of a property - identifiers = find_all_descendants(self.ts_node, id_types, nested=False) - return self._add_symbol_usages(identifiers, usage_type, dest) - - @commiter - @noapidoc - def add_all_identifier_usages_for_child_node(self, usage_type: UsageKind, child: TSNode, dest=None) -> None: - # Interim hack. Don't use - id_types = self.ctx.node_classes.resolvables - # Skip identifiers that are part of a property - identifiers = find_all_descendants(child, id_types, nested=False) - return self._add_symbol_usages(identifiers, usage_type, dest) - - @noapidoc - def _log_parse(self, msg: str, *args, **kwargs): - self.ctx.parser.log(msg, *args, **kwargs) - - @property - @noapidoc - def viz(self) -> VizNode: - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.visualizations.enums import VizNode - - if isinstance(self, HasName): - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, name=self.name, symbol_name=self.__class__.__name__) - else: - return VizNode(file_path=self.filepath, start_point=self.start_point, end_point=self.end_point, symbol_name=self.__class__.__name__) - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: - if self.parent is not None: - yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict) - else: - yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) - - @cached_property - @noapidoc - def github_url(self) -> str | None: - if self.file.github_url: - return self.file.github_url + f"#L{self.start_point[0] + 1}-L{self.end_point[0] + 1}" - - @property - @noapidoc - def parent_symbol(self) -> Symbol | File | Import | Export: - """Returns the parent symbol of the symbol.""" - return self.parent.parent_symbol - - @property - @noapidoc - @final - def range(self) -> Range: - return self.ts_node.range - - @cached_property - @noapidoc - @final - def span(self) -> Span: - return Span(range=self.range, filepath=self.filepath) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - """Returns the nested symbols of the importable object, including itself.""" - return [] - # return list(itertools.chain.from_iterable(child.descendant_symbols for child in self.children)) - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Reduces an editable to the following condition""" - if node is not None: - node.edit(self.ctx.node_classes.bool_conversion[bool_condition]) - else: - self.parent.reduce_condition(bool_condition, self) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns a list of all function calls contained within this expression. - - Traverses the extended nodes of this expression to find all function calls within it. This is useful for tasks like analyzing call patterns or renaming function invocations. - - Returns: - list[FunctionCall]: A list of FunctionCall objects representing all function calls contained within this expression. - """ - calls = [] - for node in self.children: - calls.extend(node.function_calls) - return calls - - @property - @noapidoc - def self_dest(self) -> Importable: - """Returns the symbol usage resolution destination node for the symbol.""" - from codegen.sdk.core.interfaces.importable import Importable - - dest = self - while dest and not isinstance(dest, Importable): - dest = dest.parent - return dest - - @cached_property - @noapidoc - def _add_to_index(self) -> None: - self.file._range_index.add_to_range(self) - - @noapidoc - def _smart_remove(self, child, *args, **kwargs) -> bool: - """Check if a node should remove itself based on the removal of its children nodes""" - return False - - @reader - def is_wrapped_in(self, cls: type[Expression]) -> bool: - """Check if this node is contained another node of the given class""" - return self.parent_of_type(cls) is not None - - @reader - def parent_of_type(self, type: type[T]) -> T | None: - """Find the first ancestor of the node of the given type. Does not return itself""" - if isinstance(self.parent, type): - return self.parent - if self.parent is not self and self.parent is not None: - return self.parent.parent_of_type(type) - return None - - def parent_of_types(self, types: set[type[T]]) -> T | None: - """Find the first ancestor of the node of the given type. Does not return itself""" - if self.parent and any(isinstance(self.parent, t) for t in types): - return self.parent - if self.parent is not self and self.parent is not None: - return self.parent.parent_of_types(types) - return None - - def is_child_of(self, instance: Editable) -> bool: - """Checks if this node is a descendant of the given editable instance in the AST.""" - if not self.parent: - return False - if self.parent is instance: - return True - else: - return self.parent.is_child_of(instance=instance) - - @reader - def ancestors(self, type: type[T]) -> list[T]: - """Find all ancestors of the node of the given type. Does not return itself""" - if self.parent is not self and self.parent is not None: - ret = self.parent.ancestors(type) - else: - ret = [] - if isinstance(self.parent, type): - ret.append(self.parent) - return ret - - @reader - @noapidoc - def first_ancestors(self, type: type[T]) -> T | None: - """Find the first ancestor of the node of the given type.""" - return next(iter(self.ancestors(type)), None) - - @property - @reader - def parent_statement(self) -> Statement | None: - """Find the statement this node is contained in""" - from codegen.sdk.core.statements.statement import Statement - - return self.parent_of_type(Statement) - - @property - @reader - def parent_function(self) -> Function | None: - """Find the function this node is contained in""" - from codegen.sdk.core.function import Function - - return self.parent_of_type(Function) - - @property - @reader - def parent_class(self) -> Class | None: - """Find the class this node is contained in""" - from codegen.sdk.core.class_definition import Class - - return self.parent_of_type(Class) - - def _get_ast_children(self) -> list[tuple[str | None, AST]]: - children = [] - names = {} - for name, val in self._list_members(include_methods=True).items(): - if isinstance(val, Editable): - names[val] = name - for child in self.file._range_index.get_children(self): - if self.ctx.config.debug: - assert child != self, child - elif child == self: - continue - children.append((names.get(child, None), child.ast())) - return children - - @noapidoc - @final - def ast(self) -> AST: - children = self._get_ast_children() - return AST(codegen_sdk_type=self.__class__.__name__, span=self.span, tree_sitter_type=self.ts_node_type, children=children) diff --git a/src/codegen/sdk/core/interfaces/exportable.py b/src/codegen/sdk/core/interfaces/exportable.py deleted file mode 100644 index b201316ac..000000000 --- a/src/codegen/sdk/core/interfaces/exportable.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from rustworkx import NoSuitableNeighbors - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.enums import EdgeType, ImportType, NodeType -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.export import Export - from codegen.sdk.core.interfaces.editable import Editable -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Exportable(Usable[Parent], Generic[Parent]): - """An interface for any node object that can be exported - eg. Class, class name, top-level functions, imports - """ - - @property - def is_exported(self) -> bool: - """Indicates if the symbol is exported from its defining file. - - Returns: - bool: True if the symbol has an export object, False otherwise. - """ - return self.export is not None - - @property - @reader(cache=False) - def export(self) -> Export | None: - """Returns the export object that exports this symbol. - - Retrieves the export object by examining incoming EXPORT edges in the CodebaseContext. - - Args: - None - - Returns: - Export | None: The Export object that exports this symbol, or None if not exported. - """ - try: - if self.node_id is None: - return None - return self.ctx.predecessor(self.node_id, edge_type=EdgeType.EXPORT) - except NoSuitableNeighbors: - return None - - @property - @reader(cache=False) - def exported_name(self) -> str | None: - """Retrieves the exported name of a symbol from its file. - - If the symbol is an export node, returns the node's name. If the symbol is not exported, returns None. - - Returns: - str | None: The name the symbol is exported as, or None if not exported. - """ - if self.node_type == NodeType.EXPORT: - # Export's exported name is itself - return self.name - - export = self.export - if export is None: - return None - return export.name - - @property - @reader - def is_reexported(self) -> bool: - """Determines if the symbol is re-exported from a different file. - - A re-export occurs when a symbol is imported into a file and then exported - from that same file. - - Returns: - bool: True if the symbol is re-exported from a different file than where - it was defined, False otherwise. - """ - return any(x.node_type == NodeType.EXPORT and x.file != self.file for x in self.symbol_usages + self.file.symbol_usages) - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Returns the import string for a symbol. - - Generates the import statement needed to import a symbol from its module. - - Args: - alias (str | None): Optional alias for the symbol. - module (str | None): Optional module name to import from. - import_type (ImportType): Type of import to generate. - is_type_import (bool): Indicates if it's a type-only import. - - Returns: - str: The formatted import string. - - Raises: - NotImplementedError: If called on the base class. - """ - msg = "The subclass must implement `to_import_string`." - raise NotImplementedError(msg) diff --git a/src/codegen/sdk/core/interfaces/has_attribute.py b/src/codegen/sdk/core/interfaces/has_attribute.py deleted file mode 100644 index 1fcada30b..000000000 --- a/src/codegen/sdk/core/interfaces/has_attribute.py +++ /dev/null @@ -1,15 +0,0 @@ -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - - -Attribute = TypeVar("Attribute", bound="Editable") - - -class HasAttribute(Generic[Attribute]): - @abstractmethod - def resolve_attribute(self, name: str) -> Attribute | None: - """Resolve an attribute belonging to this object.""" - pass diff --git a/src/codegen/sdk/core/interfaces/has_block.py b/src/codegen/sdk/core/interfaces/has_block.py deleted file mode 100644 index 0d26a92a1..000000000 --- a/src/codegen/sdk/core/interfaces/has_block.py +++ /dev/null @@ -1,146 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.statements.comment import Comment -from codegen.sdk.extensions.sort import sort_editables -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.decorator import Decorator - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.symbol_groups.comment_group import CommentGroup - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") -TDecorator = TypeVar("TDecorator", bound="Decorator") - - -@apidoc -class HasBlock(Expression, Generic[TCodeBlock, TDecorator]): - """An interface for any code object that has a block of code, e.g. a function, class, etc. - - Attributes: - code_block: The block of code associated with the code object. - """ - - code_block: TCodeBlock - - # =======[ CODE BLOCK ]====== - def _parse_code_block(self, body_node: TSNode | None = None) -> TCodeBlock | None: - """Returns the code block of the function.""" - body_node = body_node or self.ts_node.child_by_field_name("body") - if not body_node: - return None - parent_block = None - level = 0 # Level 0 is reserved for files - parent = self.parent - while parent is not None and parent is not parent.parent: - if isinstance(parent, HasBlock) and hasattr(parent, "code_block"): - parent_block = parent.code_block - level = parent_block.level + 1 - break - parent = parent.parent - - return self.ctx.node_classes.code_block_cls(body_node, level, parent_block, self) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls within the code block and its decorators. - - Args: - None - - Returns: - list[FunctionCall]: A sorted list of FunctionCall objects representing all - function calls in the code block and its decorators. The list may contain - duplicates. - """ - fcalls = self.code_block.function_calls - for dec in self.decorators: - fcalls.extend(dec.function_calls) - return sort_editables(fcalls, dedupe=False) - - # =======[ DECORATORS ]======= - - @property - @abstractmethod - def is_decorated(self) -> bool: - """Check if the symbol has decorators. - - A helper method to determine if a function, class, or method has any - applied decorators. - - Returns: - bool: True if the symbol has one or more decorators, False otherwise. - """ - - # TODO: class def + function are almost copied of this function? just use the HasBlock definition? - @property - @abstractmethod - def decorators(self) -> list[TDecorator]: - """Returns list of all decorators on this Symbol. - - Gets all decorators associated with a code entity (function, class, method). - - Returns: - list[TDecorator]: A list of Decorator objects. Empty list if no decorators are present. - """ - - @writer - def add_decorator(self, new_decorator: str, skip_if_exists: bool = False) -> bool: - """Adds a decorator to a function or method. - - Adds a new decorator to the symbol's definition before the first non-comment extended node with proper indentation. - - Args: - new_decorator (str): The decorator to add, including the '@' symbol. - skip_if_exists (bool, optional): If True, skips adding if the decorator exists. - - Returns: - bool: True if the decorator was added, False if skipped. - """ - if skip_if_exists: - if new_decorator in self.decorators: - return False - # Get the top most extended ts_node that excludes docstrings - extended_ts_nodes = self.extended_nodes - # Iterate through the extended nodes and find the first node that is not a comment - for node in extended_ts_nodes: - if not isinstance(node, Comment): - break - node.insert_before(new_decorator, fix_indentation=True) - return True - - @property - @abstractmethod - @reader - def docstring(self) -> CommentGroup | None: - """Retrieves the docstring of the expression. - - Args: - None - - Returns: - CommentGroup | None: The docstring as a CommentGroup if it exists, None otherwise. - """ - - @abstractmethod - @writer - def set_docstring(self, docstring: str) -> None: - """Sets or updates the docstring for the current entity. - - Modifies the entity's docstring by either replacing an existing one or creating a new one. - - Args: - docstring (str): The new docstring content to set. - - Returns: - None: This method doesn't return anything. - """ diff --git a/src/codegen/sdk/core/interfaces/has_name.py b/src/codegen/sdk/core/interfaces/has_name.py deleted file mode 100644 index e8c09be5a..000000000 --- a/src/codegen/sdk/core/interfaces/has_name.py +++ /dev/null @@ -1,96 +0,0 @@ -from functools import cached_property - -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.core.expressions.defined_name import DefinedName -from codegen.sdk.core.expressions.name import Name -from codegen.shared.decorators.docs import apidoc, noapidoc - - -@apidoc -class HasName: - """An interface for any node object that has a name.""" - - _name_node: Name | ChainedAttribute | DefinedName | None = None - - @cached_property - @reader - def name(self) -> str | None: - """Retrieves the base name of the object without namespace prefixes. - - Returns: - str | None: The base name of the object, or None if no name node is associated. - """ - if isinstance(self._name_node, ChainedAttribute): - return self._name_node.attribute.source - return self._name_node._source if self._name_node else None - - @cached_property - @reader - def full_name(self) -> str | None: - """Returns the full name of the object, including the namespace path. - - For class methods, this returns the parent class's full name followed by the method name. For chained attributes (e.g., 'a.b'), this returns the full chained name. - - Returns: - str | None: The complete qualified name of the object. Returns None if no name is available. - """ - if isinstance(self._name_node, ChainedAttribute): - return self._name_node.full_name - if isinstance(self._name_node, DefinedName): - from codegen.sdk.core.function import Function - - if isinstance(self, Function) and self.is_method: - return self.parent_class.full_name + "." + self.name - # if self.parent_symbol == self or self.parent_symbol.full_name is None: - # return self.name - # return self.parent_symbol.full_name + "." + self.name - return self.name - - @reader - def get_name(self) -> Name | ChainedAttribute | None: - """Returns the name node of the object. - - Args: - None - - Returns: - Name | ChainedAttribute | None: The name node of the object. Can be a Name node for simple names, - a ChainedAttribute for names with namespaces (e.g., a.b), or None if the object has no name. - """ - return self._name_node - - @writer - def set_name(self, name: str) -> None: - """Sets the name of a code element. - - Modifies the name of the object's underlying name node. Works with both simple names and chained attributes (e.g., 'a.b'). - - Args: - name (str): The new name to set for the object. - - Returns: - None - """ - if self._name_node: - self._name_node.rename_if_matching(self.name, name) - - @writer - def rename(self, name: str) -> None: - """Sets the name of an object and updates all its usages. - - Args: - name (str): The new name to assign to the object. - - Returns: - None - """ - self.set_name(name) - - @noapidoc - @commiter - def _add_name_usage(self, usage_type: UsageKind): - if name := self.get_name(): - if resolved := name.resolved_symbol(): - self._add_symbol_usages(usage_type, [resolved]) diff --git a/src/codegen/sdk/core/interfaces/has_symbols.py b/src/codegen/sdk/core/interfaces/has_symbols.py deleted file mode 100644 index 2c8bbe445..000000000 --- a/src/codegen/sdk/core/interfaces/has_symbols.py +++ /dev/null @@ -1,117 +0,0 @@ -from collections.abc import Iterator -from itertools import chain -from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar - -from codegen.sdk.core.utils.cache_utils import cached_generator -from codegen.shared.decorators.docs import py_noapidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.function import Function - from codegen.sdk.core.import_resolution import Import, ImportStatement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.typescript.class_definition import TSClass - from codegen.sdk.typescript.export import TSExport - from codegen.sdk.typescript.file import TSFile - from codegen.sdk.typescript.function import TSFunction - from codegen.sdk.typescript.import_resolution import TSImport - from codegen.sdk.typescript.statements.import_statement import TSImportStatement - from codegen.sdk.typescript.symbol import TSSymbol - -logger = get_logger(__name__) - - -TFile = TypeVar("TFile", bound="SourceFile") -TSymbol = TypeVar("TSymbol", bound="Symbol") -TImportStatement = TypeVar("TImportStatement", bound="ImportStatement") -TGlobalVar = TypeVar("TGlobalVar", bound="Assignment") -TClass = TypeVar("TClass", bound="Class") -TFunction = TypeVar("TFunction", bound="Function") -TImport = TypeVar("TImport", bound="Import") -FilesParam = ParamSpec("FilesParam") - -TSGlobalVar = TypeVar("TSGlobalVar", bound="Assignment") - - -class HasSymbols(Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TFunction, TImport]): - """Abstract interface for files in a codebase. - - Abstract interface for files in a codebase. - """ - - @cached_generator() - def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -> Iterator[TFile]: - """Generator for yielding files of the current container's scope.""" - msg = "This method should be implemented by the subclass" - raise NotImplementedError(msg) - - @property - def symbols(self) -> list[TSymbol]: - """Get a recursive list of all symbols in files container.""" - return list(chain.from_iterable(f.symbols for f in self.files_generator())) - - @property - def import_statements(self) -> list[TImportStatement]: - """Get a recursive list of all import statements in files container.""" - return list(chain.from_iterable(f.import_statements for f in self.files_generator())) - - @property - def global_vars(self) -> list[TGlobalVar]: - """Get a recursive list of all global variables in files container.""" - return list(chain.from_iterable(f.global_vars for f in self.files_generator())) - - @property - def classes(self) -> list[TClass]: - """Get a recursive list of all classes in files container.""" - return list(chain.from_iterable(f.classes for f in self.files_generator())) - - @property - def functions(self) -> list[TFunction]: - """Get a recursive list of all functions in files container.""" - return list(chain.from_iterable(f.functions for f in self.files_generator())) - - @property - @py_noapidoc - def exports(self) -> "list[TSExport]": - """Get a recursive list of all exports in files container.""" - return list(chain.from_iterable(f.exports for f in self.files_generator())) - - @property - def imports(self) -> list[TImport]: - """Get a recursive list of all imports in files container.""" - return list(chain.from_iterable(f.imports for f in self.files_generator())) - - def get_symbol(self, name: str) -> TSymbol | None: - """Get a symbol by name in files container.""" - return next((s for s in self.symbols if s.name == name), None) - - def get_import_statement(self, name: str) -> TImportStatement | None: - """Get an import statement by name in files container.""" - return next((s for s in self.import_statements if s.name == name), None) - - def get_global_var(self, name: str) -> TGlobalVar | None: - """Get a global variable by name in files container.""" - return next((s for s in self.global_vars if s.name == name), None) - - def get_class(self, name: str) -> TClass | None: - """Get a class by name in files container.""" - return next((s for s in self.classes if s.name == name), None) - - def get_function(self, name: str) -> TFunction | None: - """Get a function by name in files container.""" - return next((s for s in self.functions if s.name == name), None) - - @py_noapidoc - def get_export( - self: "HasSymbols[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]", - name: str, - ) -> "TSExport | None": - """Get an export by name in files container (supports only typescript).""" - return next((s for s in self.exports if s.name == name), None) - - def get_import(self, name: str) -> TImport | None: - """Get an import by name in files container.""" - return next((s for s in self.imports if s.name == name), None) diff --git a/src/codegen/sdk/core/interfaces/has_value.py b/src/codegen/sdk/core/interfaces/has_value.py deleted file mode 100644 index eaffb870e..000000000 --- a/src/codegen/sdk/core/interfaces/has_value.py +++ /dev/null @@ -1,35 +0,0 @@ -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.expressions.expression import Expression -from codegen.shared.decorators.docs import apidoc - - -@apidoc -class HasValue: - """An interface for any node object that has a value.""" - - _value_node: Expression | None - - @property - @reader - def value(self) -> Expression | None: - """Gets the value node of the object. - - Returns: - Expression | None: The value node of the object. None if no value is set. - """ - return self._value_node - - @writer - def set_value(self, value: str) -> None: - """Sets the value of the node's value Expression. - - Updates the value of the underlying Expression node if it exists. No action is taken if the value node is None. - - Args: - value (str): The new value to set. - - Returns: - None - """ - if self._value_node is not None: - self._value_node.edit(value) diff --git a/src/codegen/sdk/core/interfaces/importable.py b/src/codegen/sdk/core/interfaces/importable.py deleted file mode 100644 index 4ea73eafe..000000000 --- a/src/codegen/sdk/core/interfaces/importable.py +++ /dev/null @@ -1,129 +0,0 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar, Union - -from tree_sitter import Node as TSNode - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageType -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.enums import EdgeType -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.extensions.sort import sort_editables -from codegen.shared.decorators.docs import apidoc, noapidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.symbol import Symbol - -Parent = TypeVar("Parent", bound="Editable") - -logger = get_logger(__name__) - - -@apidoc -class Importable(Expression[Parent], HasName, Generic[Parent]): - """An interface for any node object that can import (or reference) an exportable symbol eg. All nodes that are on the graph must inherit from here - - Class, function, imports, exports, etc. - """ - - node_id: int - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - if not hasattr(self, "node_id"): - self.node_id = ctx.add_node(self) - super().__init__(ts_node, file_node_id, ctx, parent) - if self.file: - self.file._nodes.append(self) - - @proxy_property - @reader(cache=False) - def dependencies(self, usage_types: UsageType | None = UsageType.DIRECT, max_depth: int | None = None) -> list[Union["Symbol", "Import"]]: - """Returns a list of symbols that this symbol depends on. - - Args: - usage_types (UsageType | None): The types of dependencies to search for. Defaults to UsageType.DIRECT. - max_depth (int | None): Maximum depth to traverse in the dependency graph. If provided, will recursively collect - dependencies up to this depth. Defaults to None (only direct dependencies). - - Returns: - list[Union[Symbol, Import]]: A list of symbols and imports that this symbol depends on, - sorted by file location. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - """ - # Get direct dependencies for this symbol and its descendants - avoid = set(self.descendant_symbols) - deps = [] - for symbol in self.descendant_symbols: - deps.extend(filter(lambda x: x not in avoid, symbol._get_dependencies(usage_types))) - - if max_depth is not None and max_depth > 1: - # For max_depth > 1, recursively collect dependencies - seen = set(deps) - for dep in list(deps): # Create a copy of deps to iterate over - if isinstance(dep, Importable): - next_deps = dep.dependencies(usage_types=usage_types, max_depth=max_depth - 1) - for next_dep in next_deps: - if next_dep not in seen: - seen.add(next_dep) - deps.append(next_dep) - - return sort_editables(deps, by_file=True) - - @reader(cache=False) - @noapidoc - def _get_dependencies(self, usage_types: UsageType) -> list[Union["Symbol", "Import"]]: - """Symbols that this symbol depends on. - - Opposite of `usages` - """ - # TODO: sort out attribute usages in dependencies - edges = [x for x in self.ctx.out_edges(self.node_id) if x[2].type == EdgeType.SYMBOL_USAGE] - unique_dependencies = [] - for edge in edges: - if edge[2].usage.usage_type is None or edge[2].usage.usage_type in usage_types: - dependency = self.ctx.get_node(edge[1]) - unique_dependencies.append(dependency) - return sort_editables(unique_dependencies, by_file=True) - - @commiter - @noapidoc - def recompute(self, incremental: bool = False) -> list["Importable"]: - """Recompute the dependencies of this symbol. - - Returns: - A list of importables that need to be updated now this importable has been updated. - """ - if incremental: - self._remove_internal_edges(EdgeType.SYMBOL_USAGE) - try: - self._compute_dependencies() - except Exception as e: - logger.exception(f"Error in file {self.file.path} while computing dependencies for symbol {self.name}") - raise e - if incremental: - return self.descendant_symbols + self.file.get_nodes(sort=False) - return [] - - @commiter - @noapidoc - def _remove_internal_edges(self, edge_type: EdgeType | None = None) -> None: - """Removes edges from itself to its children from the codebase graph. - - Returns a list of node ids for edges that were removed. - """ - # Must store edges to remove in a static read-only view before removing to avoid concurrent dict modification - for v in self.ctx.successors(self.node_id, edge_type=edge_type): - self.ctx.remove_edge(self.node_id, v.node_id, edge_type=edge_type) - - @property - @noapidoc - def descendant_symbols(self) -> list[Self]: - return [self] diff --git a/src/codegen/sdk/core/interfaces/inherits.py b/src/codegen/sdk/core/interfaces/inherits.py deleted file mode 100644 index 10a67a3d6..000000000 --- a/src/codegen/sdk/core/interfaces/inherits.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.expressions import Type -from codegen.sdk.core.interfaces.supports_generic import SupportsGenerics -from codegen.sdk.enums import EdgeType - -if TYPE_CHECKING: - from collections.abc import Generator - - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.external_module import ExternalModule - from codegen.sdk.core.interface import Interface - -TType = TypeVar("TType", bound=Type) - - -class Inherits(SupportsGenerics, Generic[TType]): - """This symbol inherits from other symbols.""" - - @commiter - @abstractmethod - def compute_superclass_dependencies(self) -> None: - pass - - @reader - def _get_superclasses(self, max_depth: int | None = None) -> list[Class | ExternalModule | Interface]: - """Returns a list of all classes that this class extends, up to max_depth.""" - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.interface import Interface - - # Implements the python MRO, IE: by level - seen = set() - - def traverse_classes(classes: list[Inherits], depth: int = 0) -> Generator[Class | Interface | ExternalModule, None, None]: - if max_depth is not None and depth >= max_depth: - return - next_level = [] - for node in classes: - for result in self.ctx.successors(node.node_id, edge_type=EdgeType.SUBCLASS): - if result.node_id not in seen: - seen.add(result.node_id) - yield result - if isinstance(result, Class) or isinstance(result, Interface): - next_level.append(result) - if len(next_level) > 0: - yield from traverse_classes(next_level, depth + 1) - - return list(traverse_classes([self])) - - @reader - def _get_subclasses(self, max_depth: int | None = None) -> list[Class | ExternalModule | Interface]: - """Returns a list of all classes that subclass this class, up to max_depth.""" - # Implements the python MRO, IE: by level - seen = set() - - def traverse_classes(classes: list[Inherits], depth: int = 0) -> Generator[Class | Interface, None, None]: - if max_depth and depth >= max_depth: - return - next_level = [] - for node in classes: - for result in self.ctx.predecessors(node.node_id, edge_type=EdgeType.SUBCLASS): - if result.node_id not in seen: - seen.add(result.node_id) - yield result - next_level.append(result) - if len(next_level) > 0: - yield from traverse_classes(next_level, depth + 1) - - return list(traverse_classes([self])) diff --git a/src/codegen/sdk/core/interfaces/parseable.py b/src/codegen/sdk/core/interfaces/parseable.py deleted file mode 100644 index 28a995892..000000000 --- a/src/codegen/sdk/core/interfaces/parseable.py +++ /dev/null @@ -1,11 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -class Parseable(ABC): - @abstractmethod - def parse(self, ctx: "CodebaseContext") -> None: - """Adds itself and its children to the codebase graph.""" diff --git a/src/codegen/sdk/core/interfaces/resolvable.py b/src/codegen/sdk/core/interfaces/resolvable.py deleted file mode 100644 index 4906cc0cd..000000000 --- a/src/codegen/sdk/core/interfaces/resolvable.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import abstractmethod -from typing import Generic - -from typing_extensions import TypeVar - -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.shared.decorators.docs import noapidoc - -Parent = TypeVar("Parent", bound=Editable) - - -class Resolvable(Chainable[Parent], Generic[Parent]): - """Represents a class resolved to another symbol during the compute dependencies step.""" - - @abstractmethod - @noapidoc - @writer - def rename_if_matching(self, old: str, new: str) -> None: ... diff --git a/src/codegen/sdk/core/interfaces/supports_generic.py b/src/codegen/sdk/core/interfaces/supports_generic.py deleted file mode 100644 index 725df9076..000000000 --- a/src/codegen/sdk/core/interfaces/supports_generic.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import TYPE_CHECKING, Generic, Self - -from typing_extensions import TypeVar - -from codegen.sdk.core.expressions.named_type import NamedType -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters -from codegen.sdk.extensions.utils import cached_property -from codegen.shared.decorators.docs import noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.expressions import Type - -TType = TypeVar("TType", bound="Type") - - -class SupportsGenerics(Symbol, Generic[TType]): - """A symbol that supports generics. - - Attributes: - type_parameters: The type parameters of the symbol, if any. - """ - - type_parameters: TypeParameters[TType, Self] | None = None - - @cached_property - @noapidoc - def generics(self) -> dict[str, TType]: - if self.type_parameters: - return {param.name: param for param in self.type_parameters if isinstance(param, NamedType)} - return {} diff --git a/src/codegen/sdk/core/interfaces/typeable.py b/src/codegen/sdk/core/interfaces/typeable.py deleted file mode 100644 index 3d23a4f93..000000000 --- a/src/codegen/sdk/core/interfaces/typeable.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.placeholder.placeholder_type import TypePlaceholder -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from codegen.sdk.codebase.resolution_stack import ResolutionStack - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.interfaces.editable import Editable - - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Typeable(Chainable[Parent], Generic[TType, Parent]): - """An interface for any node object that can be typed, eg. function parameters, variables, etc. - - Attributes: - type: The type annotation associated with this node - """ - - type: TType | TypePlaceholder[Self] - - @commiter - def _init_type(self, type_name: str = "type") -> None: - self.type = self.child_by_field_name(type_name, placeholder=TypePlaceholder) - - @property - @reader - def is_typed(self) -> bool: - """Indicates if a node has an explicit type annotation. - - Returns: - bool: True if the node has an explicit type annotation, False otherwise. - """ - return self.type - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - if isinstance(self.type, Chainable): - yield from self.with_resolution_frame(self.type) diff --git a/src/codegen/sdk/core/interfaces/unwrappable.py b/src/codegen/sdk/core/interfaces/unwrappable.py deleted file mode 100644 index 341e3352f..000000000 --- a/src/codegen/sdk/core/interfaces/unwrappable.py +++ /dev/null @@ -1,25 +0,0 @@ -from abc import abstractmethod -from typing import Generic - -from typing_extensions import TypeVar - -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.shared.decorators.docs import apidoc - -Parent = TypeVar("Parent", bound=Editable) - - -@apidoc -class Unwrappable(Expression[Parent], Generic[Parent]): - """An abstract representation of an expression that can be unwrapped. - Expressions that can be unwrapped include binary expressions and ternary expressions. - """ - - @abstractmethod - def unwrap(self, node: Expression | None = None) -> None: - """Unwrap this expression, removing parenthesis and other syntax elements while maintaining the function of the code. - - Args: - node: the node that's remaining. If None, assume all children of this expression are kept - """ diff --git a/src/codegen/sdk/core/interfaces/usable.py b/src/codegen/sdk/core/interfaces/usable.py deleted file mode 100644 index f1d2ed450..000000000 --- a/src/codegen/sdk/core/interfaces/usable.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import Usage, UsageType -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.enums import EdgeType -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.export import Export - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol import Symbol -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Usable(Importable[Parent], Generic[Parent]): - """An interface for any node object that can be referenced by another node.""" - - @proxy_property - @reader(cache=False) - def symbol_usages(self, usage_types: UsageType | None = None) -> list[Import | Symbol | Export]: - """Returns a list of symbols that use or import the exportable object. - - Args: - usage_types (UsageType | None): The types of usages to search for. Defaults to any. - - Returns: - list[Import | Symbol | Export]: A list of symbols that use or import the exportable object. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - """ - symbol_usages = [] - for usage in self.usages(usage_types=usage_types): - symbol_usages.append(usage.usage_symbol.parent_symbol) - return list(dict.fromkeys(symbol_usages)) - - @proxy_property - @reader(cache=False) - def usages(self, usage_types: UsageType | None = None) -> list[Usage]: - """Returns a list of usages of the exportable object. - - Retrieves all locations where the exportable object is used in the codebase. By default, returns all usages, such as imports or references within the same file. - - Args: - usage_types (UsageType | None): Specifies which types of usages to include in the results. Default is any usages. - - Returns: - list[Usage]: A sorted list of Usage objects representing where this exportable is used, ordered by source location in reverse. - - Raises: - ValueError: If no usage types are specified or if only ALIASED and DIRECT types are specified together. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - """ - if usage_types == UsageType.DIRECT | UsageType.ALIASED: - msg = "Combination of only Aliased and Direct usages makes no sense" - raise ValueError(msg) - - assert self.node_id is not None - usages_to_return = [] - in_edges = self.ctx.in_edges(self.node_id) - for edge in in_edges: - meta_data = edge[2] - if meta_data.type == EdgeType.SYMBOL_USAGE: - usage = meta_data.usage - if usage_types is None or usage.usage_type in usage_types: - usages_to_return.append(usage) - return sorted(dict.fromkeys(usages_to_return), key=lambda x: x.match.ts_node.start_byte if x.match else x.usage_symbol.ts_node.start_byte, reverse=True) - - def rename(self, new_name: str, priority: int = 0) -> tuple[NodeId, NodeId]: - """Renames a symbol and updates all its references in the codebase. - - Args: - new_name (str): The new name for the symbol. - priority (int): Priority of the edit operation. Defaults to 0. - - Returns: - tuple[NodeId, NodeId]: A tuple containing the file node ID and the new node ID of the renamed symbol. - """ - self.set_name(new_name) - - for usage in self.usages(UsageType.DIRECT | UsageType.INDIRECT | UsageType.CHAINED): - usage.match.rename_if_matching(self.name, new_name) diff --git a/src/codegen/sdk/core/interfaces/wrapper_expression.py b/src/codegen/sdk/core/interfaces/wrapper_expression.py deleted file mode 100644 index 626dd1544..000000000 --- a/src/codegen/sdk/core/interfaces/wrapper_expression.py +++ /dev/null @@ -1,55 +0,0 @@ -from abc import abstractmethod -from collections.abc import Generator -from typing import TYPE_CHECKING, Self, final, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.importable import Importable - - -class IWrapper(Chainable, Editable): - """Any expression or statement that contains another expression. - - This is a simple interface to unwrap the nested expression. - """ - - @property - @abstractmethod - @reader - def value(self) -> Expression | None: - """The value of the object.""" - - @reader - @final - def resolve(self) -> Expression: - """Resolves the wrapper expression and returns the first concrete expression.""" - cur_val = self.value - while cur_val and isinstance(cur_val, IWrapper): - cur_val = cur_val.value - return cur_val - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.value) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - if self._value_node: - self.resolve()._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendent_symbols(self) -> list["Importable"]: - return self.resolve().descendant_symbols diff --git a/src/codegen/sdk/core/node_id_factory.py b/src/codegen/sdk/core/node_id_factory.py deleted file mode 100644 index eb9f20558..000000000 --- a/src/codegen/sdk/core/node_id_factory.py +++ /dev/null @@ -1 +0,0 @@ -NodeId = int diff --git a/src/codegen/sdk/core/parser.py b/src/codegen/sdk/core/parser.py deleted file mode 100644 index 5ea44f27e..000000000 --- a/src/codegen/sdk/core/parser.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Generic, Protocol, Self, TypeVar - -from rich.console import Console - -from codegen.sdk.core.expressions.placeholder_type import PlaceholderType -from codegen.sdk.core.expressions.value import Value -from codegen.sdk.core.statements.symbol_statement import SymbolStatement -from codegen.sdk.utils import find_first_function_descendant, find_import_node - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.codebase.node_classes.node_classes import NodeClasses - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -Parent = TypeVar("Parent", bound="Editable") - - -class CanParse(Protocol, Generic[Parent]): - def __init__(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: ... - - -Expression = TypeVar("Expression", bound="CanParse") -Parent = TypeVar("Parent", bound="Editable") - - -@dataclass -class Parser(Generic[Expression]): - symbol_map: dict[str, type[Symbol]] - expressions: dict[str, type[Expression]] - types: dict[str, type[Type] | dict[str, type[Type]]] - type_node: str - _uncovered_nodes: set[str] = field(default_factory=set) - _should_log: bool = False - _console: Console = field(default_factory=lambda: Console()) - - def _process_type(self, expr_type: type[Type] | dict[str, type[Type]], node: TSNode) -> tuple[type[Type], TSNode]: - if isinstance(expr_type, dict): - for child in node.named_children: - if v := expr_type.get(child.type, None): - return v, child - if node.type not in self._uncovered_nodes: - self.log(f"Cannot handle nested type {node.type}, {expr_type}, {node.named_children}") - self._uncovered_nodes.add(node.type) - return PlaceholderType, node - return expr_type, node - - @classmethod - def from_node_classes(cls, node_classes: NodeClasses, log_parse_warnings: bool = False) -> Self: - return cls(symbol_map=node_classes.symbol_map, expressions=node_classes.expression_map, types=node_classes.type_map, type_node=node_classes.type_node_type, _should_log=log_parse_warnings) - - def parse_expression(self, node: TSNode | None, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, *args, default: type[Expression] = Value, **kwargs) -> Expression[Parent] | None: - if node is None: - return None - if node.type == self.type_node: - return self.parse_type(node, file_node_id, ctx, parent) - assert default is not None - if default == Value: - if previous := parent.file._range_index.get_canonical_for_range(node.range, node.kind_id): - return previous - if symbol_cls := self.symbol_map.get(node.type, None): - ret = symbol_cls(node, file_node_id, ctx, parent, *args, **kwargs) - else: - expr_type = self.expressions.get(node.type, default) - ret = expr_type(node, file_node_id, ctx, parent) - if default == Value: - ret.file._range_index.mark_as_canonical(ret) - if isinstance(ret, Value): - ret.children - return ret - - def log_unparsed(self, node: TSNode) -> None: - if self._should_log and node.is_named and node.type not in self._uncovered_nodes: - self._uncovered_nodes.add(node.type) - self.log(f"Encountered unimplemented node {node.type} with text {node.text.decode('utf-8')}") - - def parse_type(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> Type: - if node.type == self.type_node: - return self.parse_type(node.named_children[0], file_node_id, ctx, parent) - if expr_type := self.types.get(node.type, None): - expr_type, node = self._process_type(expr_type, node) - return expr_type(node, file_node_id, ctx, parent) - self.log_unparsed(node) - from codegen.sdk.core.expressions.placeholder_type import PlaceholderType - - return PlaceholderType(node, file_node_id, ctx, parent) - - def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock) -> list[Statement]: - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.core.statements.expression_statement import ExpressionStatement - from codegen.sdk.core.statements.return_statement import ReturnStatement - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.core.statements.symbol_statement import SymbolStatement - from codegen.sdk.typescript.function import _VALID_TYPE_NAMES - from codegen.sdk.typescript.statements.assignment_statement import TSAssignmentStatement - from codegen.sdk.typescript.statements.attribute import TSAttribute - from codegen.sdk.typescript.statements.comment import TSComment - from codegen.sdk.typescript.statements.for_loop_statement import TSForLoopStatement - from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement - from codegen.sdk.typescript.statements.import_statement import TSImportStatement - from codegen.sdk.typescript.statements.labeled_statement import TSLabeledStatement - from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement - from codegen.sdk.typescript.statements.try_catch_statement import TSTryCatchStatement - from codegen.sdk.typescript.statements.while_statement import TSWhileStatement - - statements = [] - - if node.type in self.expressions or node.type == "expression_statement": - return [ExpressionStatement(node, file_node_id, ctx, parent, 0, expression_node=node)] - - for child in node.named_children: - # =====[ Functions + Methods ]===== - if child.type in _VALID_TYPE_NAMES: - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - elif child.type == "import_statement": - statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements))) - # =====[ Classes ]===== - elif child.type in ("class_declaration", "abstract_class_declaration"): - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Interface Declarations ]===== - elif child.type == "interface_declaration": - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Type Alias Declarations ]===== - elif child.type == "type_alias_declaration": - if import_node := find_import_node(child): - statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node)) - else: - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Enum Declarations ]===== - elif child.type == "enum_declaration": - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Exports ]===== - elif child.type == "export_statement" or child.text.decode("utf-8") == "export *;": - statements.append(ExportStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Non-symbol statements ] ===== - elif child.type == "comment": - statements.append(TSComment.from_code_block(child, parent, pos=len(statements))) - elif child.type == "return_statement": - statements.append(ReturnStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "if_statement": - statements.append(TSIfBlockStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type in ["while_statement", "do_statement"]: - statements.append(TSWhileStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type in ["for_statement", "for_in_statement"]: - statements.append(TSForLoopStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "try_statement": - statements.append(TSTryCatchStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "switch_statement": - statements.append(TSSwitchStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "labeled_statement": - statements.append(TSLabeledStatement(child, file_node_id, ctx, parent, len(statements))) - elif child.type in ["lexical_declaration", "variable_declaration"]: - if function_node := find_first_function_descendant(child): - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements), function_node)) - elif import_node := find_import_node(child): - statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node)) - else: - statements.append( - TSAssignmentStatement.from_assignment( - child, file_node_id, ctx, parent, pos=len(statements), assignment_node=next(var for var in child.named_children if var.type == "variable_declarator") - ) - ) - elif child.type in ["public_field_definition", "property_signature", "enum_assignment"]: - statements.append(TSAttribute(child, file_node_id, ctx, parent, pos=len(statements))) - elif child.type == "expression_statement": - if import_node := find_import_node(child): - statements.append(TSImportStatement(child, file_node_id, ctx, parent, pos=len(statements), source_node=import_node)) - continue - - for var in child.named_children: - if var.type == "string": - statements.append(TSComment.from_code_block(var, parent, pos=len(statements))) - elif var.type in ["assignment_expression", "augmented_assignment_expression"]: - statements.append(TSAssignmentStatement.from_assignment(child, file_node_id, ctx, parent, pos=len(statements), assignment_node=var)) - else: - statements.append(ExpressionStatement(child, file_node_id, ctx, parent, pos=len(statements), expression_node=var)) - elif child.type in self.expressions: - statements.append(ExpressionStatement(child, file_node_id, ctx, parent, len(statements), expression_node=child)) - else: - self.log("Couldn't parse statement with type: %s", child.type) - statements.append(Statement.from_code_block(child, parent, pos=len(statements))) - statements[-1].nested_code_blocks - - return statements - - def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock) -> list[Statement]: - from codegen.sdk.core.statements.expression_statement import ExpressionStatement - from codegen.sdk.core.statements.raise_statement import RaiseStatement - from codegen.sdk.core.statements.return_statement import ReturnStatement - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.python.statements.assignment_statement import PyAssignmentStatement - from codegen.sdk.python.statements.attribute import PyAttribute - from codegen.sdk.python.statements.break_statement import PyBreakStatement - from codegen.sdk.python.statements.comment import PyComment - from codegen.sdk.python.statements.for_loop_statement import PyForLoopStatement - from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement - from codegen.sdk.python.statements.import_statement import PyImportStatement - from codegen.sdk.python.statements.match_statement import PyMatchStatement - from codegen.sdk.python.statements.pass_statement import PyPassStatement - from codegen.sdk.python.statements.try_catch_statement import PyTryCatchStatement - from codegen.sdk.python.statements.while_statement import PyWhileStatement - from codegen.sdk.python.statements.with_statement import WithStatement - - statements = [] - - # Handles a Tree sitter anomaly where comments in the block are not included in the block node - prev_sibling = node.prev_sibling - top_comments = [] - while prev_sibling and prev_sibling.type == "comment": - top_comments.insert(0, prev_sibling) - prev_sibling = prev_sibling.prev_sibling - - for comment in top_comments: - statements.append(PyComment.from_code_block(comment, parent, pos=len(statements))) - - for child in node.named_children: - # =====[ Decorated definitions ]===== - if child.type == "decorated_definition": - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Functions ]===== - elif child.type == "function_definition": - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Classes ]===== - elif child.type == "class_definition": - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - - # =====[ Imports ] ===== - elif child.type in ["import_statement", "import_from_statement", "future_import_statement"]: - statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements))) - # =====[ Non-symbol statements ] ===== - elif child.type == "comment": - statements.append(PyComment.from_code_block(child, parent, pos=len(statements))) - elif child.type == "raise_statement": - statements.append(RaiseStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "return_statement": - statements.append(ReturnStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "if_statement": - statements.append(PyIfBlockStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "with_statement": - statements.append(WithStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "pass_statement": - statements.append(PyPassStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "break_statement": - statements.append(PyBreakStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "while_statement": - statements.append(PyWhileStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "for_statement": - statements.append(PyForLoopStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "match_statement": - statements.append(PyMatchStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "try_statement": - statements.append(PyTryCatchStatement.from_code_block(child, parent, pos=len(statements))) - elif child.type == "expression_statement": - for var in child.named_children: - if var.type == "string": - statements.append(PyComment.from_code_block(var, parent, pos=len(statements))) - elif var.type in ["assignment", "augmented_assignment"]: - from codegen.sdk.core.class_definition import Class - - if isinstance(parent.parent, Class): - statements.append(PyAttribute(child, file_node_id, ctx, parent, len(statements), var)) - else: - statements.append(PyAssignmentStatement.from_assignment(child, file_node_id, ctx, parent, pos=len(statements), assignment_node=var)) - else: - statements.append(ExpressionStatement(child, file_node_id, ctx, parent, pos=len(statements), expression_node=var)) - else: - self.log("Couldn't parse statement with type: %s", node.type) - statements.append(Statement.from_code_block(child, parent, pos=len(statements))) - statements[-1].nested_code_blocks - return statements - - def report(self): - if self._uncovered_nodes: - self._console.print(f"Encountered unimplemented nodes {self._uncovered_nodes}") - - def log(self, message: str, *args): - if self._should_log: - try: - self._console.log(message % args) - except (KeyError, IndexError, ValueError, TypeError): - self._console.log(message, *args) - pass diff --git a/src/codegen/sdk/core/placeholder/__init__.py b/src/codegen/sdk/core/placeholder/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/core/placeholder/placeholder.py b/src/codegen/sdk/core/placeholder/placeholder.py deleted file mode 100644 index 0783ebe16..000000000 --- a/src/codegen/sdk/core/placeholder/placeholder.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Literal, TypeVar - -from codegen.sdk.core.autocommit import repr_func -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Placeholder(ABC, Generic[Parent]): - """A placeholder for a node that does not exist yet. - - Use bool checks (ie is node) to check if the node exists. You can call edit to replace the - placeholder with a real node and it will automatically insert formatting. - """ - - _parent_node: Parent - - def __init__(self, parent: Parent) -> None: - self._parent_node = parent - - def __bool__(self) -> Literal[False]: - return False - - def __str__(self) -> str: - return self.__repr__() - - @repr_func - def __repr__(self) -> str: - """Represents the object as a string for logging purposes. - - Returns: - str: The class name of the object. - """ - return f"{self.__class__.__name__}" - - def remove(self, *args, **kwargs) -> None: - """Removes this element from its parent container. - - Args: - *args: Variable length argument list. Unused. - **kwargs: Arbitrary keyword arguments. Unused. - - Returns: - None - """ - pass - - @abstractmethod - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Replaces the content of a placeholder node with new source code. - - Modifies the parent node to include the new source code. Can optionally fix - indentation and handle deduplication. - - Args: - new_src (str): The new source code to replace the placeholder with. - fix_indentation (bool, optional): Whether to automatically fix the - indentation of the new source. Defaults to False. - priority (int, optional): Priority value for conflict resolution. - Defaults to 0. - dedupe (bool, optional): Whether to prevent duplicate insertions. - Defaults to True. - - Returns: - None - """ - pass diff --git a/src/codegen/sdk/core/placeholder/placeholder_stub.py b/src/codegen/sdk/core/placeholder/placeholder_stub.py deleted file mode 100644 index d99b375a9..000000000 --- a/src/codegen/sdk/core/placeholder/placeholder_stub.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.placeholder.placeholder import Placeholder -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class StubPlaceholder(Placeholder[Parent], Generic[Parent]): - """A placeholder for a stub that does not exist. - Can be populated using the `edit` method. - """ - - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Edits the source code of this placeholder node. - - Modifies the source code with the provided new source code. - - Args: - new_src (str): The new source code to replace the current source code. - fix_indentation (bool, optional): Whether to automatically fix the indentation of the new source code. Defaults to False. - priority (int, optional): The priority of this edit operation. Higher priority edits are applied first. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate this edit against other pending edits. Defaults to True. - - Returns: - None - """ - raise NotImplementedError diff --git a/src/codegen/sdk/core/placeholder/placeholder_type.py b/src/codegen/sdk/core/placeholder/placeholder_type.py deleted file mode 100644 index 5beeddf46..000000000 --- a/src/codegen/sdk/core/placeholder/placeholder_type.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.placeholder.placeholder import Placeholder -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class TypePlaceholder(Placeholder[Parent], Generic[Parent]): - """A placeholder for a Type node that does not exist. - Can be populated using the `edit` method. - """ - - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Edits the type annotation of a placeholder node. - - Modifies the source code by adding or updating a type annotation after a node. - Handles cases where the parent node has children and adjusts spacing accordingly. - - Args: - new_src (str): The new type annotation text to be inserted. - fix_indentation (bool, optional): Whether to fix the indentation of the new source. - priority (int, optional): Priority of the edit operation. - dedupe (bool, optional): Whether to remove duplicate edits. - - Returns: - None - """ - if len(self._parent_node.children) == 0: - self._parent_node.insert_after(": " + new_src, newline=False) - else: - if len(self._parent_node.children) > 1 and " " in self._parent_node.source: - new_src = new_src + " " - self._parent_node.children[0].insert_after(": " + new_src, newline=False) diff --git a/src/codegen/sdk/core/plugins/__init__.py b/src/codegen/sdk/core/plugins/__init__.py deleted file mode 100644 index 4cf23493c..000000000 --- a/src/codegen/sdk/core/plugins/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from codegen.sdk.core.plugins.axios import AxiosApiFinder -from codegen.sdk.core.plugins.flask import FlaskApiFinder -from codegen.sdk.core.plugins.modal import ModalApiFinder - -PLUGINS = [ - FlaskApiFinder(), - AxiosApiFinder(), - ModalApiFinder(), -] diff --git a/src/codegen/sdk/core/plugins/axios.py b/src/codegen/sdk/core/plugins/axios.py deleted file mode 100644 index d4efc27ac..000000000 --- a/src/codegen/sdk/core/plugins/axios.py +++ /dev/null @@ -1,52 +0,0 @@ -from logging import getLogger -from typing import TYPE_CHECKING - -from codegen.sdk.core.detached_symbols.argument import Argument -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import String -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.plugins.plugin import Plugin -from codegen.sdk.core.symbol_groups.dict import Dict -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.core.codebase import TSCodebaseType - - -logger = getLogger(__name__) - - -class AxiosApiFinder(Plugin): - language: ProgrammingLanguage = ProgrammingLanguage.TYPESCRIPT - - def execute(self, codebase: "TSCodebaseType"): - logger.info("Scanning for Axios API calls") - api_calls = 0 - - def resolve_http(val) -> Editable: - if isinstance(val, Argument): - val = val.value - if isinstance(val, Dict): - return resolve_http(val.get("baseURL")) - return val.resolved_value - - for imp in codebase.imports: - if "axios" in imp.module.source: - for usage in imp.usages: - call = usage.match.parent.parent - if not isinstance(call, FunctionCall): - continue - if call.name in ("isAxiosError",): - continue - val = resolve_http(call.args[0]) - if isinstance(val, String): - url = val.content.rsplit("--")[-1] - for split in url.rsplit("/"): - if split: - url = split - break - url = url.removesuffix(".modal.run") - call.register_api_call(url) - api_calls += 1 - if api_calls > 0: - logger.info(f"Found {api_calls} Axios API calls") diff --git a/src/codegen/sdk/core/plugins/flask.py b/src/codegen/sdk/core/plugins/flask.py deleted file mode 100644 index 783cd7f9c..000000000 --- a/src/codegen/sdk/core/plugins/flask.py +++ /dev/null @@ -1,63 +0,0 @@ -from logging import getLogger -from typing import TYPE_CHECKING - -from codegen.sdk.core.plugins.plugin import Plugin -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.core.codebase import PyCodebaseType -logger = getLogger(__name__) - - -def is_flask_route(decorator): - return (decorator.call and decorator.call.name in ["route", "get", "post", "put", "delete"]) or (decorator.call and "route" in decorator.call.name) - - -def extract_route(decorator): - if decorator.call and decorator.call.args: - return decorator.call.args[0].value - return None - - -def extract_methods(decorator): - if decorator.call and len(decorator.call.args) > 1: - methods_arg = decorator.call.args[1] - if isinstance(methods_arg, list): - return [m.strip("'\"") for m in methods_arg.value.strip("[]").split(",")] - return None - - -class FlaskApiFinder(Plugin): - language: ProgrammingLanguage = ProgrammingLanguage.PYTHON - - def execute(self, codebase: "PyCodebaseType"): - logger.info("Scanning for flask endpoints") - endpoints = 0 - for func in codebase.functions: - for decorator in func.decorators: - if is_flask_route(decorator): - route = extract_route(decorator) - methods = extract_methods(decorator) or ["GET"] - if route: - func.register_api(route) - endpoints += 1 - - for cls in codebase.classes: - class_route = None - for decorator in cls.decorators: - if is_flask_route(decorator): - class_route = extract_route(decorator) - break - - for method in cls.methods: - if method.name.lower() in ["get", "post", "put", "delete", "patch"]: - route = class_route or "" - for decorator in method.decorators: - if is_flask_route(decorator): - route += extract_route(decorator) or "" - if route: - method.register_api(route) - endpoints += 1 - - if endpoints > 0: - logger.info(f"Found {endpoints} modal endpoints") diff --git a/src/codegen/sdk/core/plugins/modal.py b/src/codegen/sdk/core/plugins/modal.py deleted file mode 100644 index c81406df9..000000000 --- a/src/codegen/sdk/core/plugins/modal.py +++ /dev/null @@ -1,26 +0,0 @@ -from logging import getLogger -from typing import TYPE_CHECKING - -from codegen.sdk.core.plugins.plugin import Plugin -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.core.codebase import PyCodebaseType - -logger = getLogger(__name__) - - -class ModalApiFinder(Plugin): - language: ProgrammingLanguage = ProgrammingLanguage.PYTHON - - def execute(self, codebase: "PyCodebaseType"): - logger.info("Scanning for modal endpoints") - endpoints = 0 - for func in codebase.functions: - for decorator in func.decorators: - if decorator.full_name == "web_endpoint": - value = decorator.call.get_arg_by_parameter_name("label").value.content - func.register_api(value) - endpoints += 1 - if endpoints > 0: - logger.info(f"Found {endpoints} modal endpoints") diff --git a/src/codegen/sdk/core/plugins/plugin.py b/src/codegen/sdk/core/plugins/plugin.py deleted file mode 100644 index 93c7f3ed5..000000000 --- a/src/codegen/sdk/core/plugins/plugin.py +++ /dev/null @@ -1,17 +0,0 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from codegen.sdk.core.interfaces.editable import Editable -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.core.codebase import Codebase - - -class Plugin(ABC): - language: ProgrammingLanguage - - @abstractmethod - def execute(self, codebase: "Codebase"): ... - def register_api(self, method: str, label: str, node: Editable): - pass diff --git a/src/codegen/sdk/core/statements/__init__.py b/src/codegen/sdk/core/statements/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/core/statements/assignment_statement.py b/src/codegen/sdk/core/statements/assignment_statement.py deleted file mode 100644 index c6e6880eb..000000000 --- a/src/codegen/sdk/core/statements/assignment_statement.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.expressions.multi_expression import MultiExpression - from codegen.sdk.core.interfaces.has_block import HasBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") -TAssignment = TypeVar("TAssignment", bound="Assignment") - - -@apidoc -class AssignmentStatement(Statement[TCodeBlock], HasValue, Generic[TCodeBlock, TAssignment]): - """A class that represents an assignment statement in a codebase, such as `x = 1`, `a, b = 1, 2`, `const {a: b} = myFunc(),`, etc. - - This includes potentially multiple Assignments via `statement.assignments`, which represent each assignment of a value to a variable within this statement. - - For example, assigning to a destructured object, or assigning multiple values to multiple variables in a single statement. - - Attributes: - assignments: A list of assignments within the statement. - left: The left-hand side expression of the first assignment. - right: The right-hand side expression of the first assignment, or None if not applicable. - """ - - statement_type = StatementType.ASSIGNMENT - assignments: list[TAssignment] - left: Expression[TAssignment] - right: Expression[TAssignment] | None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TCodeBlock, pos: int, assignment_node: TSNode) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos=pos) - self.assignments = self._DEPRECATED_parse_assignments().expressions - if len(self.assignments) == 0: - msg = f"No assignments found: {self.ts_node}\n\n{self.source}" - raise ValueError(msg) - - first_assignment: TAssignment = self.assignments[0] - self._name_node = self.ctx.parser.parse_expression(first_assignment.ts_node, self.file_node_id, self.ctx, parent, default=Name) - self.left = first_assignment.left - self.right = first_assignment.value - self._value_node = self.right - - @abstractmethod - def _parse_assignments(self, ts_node: TSNode) -> MultiExpression[HasBlock, TAssignment]: ... - - @abstractmethod - def _DEPRECATED_parse_assignments(self) -> MultiExpression[HasBlock, TAssignment]: ... - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - # We compute assignment dependencies separately - pass - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - """Returns the nested symbols of the importable object.""" - symbols = [] - for assignment in self.assignments: - symbols.extend(assignment.descendant_symbols) - return symbols diff --git a/src/codegen/sdk/core/statements/attribute.py b/src/codegen/sdk/core/statements/attribute.py deleted file mode 100644 index 1972f2398..000000000 --- a/src/codegen/sdk/core/statements/attribute.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations - -import itertools -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.core.statements.assignment_statement import AssignmentStatement -from codegen.sdk.core.statements.statement import StatementType -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.extensions.resolution import ResolutionStack - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock | None") -TAssignment = TypeVar("TAssignment", bound="Assignment") - - -@apidoc -class Attribute(AssignmentStatement[TCodeBlock, TAssignment], Usable, Chainable, Generic[TCodeBlock, TAssignment]): - """Abstract representation of an attribute on a class definition. - - Attributes: - assignment: The assignment associated with the attribute. - """ - - statement_type = StatementType.CLASS_ATTRIBUTE - assignment: TAssignment - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TCodeBlock, pos: int, assignment_node: TSNode) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos=pos, assignment_node=assignment_node) - self.assignment = self.assignments[0] - self._name_node = self.assignment.get_name() - - @abstractmethod - def _get_name_node(self) -> TSNode: - """Returns the ID node from the root node of the symbol.""" - - @property - @abstractmethod - def is_private(self) -> bool: - """Indicates whether the attribute is private. - - Determines if the attribute is a private class attribute by checking if it follows Python's private naming convention (i.e., starts with an underscore). - - Returns: - bool: True if the attribute is private (starts with underscore), False otherwise. - """ - ... - - @property - @abstractmethod - def is_optional(self) -> bool: - """Returns whether the attribute is optional. - - Determines if an attribute's type annotation indicates it is optional/nullable. For example, - if the attribute's type is `Optional[str]` or `str | None`, this will return True. - - Returns: - bool: True if the attribute is marked as optional/nullable, False otherwise. - """ - - @writer - def set_value(self, value: str) -> None: - """Sets the value of a node's assignment. - - Updates the value of a node's assignment to the specified string value. - - Args: - value (str): The new value to set for the assignment. - - Returns: - None - """ - self.assignment.set_value(value) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - return list(itertools.chain.from_iterable(assignment.descendant_symbols for assignment in self.assignments)) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.assignments[0]) diff --git a/src/codegen/sdk/core/statements/block_statement.py b/src/codegen/sdk/core/statements/block_statement.py deleted file mode 100644 index 2ca003c72..000000000 --- a/src/codegen/sdk/core/statements/block_statement.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.statements.statement import Statement -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class BlockStatement(Statement[TCodeBlock], HasBlock, ABC, Generic[TCodeBlock]): - """Statement which contains a block. - - Attributes: - code_block: The code block contained within the statement, if it exists. - """ - - code_block: TCodeBlock | None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.code_block = self._parse_code_block() - if self.code_block: - self.code_block.parse() - - @property - @reader - def nested_code_blocks(self) -> list[TCodeBlock]: - """Returns all nested CodeBlocks within the statement. - - Gets all nested CodeBlocks contained within this BlockStatement. A BlockStatement may contain - at most one code block. - - Args: - None - - Returns: - list[TCodeBlock]: A list containing the statement's code block if it exists, otherwise an empty list. - """ - if self.code_block: - return [self.code_block] - return [] - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets all function calls within the statement's code block. - - Returns a list of FunctionCall instances contained within the statement's code block. If the statement does not have a code block, returns an empty list. - - Returns: - list[FunctionCall]: A list of function call instances within the code block. - """ - if self.code_block: - return self.code_block.function_calls - return [] - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.code_block: - self.code_block._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - if self.code_block: - symbols.extend(self.code_block.descendant_symbols) - return symbols diff --git a/src/codegen/sdk/core/statements/catch_statement.py b/src/codegen/sdk/core/statements/catch_statement.py deleted file mode 100644 index 6d7b36071..000000000 --- a/src/codegen/sdk/core/statements/catch_statement.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - - -Parent = TypeVar("Parent", bound="CodeBlock") - - -@apidoc -class CatchStatement(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): - """Abstract representation catch clause. - - Attributes: - code_block: The code block that may trigger an exception - condition: The condition which triggers this clause - """ - - condition: Expression[Self] | None = None - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.condition: - self.condition._compute_dependencies(usage_type, dest) - super()._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/core/statements/comment.py b/src/codegen/sdk/core/statements/comment.py deleted file mode 100644 index 9bc84b1ca..000000000 --- a/src/codegen/sdk/core/statements/comment.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - - -def lowest_indentation(text_blocks, skip_lines: int = 0): - if not text_blocks: - return 0 - - # Filter out empty strings and strings with only whitespace - non_empty_blocks = [block for block in text_blocks if block.strip()] - - # Skip the first n lines - non_empty_blocks = non_empty_blocks[skip_lines:] - - if not non_empty_blocks: - return 0 - - # Count leading spaces for each non-empty block - indentations = [len(block) - len(block.lstrip()) for block in non_empty_blocks] - - # Return the minimum indentation - return min(indentations) - - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class Comment(Statement[TCodeBlock], Generic[TCodeBlock]): - """Abstract representation of comment statements.""" - - statement_type = StatementType.COMMENT - - @property - @reader - def nested_code_blocks(self: Statement[TCodeBlock]) -> list[TCodeBlock]: - """Returns a list of nested code blocks within the statement. - - A property that returns an empty list as comments, by default, do not have any nested code blocks. - - Args: - self: The statement instance. - - Returns: - list[TCodeBlock]: An empty list, as comments do not contain nested code blocks. - """ - return [] - - @noapidoc - @classmethod - @reader - def from_expression_statement(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Statement, code_block: TCodeBlock, pos: int, comment_node: TSNode) -> Comment: - return cls(ts_node, file_node_id, ctx, code_block, pos) - - @property - @reader - def text(self) -> str: - """Returns the text content of the comment. - - Returns the actual text content of the comment without any comment delimiters (e.g., '#', '/* */'). For accessing - the complete comment including delimiters, use the `source` property instead. - - Returns: - str: The text content of the comment with delimiters removed. - """ - return self._parse_comment() - - @text.setter - @writer - def text(self, new_text: str) -> None: - """Replace the text content of a comment while preserving the comment delimiters and - autoformatting. - - Args: - new_text (str): The new text content to replace the existing comment. This should be - the raw text without comment delimiters. - - Returns: - None - """ - self.edit_text(new_text) - - @writer - def edit_text(self, new_text: str) -> None: - """Replace the text of a comment with new text. - - Updates the comment text while maintaining proper comment delimiters (e.g., `#` or `/* */`) and formatting. - - Args: - new_text (str): The new text content to replace the existing comment text. - - Returns: - None - """ - # Generate comment block with new source - new_src = self._unparse_comment(new_text) - super().edit(new_src, fix_indentation=True, dedupe=True) - - @noapidoc - @commiter - def _parse_comment(self) -> str: - """Parse out the comment into its text content.""" - msg = "This method should be implemented by the subclass" - raise NotImplementedError(msg) - - @noapidoc - @commiter - def _unparse_comment(self, new_src: str): - """Unparses cleaned text content into a comment block.""" - msg = "This method should be implemented by the subclass" - raise NotImplementedError(msg) - - @commiter - @noapidoc - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - pass diff --git a/src/codegen/sdk/core/statements/export_statement.py b/src/codegen/sdk/core/statements/export_statement.py deleted file mode 100644 index c8a8d10df..000000000 --- a/src/codegen/sdk/core/statements/export_statement.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.typescript.export import TSExport -from codegen.sdk.typescript.statements.import_statement import TSImportStatement -from codegen.sdk.utils import find_first_ancestor -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.export import Export - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - -TExport = TypeVar("TExport", bound="Export") - - -@apidoc -class ExportStatement(Statement["TSCodeBlock"], Generic[TExport]): - """Abstract representation of a single export statement that appears in a file. One export - statement can export multiple symbols from a single source. - - Attributes: - exports: A list of the individual exports this statement represents - """ - - exports: Collection[TExport, Self] - statement_type = StatementType.EXPORT_STATEMENT - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int): - super().__init__(ts_node, parent.file_node_id, parent.ctx, parent, pos) - export_node = self.ts_node - if node := self.child_by_field_types(["export_clause", "export_statement"]): - export_node = node.ts_node - self.exports = Collection(export_node, self.file_node_id, self.ctx, self, bracket_size=2) - if declaration := ts_node.child_by_field_name("declaration"): - exports = TSExport.from_export_statement_with_declaration(ts_node, declaration, file_node_id, ctx, self, pos) - elif value := ts_node.child_by_field_name("value"): - exports = TSExport.from_export_statement_with_value(self.ts_node, value, self.file_node_id, self.ctx, self, self.index) - else: - exports = [] - if source_node := ts_node.child_by_field_name("source"): - # ==== [ Re-export ] ==== - # e.g. export { name1, name2 } from './other-module'; - import_statement = TSImportStatement(ts_node, file_node_id, ctx, parent, pos, source_node=source_node) - for imp in import_statement.imports: - name_node = imp.alias.ts_node if imp.alias else None - export = TSExport( - ts_node=find_first_ancestor(imp._name_node.ts_node, ["export_statement", "export_clause", "export_specifier"]) if imp._name_node else imp.ts_node, - file_node_id=file_node_id, - ctx=ctx, - name_node=name_node, - declared_symbol=imp, - parent=self.exports, - ) - exports.append(export) - elif export_clause := next((child for child in ts_node.named_children if child.type == "export_clause"), None): - export_node = export_clause - # ==== [ Named export ] ==== - # e.g. export { variable, functionName, ClassName }; - for export_specifier in export_clause.named_children: - if export_specifier.type == "comment": - continue - name_node = export_specifier.child_by_field_name("name") - alias_node = export_specifier.child_by_field_name("alias") or name_node - export = TSExport(ts_node=export_specifier, file_node_id=file_node_id, ctx=ctx, name_node=alias_node, exported_symbol=name_node, parent=self.exports) - exports.append(export) - else: - # ==== [ Export assignment ] ==== - # Examples: `export = XYZ;`, `export = function foo() {}`, `export = function() {}`, `export = { f1, f2 }` - # No other named exports can exist alongside this type of export in the file - exports.extend(TSExport.from_export_statement_with_value(self.ts_node, ts_node.named_children[0], self.file_node_id, self.ctx, self, self.index)) - self.exports._init_children(exports) - for exp in self.exports: - exp.export_statement = self - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - # We compute export dependencies separately - pass - - def _removed_child(self) -> None: - self.exports._removed_child() - - def _removed_child_commit(self) -> None: - self.exports._removed_child_commit() - - @property - def reexports(self) -> list[TSExport]: - """Retrieves a list of re-exported symbols from this export statement. - - Returns: - list[TSExport]: A list of re-exported symbols within the current export context, - excluding external exports. - """ - reexports = [] - for export in self.exports: - if export.is_reexport() and not export.is_external_export: - reexports.append(export) - return reexports - - def _smart_remove(self, child, *args, **kwargs) -> bool: - if self.exports.uncommitted_len == 1 and child.ts_node.is_named: - self.remove() - return True - return super()._smart_remove(child, *args, **kwargs) diff --git a/src/codegen/sdk/core/statements/expression_statement.py b/src/codegen/sdk/core/statements/expression_statement.py deleted file mode 100644 index 6aeda5f99..000000000 --- a/src/codegen/sdk/core/statements/expression_statement.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.wrapper_expression import IWrapper -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_block import HasBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - - -Parent = TypeVar("Parent", bound="HasBlock") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class ExpressionStatement(Statement, HasValue, IWrapper, Generic[Parent, TCodeBlock]): - """Abstract representation of any expression statements that resolves to an expression. In some - languages without a statement delimiter, expression statement and the enclosed expression looks - the same in text. - - For example, in Python: - ```python - x = 1 - ``` - The above code is an expression statement, but its expression value is an assignment. - - In Typescript: - ```typescript - x = 1; - ``` - The above code is also an expression statement, but its expression value is an assignment excluding the semicolon. - """ - - statement_type = StatementType.EXPRESSION_STATEMENT - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int, expression_node: TSNode) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos=pos) - self._value_node = self._parse_expression(expression_node) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Get all function calls contained within this expression statement. - - Returns a list of function calls that are direct or nested within the expression of this statement. This retrieves function calls from the resolved value of the expression. - - Returns: - list[FunctionCall]: A list of FunctionCall objects representing all function calls contained within this statement. - """ - return self.resolve().function_calls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None): - if self._value_node: - self.resolve()._compute_dependencies(usage_type, dest) - - def _smart_remove(self, child, *args, **kwargs) -> bool: - return self.parent._smart_remove(child, *args, **kwargs) diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py deleted file mode 100644 index f8753cdac..000000000 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.core.statements.statement import StatementType -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.core.symbol import Symbol - - -Parent = TypeVar("Parent", bound="CodeBlock") - - -@apidoc -class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): - """Abstract representation of the for loop. - - Attributes: - item: The item being iterated over, if applicable. - iterable: The iterable expression that the loop iterates over. - """ - - statement_type = StatementType.FOR_LOOP_STATEMENT - item: Expression[Self] | None = None - iterable: Expression[Self] - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: - if self.item and isinstance(self.iterable, Chainable): - if start_byte is None or start_byte > self.iterable.end_byte: - if name == self.item: - for frame in self.iterable.resolved_type_frames: - if frame.generics: - yield next(iter(frame.generics.values())) - return - yield frame.top.node - return - elif isinstance(self.item, Collection): - for idx, item in enumerate(self.item): - if item == name: - for frame in self.iterable.resolved_type_frames: - if frame.generics and len(frame.generics) > idx: - yield list(frame.generics.values())[idx] - return - yield frame.top.node - return - yield from super().resolve_name(name, start_byte, strict=strict) diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py deleted file mode 100644 index 98e7fab8d..000000000 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ /dev/null @@ -1,308 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from functools import cached_property -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.function import Function -from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from collections.abc import Sequence - - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - - -TIfBlockStatement = TypeVar("TIfBlockStatement", bound="IfBlockStatement") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class IfBlockStatement(ConditionalBlock, Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): - """Abstract representation of the if/elif/else if/else statement block. - - For example, if there is a code block like: - if condition1: - block1 - elif condition2: - block2 - else: - block3 - This class represents the entire block, including the conditions and nested code blocks. - - Attributes: - condition: The condition expression for the if block. None if the block is an else block. - consequence_block: The code block that is executed if the condition is True. - """ - - statement_type = StatementType.IF_BLOCK_STATEMENT - condition: Expression[Self] | None - consequence_block: TCodeBlock - _alternative_blocks: list[TIfBlockStatement] | None # None if it is an elif or else block - _main_if_block: TIfBlockStatement - - @abstractmethod - def _parse_consequence_block(self) -> TCodeBlock: ... - - @abstractmethod - def _parse_alternative_blocks(self) -> list[TIfBlockStatement]: - """Returns the alternative blocks if they exist. - - Otherwise, returns empty list. This includes both elif and else blocks. - """ - - @commiter - @noapidoc - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - # Compute dependencies for all statements in the nested code blocks - if self.condition: - self.condition._compute_dependencies(usage_type, dest) - - self.consequence_block._compute_dependencies(usage_type, dest) - - for alt_block in self.alternative_blocks: - if alt_block.condition: - alt_block.condition._compute_dependencies(usage_type, dest) - alt_block.consequence_block._compute_dependencies(usage_type, dest) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls within the if block statement and its alternative blocks. - - Collects all function calls from the if block's condition, consequence block, and any alternative blocks (elif/else) - including their conditions and consequence blocks. - - Returns: - list[FunctionCall]: A list of function call objects found within this if block statement and its alternative blocks. - """ - fcalls = [] if self.condition is None else self.condition.function_calls - fcalls.extend(self.consequence_block.function_calls) - for alt_block in self.alternative_blocks: - if alt_block.condition: - fcalls.extend(alt_block.condition.function_calls) - fcalls.extend(alt_block.consequence_block.function_calls) - return fcalls - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - if self.condition: - symbols.extend(self.condition.descendant_symbols) - if self.consequence_block: - symbols.extend(self.consequence_block.descendant_symbols) - for alt_block in self.alternative_blocks: - if alt_block.condition: - symbols.extend(alt_block.condition.descendant_symbols) - if alt_block.consequence_block: - symbols.extend(alt_block.consequence_block.descendant_symbols) - return symbols - - @cached_property - @reader - def nested_code_blocks(self) -> list[TCodeBlock]: - """Returns all nested code blocks within an if/elif/else statement block. - - Returns a list of all CodeBlocks that are part of the current if/elif/else statement block, including the main if block's consequence block - and all alternative (elif/else) blocks' consequence blocks. - - Returns: - list[TCodeBlock]: A list of CodeBlock objects representing all nested code blocks within the statement. - """ - return [self.consequence_block] + [x.consequence_block for x in self.alternative_blocks] - - @property - @abstractmethod - def is_if_statement(self) -> bool: - """Returns whether the current block is an if block. - - A property that checks if the current block within an if/elif/else statement chain is an if block. - This includes the main if block but not elif or else blocks. - - Args: - None - - Returns: - bool: True if the current block is an if block, False if it is an elif or else block. - """ - - @property - @abstractmethod - def is_else_statement(self) -> bool: - """Indicates if the current block is an else block in an if/else statement chain. - - This property checks whether the current block represents an 'else' branch in a control flow statement. It helps in identifying and handling else - blocks differently from if/elif blocks, particularly when manipulating control flow structures. - - Returns: - bool: True if the current block is an else block, False otherwise. - """ - - @property - @abstractmethod - def is_elif_statement(self) -> bool: - """Indicates whether the current block is an elif block. - - A property that returns True if the current instance of IfBlockStatement is specifically an elif block, False for if or else blocks. - - Returns: - bool: True if the current block is an elif block, False for if or else blocks. - """ - - @property - @reader - def alternative_blocks(self) -> list[TIfBlockStatement]: - """Returns a list of alternative if/elif/else blocks for the current block. - - Gets the alternative blocks (elif/else blocks) based on the type of the current block: - - For if blocks: returns all alternative blocks - - For else blocks: returns empty list - - For elif blocks: returns all subsequent alternative blocks in the main if block - - Returns: - list[TIfBlockStatement]: A list of alternative if/elif/else blocks that are executed if the condition is False. - """ - if self.is_if_statement: - return self._alternative_blocks - if self.is_else_statement: - return [] - return [x for x in self._main_if_block.alternative_blocks if x.start_byte > self.start_byte] - - @proxy_property - @reader - def elif_statements(self) -> list[TIfBlockStatement]: - """Returns all elif blocks within the if block. - - Gets all alternative blocks that are specifically elif blocks (i.e., excluding else blocks) from an if statement. Can be called on any if/elif/else block to get subsequent elif blocks. - - Note: - This method can be called as both a property and a method. If used as a property, it is equivalent to invoking it without arguments. - - Returns: - list[TIfBlockStatement]: A list of elif block statements. Empty list if no elif blocks exist. - """ - return [alt for alt in self.alternative_blocks if alt.is_elif_statement] - - @property - @reader - def else_statement(self) -> TIfBlockStatement | None: - """Returns the else block within the if-statement. - - Gets the else block from the if-statement's alternative blocks if one exists. Only returns the else block, not elif blocks. - - Returns: - TIfBlockStatement | None: The else block statement if it exists, None otherwise. - """ - return next((alt for alt in self.alternative_blocks if alt.is_else_statement), None) - - @abstractmethod - def _else_if_to_if(self) -> None: - """Converts an elif block to an if block.""" - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Simplifies a conditional block by reducing its condition to a boolean value. - - This method modifies the if/elif/else block structure based on the provided boolean value. - When reducing to True, it unwraps the consequence block and adjusts subsequent elif/else blocks. - When reducing to False, it handles different cases for elif statements and main if blocks. - - Args: - bool_condition (bool): The boolean value to reduce the condition to. - If True, unwraps the consequence block and adjusts alternative blocks. - If False, removes or modifies the current block based on its type. - - Returns: - None - - Raises: - ValueError: If attempting to reduce a condition on an IfBlockStatement that doesn't have a condition - (like an else block). - """ - if self.condition is None: - msg = "Cannot reduce condition of an IfBlockStatement without a condition." - raise ValueError(msg) - - first_elif = next((x for x in self.elif_statements()), None) - - # ==== [ Reduce condition to True ] ==== - if bool_condition: - # If condition is reduced to True, unwrap the consequence block. - # If the first alternative block is an elif block, change the elif to if. - # If the first alternative block is else, remove the else block. - self.consequence_block.unwrap() - if first_elif: - first_elif._else_if_to_if() - elif (else_block := self.else_statement) is not None: - remove_start = self.consequence_block._get_line_ends()[-1].end_byte - else_block.remove_byte_range(remove_start, else_block.end_byte) - - # If the last statement in the consequence block is a return statement, remove all the lines after it. - if isinstance(self.parent, Function): - last_statement = self.consequence_block.get_statements(max_level=self.consequence_block.level)[-1] - if last_statement.statement_type == StatementType.RETURN_STATEMENT: - self.consequence_block.remove_byte_range(last_statement.end_byte, self.parent.end_byte) - - # ==== [ Reduce condition to False ] ==== - elif self.is_elif_statement: - # If the current block is an elif block, remove the elif block and nothing else. - remove_end_byte = first_elif.start_byte if first_elif else self.ts_node.end_byte - self.remove_byte_range(self.ts_node.start_byte, remove_end_byte) - else: - # ==== [ Main block ] ==== - # If condition is reduced to False, remove the if block. - # If the first alternative block is an elif block, change the elif to else. - # If the first alternative block is else, unwrap the else block. - if first_elif: - self.remove_byte_range(self.ts_node.start_byte, first_elif.start_byte) - first_elif._else_if_to_if() - elif (else_block := self.else_statement) is not None: - else_block.consequence_block.unwrap() - remove_end = else_block.consequence_block._get_line_starts()[0].start_byte - self.remove_byte_range(self.ts_node.start_byte, remove_end) - else: - self.remove() - - @property - @noapidoc - def other_possible_blocks(self) -> Sequence[ConditionalBlock]: - if self.is_if_statement: - return self.alternative_blocks - elif self.is_elif_statement: - main = self._main_if_block - statements = [main] - if main.else_statement: - statements.append(main.else_statement) - for statement in main.elif_statements: - if statement != self: - statements.append(statement) - return statements - else: - main = self._main_if_block - return [main, *main.elif_statements] - - @property - @noapidoc - def end_byte_for_condition_block(self) -> int: - if self.is_if_statement: - return self.consequence_block.end_byte - return self.end_byte - - @property - @noapidoc - def start_byte_for_condition_block(self) -> int: - if self.is_if_statement: - return self.consequence_block.start_byte - return self.start_byte diff --git a/src/codegen/sdk/core/statements/import_statement.py b/src/codegen/sdk/core/statements/import_statement.py deleted file mode 100644 index 2a67ee06a..000000000 --- a/src/codegen/sdk/core/statements/import_statement.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.statements.statement import Statement -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol_groups.collection import Collection - - -TSourceFile = TypeVar("TSourceFile", bound="SourceFile") -TImport = TypeVar("TImport", bound="Import") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class ImportStatement(Statement[TCodeBlock], Generic[TSourceFile, TImport, TCodeBlock]): - """Abstract representation of a single import statement that appears in a file. One import - statement can import multiple symbols from a single source. - - Attributes: - imports: A collection of the individual imports this statement represents - """ - - imports: Collection[TImport, Self] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TCodeBlock, pos: int) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - # Skip computing dependencies for import statements, since it is done during import resolution step - pass - - def _smart_remove(self, child, *args, **kwargs) -> bool: - if self.imports.uncommitted_len == 1 and child.ts_node.is_named: - self.remove() - return True - return super()._smart_remove(child, *args, **kwargs) diff --git a/src/codegen/sdk/core/statements/raise_statement.py b/src/codegen/sdk/core/statements/raise_statement.py deleted file mode 100644 index ec267f88b..000000000 --- a/src/codegen/sdk/core/statements/raise_statement.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - - -Parent = TypeVar("Parent", bound="CodeBlock") - - -@apidoc -class RaiseStatement(Statement[Parent], HasValue, Generic[Parent]): - """Abstract representation of raise statements, e.g. in Python: - - Example: - def f(x): - raise ValueError() - """ - - statement_type = StatementType.RAISE_STATEMENT - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - value_node = self._get_value_node() - self._value_node = self._parse_expression(value_node) if value_node else None - - def _get_value_node(self) -> TSNode | None: - if len(self.ts_node.children) == 1: - return None - return self.ts_node.children[1] - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets function calls within a raise statement's value expression. - - Returns: - list[FunctionCall]: A list of function calls in the raise statement's value expression, or an empty list if the value expression doesn't exist. - """ - if not self.value: - return [] - return self.value.function_calls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.value: - self.value._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/core/statements/return_statement.py b/src/codegen/sdk/core/statements/return_statement.py deleted file mode 100644 index d71c95134..000000000 --- a/src/codegen/sdk/core/statements/return_statement.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_block import HasBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - - -Parent = TypeVar("Parent", bound="HasBlock") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class ReturnStatement(Statement, HasValue, Generic[Parent, TCodeBlock]): - """Abstract representation of return statements, e.g. in Python: - - Example: - def f(x): - if x: - return x**2 # ReturnStatement - else: - return 1 # ReturnStatement - """ - - statement_type = StatementType.RETURN_STATEMENT - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - value_node = self._get_value_node() - self._value_node = self._parse_expression(value_node) if value_node else None - - def _get_value_node(self) -> TSNode | None: - if len(self.ts_node.children) == 1: - return None - return self.ts_node.children[1] - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns a list of function calls contained within this return statement. - - If the return statement has no value, an empty list is returned. Otherwise, returns the function calls contained in the value expression of the return statement. - - Returns: - list[FunctionCall]: A list of function calls contained in the return statement's value expression. - """ - if not self.value: - return [] - return self.value.function_calls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.value: - self.value._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/core/statements/statement.py b/src/codegen/sdk/core/statements/statement.py deleted file mode 100644 index 75fb4f569..000000000 --- a/src/codegen/sdk/core/statements/statement.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -from enum import StrEnum -from functools import cached_property -from typing import TYPE_CHECKING, Generic, Self, TypeVar, final - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.expressions import Expression -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.output.constants import ANGULAR_STYLE -from codegen.sdk.utils import find_all_descendants -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - import rich.repr - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection - - -@apidoc -class StatementType(StrEnum): - """Enum representing the different types of statements that can be parsed. - - Attributes: - COMMENT: Represents a comment statement. - ASSIGNMENT: Represents an assignment expression. - EXPRESSION_STATEMENT: Represents an expression statement. - CLASS_ATTRIBUTE: Represents a class attribute. - RETURN_STATEMENT: Represents a return statement. - RAISE_STATEMENT: Represents a raise statement. - WITH_STATEMENT: Represents a with statement. - PASS_STATEMENT: Represents a pass statement. - BREAK_STATEMENT: Represents a break statement. - LABELED_STATEMENT: Represents a labeled statement. - TRY_CATCH_STATEMENT: Represents a try-catch statement. - IF_BLOCK_STATEMENT: Represents an if block statement. - FOR_LOOP_STATEMENT: Represents a for loop statement. - WHILE_STATEMENT: Represents a while statement. - SWITCH_STATEMENT: Represents a switch statement. - SYMBOL_STATEMENT: Represents a symbol statement. - UNSPECIFIED: Represents any unparsed code snippet or graph node statements. - EXPORT_STATEMENT: Represents an export statement. - IMPORT_STATEMENT: Represents an import statement. - """ - - COMMENT = "comment" - ASSIGNMENT = "assignment_expression" - EXPRESSION_STATEMENT = "expression_statement" - CLASS_ATTRIBUTE = "class_attribute" - RETURN_STATEMENT = "return_statement" - RAISE_STATEMENT = "raise_statement" - WITH_STATEMENT = "with_statement" - PASS_STATEMENT = "pass_statement" - BREAK_STATEMENT = "pass_statement" - LABELED_STATEMENT = "labeled_statement" - TRY_CATCH_STATEMENT = "try_catch_statement" - IF_BLOCK_STATEMENT = "if_block_statement" - FOR_LOOP_STATEMENT = "for_loop_statement" - WHILE_STATEMENT = "while_statement" - SWITCH_STATEMENT = "switch_statement" - SYMBOL_STATEMENT = "symbol_statement" - # Any unparsed code snippet, or graph node statements (e.g. function definition) - UNSPECIFIED = "unspecified" - EXPORT_STATEMENT = "export_statement" - IMPORT_STATEMENT = "import_statement" - - -Parent = TypeVar("Parent", bound="CodeBlock") - - -@apidoc -class Statement(Expression[Parent], Generic[Parent]): - """Represents a single code statement, e.g. a function definition, an assignment, an if/else statement, etc.""" - - statement_type: StatementType = StatementType.UNSPECIFIED - _pos: int - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self._pos = pos - - def __rich_repr__(self) -> rich.repr.Result: - if self.parent: - yield "level", self.parent.level - yield from super().__rich_repr__() - - __rich_repr__.angular = ANGULAR_STYLE - - @property - def index(self) -> int: - """The 0-based index of the statement in the parent code block. - - Returns the sequential position of this statement within its containing code block. - - Returns: - int: The 0-based index of this statement within its parent code block. - """ - return self._pos - - @classmethod - @noapidoc - @final - def from_code_block(cls, ts_node: TSNode, code_block: CodeBlock, pos: int | None = None) -> Statement: - return cls(ts_node, code_block.file_node_id, code_block.ctx, parent=code_block, pos=pos) - - @cached_property - @reader - def nested_code_blocks(self) -> list[Parent]: - """Returns all nested code blocks within the statement. - - Finds and parses any immediate 'block' or 'statement_block' nodes within the statement. - - Returns: - list[TCodeBlock]: A list of parsed code blocks that are directly nested within this statement. Each block has a level one higher than its parent block. - """ - block_nodes = find_all_descendants(self.ts_node, {"block", "statement_block"}, max_depth=1) - - nested_blocks = [] - for block_node in block_nodes: - block = self.ctx.node_classes.code_block_cls(block_node, self.parent.level + 1, self.parent, self) - block.parse() - nested_blocks.append(block) - return nested_blocks - - @property - @reader - def nested_statements(self) -> list[MultiLineCollection[Statement[Self], Parent]]: - """Returns a list of statement collections within nested code blocks. - - Accesses and retrieves the statements from each code block nested within the current statement, - such as the statements within if/else branches or loop bodies. - - Returns: - A list where each element is a - collection of statements from one nested code block. Returns an empty list if there are no - nested code blocks. - """ - nested_code_blocks = self.nested_code_blocks - if len(nested_code_blocks) == 0: - return [] - - nested_statements = [] - for code_block in nested_code_blocks: - nested_statements.append(code_block.statements) - - return nested_statements - - def _get_indent(self) -> int: - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - - if isinstance(self.parent, CodeBlock): - return self.parent.level * 4 - return self.ts_node.start_point[1] - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None): - self._add_all_identifier_usages(usage_type, dest=dest) diff --git a/src/codegen/sdk/core/statements/switch_case.py b/src/codegen/sdk/core/statements/switch_case.py deleted file mode 100644 index 46f83af64..000000000 --- a/src/codegen/sdk/core/statements/switch_case.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.assignment import Assignment - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.statements.switch_statement import SwitchStatement - -Parent = TypeVar("Parent", bound="CodeBlock[SwitchStatement, Assignment]") - - -@apidoc -class SwitchCase(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): - """Abstract representation for a switch case. - - Attributes: - code_block: The code block that is executed if the condition is met - condition: The condition which triggers this case - """ - - condition: Expression[Self] | None = None - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.condition: - self.condition._compute_dependencies(usage_type, dest) - super()._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def other_possible_blocks(self) -> list[ConditionalBlock]: - """Returns the end byte for the specific condition block""" - return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/core/statements/switch_statement.py b/src/codegen/sdk/core/statements/switch_statement.py deleted file mode 100644 index c8af3431a..000000000 --- a/src/codegen/sdk/core/statements/switch_statement.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.statements.switch_case import SwitchCase - - -Parent = TypeVar("Parent", bound="CodeBlock") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") -TSwitchCase = TypeVar("TSwitchCase", bound="SwitchCase") - - -@apidoc -class SwitchStatement(Statement[Parent], Generic[Parent, TCodeBlock, TSwitchCase]): - """Abstract representation of the switch statement. - - Attributes: - value: The value to switch on. - cases: A list of switch cases. - """ - - statement_type = StatementType.SWITCH_STATEMENT - value: Expression[Self] - cases: list[TSwitchCase] = [] - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls within the switch statement. - - Gets the function calls from the value expression and all switch cases. - - Returns: - list[FunctionCall]: A list of all function calls found within the switch statement, - including those in the value expression and all switch cases. - """ - fcalls = self.value.function_calls - for case in self.cases: - fcalls.extend(case.function_calls) - return fcalls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self.value._compute_dependencies(usage_type, dest) - for case in self.cases: - case._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = self.value.descendant_symbols - for case in self.cases: - symbols.extend(case.descendant_symbols) - return symbols - - @property - @reader - @override - def nested_code_blocks(self) -> list[TCodeBlock]: - """Returns all nested CodeBlocks within the switch statement. - - Gets all code blocks from the switch statement's cases. Only includes code blocks - that are not None. - - Returns: - list[TCodeBlock]: A list of code blocks from all cases in the switch statement. - """ - nested_blocks = [] - for case in self.cases: - if case.code_block: - nested_blocks.append(case.code_block) - return nested_blocks diff --git a/src/codegen/sdk/core/statements/symbol_statement.py b/src/codegen/sdk/core/statements/symbol_statement.py deleted file mode 100644 index 76064fe4f..000000000 --- a/src/codegen/sdk/core/statements/symbol_statement.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol import Symbol - - -Parent = TypeVar("Parent", bound="CodeBlock") -Child = TypeVar("Child", bound="Symbol") - - -@apidoc -class SymbolStatement(Statement[Parent], Generic[Parent, Child]): - """A statement that represents a symbol definition in a codeblock. - - Examples include: - - a function definition, class definition, global variable assignment - - Attributes: - symbol: The symbol associated with this statement, representing a code element. - """ - - statement_type = StatementType.SYMBOL_STATEMENT - symbol: Child - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int, symbol_node: TSNode | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.symbol = self.ctx.parser.parse_expression(symbol_node or ts_node, file_node_id, ctx, parent=self) - - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls contained within the symbol associated with this statement. - - This property retrieves all function call nodes from the statement's underlying symbol. This is useful for tasks - like renaming function invocations or analyzing call patterns. Note that this operation may trigger a reparse of - the file and could be slow. - - Returns: - list[FunctionCall]: A list of FunctionCall objects representing all function calls within the symbol. - - Note: - Consider using function.call_sites instead if you already know which specific function you're looking for, - as it will be more performant. - """ - return self.symbol.function_calls - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - """Returns the nested symbols of the importable object.""" - return self.symbol.descendant_symbols diff --git a/src/codegen/sdk/core/statements/try_catch_statement.py b/src/codegen/sdk/core/statements/try_catch_statement.py deleted file mode 100644 index eca344b61..000000000 --- a/src/codegen/sdk/core/statements/try_catch_statement.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Generic, TypeVar, override - -from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.core.statements.statement import StatementType -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - - -Parent = TypeVar("Parent", bound="CodeBlock") - - -@apidoc -class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): - """Abstract representation of the try catch statement block. - - Attributes: - code_block: The code block that may trigger an exception - finalizer: The code block executed regardless of if an exception is thrown or not - """ - - statement_type = StatementType.TRY_CATCH_STATEMENT - finalizer: BlockStatement | None = None - - @noapidoc - @override - def is_true_conditional(self, descendant) -> bool: - if descendant.is_child_of(self.finalizer): - return False - return True - - @property - @noapidoc - def end_byte_for_condition_block(self) -> int: - if self.code_block: - return self.code_block.end_byte - else: - return self.end_byte - - @property - @noapidoc - def start_byte_for_condition_block(self) -> int: - if self.code_block: - return self.code_block.start_byte - 1 - else: - return self.start_byte diff --git a/src/codegen/sdk/core/statements/while_statement.py b/src/codegen/sdk/core/statements/while_statement.py deleted file mode 100644 index b203ad4cb..000000000 --- a/src/codegen/sdk/core/statements/while_statement.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class WhileStatement(Statement[TCodeBlock], HasBlock, ABC, Generic[TCodeBlock]): - """Abstract representation of the while statement block. - - Attributes: - condition: The condition expression of the while statement. - code_block: The code block that represents the body of the while statement. - """ - - statement_type = StatementType.WHILE_STATEMENT - condition: Expression[Self] - code_block: TCodeBlock - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.code_block = self._parse_code_block() - self.code_block.parse() - - @property - @reader - def nested_code_blocks(self) -> list[TCodeBlock]: - """Returns all nested CodeBlocks within the statement. - - Returns all code blocks that are nested within the while statement. For while statements, - this will always be a list containing only the single code block associated with the - while statement's body. - - Returns: - list[TCodeBlock]: A list containing the code blocks associated with this while - statement. - """ - return [self.code_block] - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls within the while statement block. - - Collects all function calls from both the condition expression and the code block. - - Returns: - list[FunctionCall]: A list of function calls found in the while statement's condition and code block. - """ - fcalls = self.condition.function_calls - fcalls.extend(self.code_block.function_calls) - return fcalls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self.condition._compute_dependencies(usage_type, dest) - self.code_block._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - symbols.extend(self.condition.descendant_symbols) - symbols.extend(self.code_block.descendant_symbols) - return symbols diff --git a/src/codegen/sdk/core/symbol.py b/src/codegen/sdk/core/symbol.py deleted file mode 100644 index cc0238b45..000000000 --- a/src/codegen/sdk/core/symbol.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, Literal, TypeVar - -from rich.markup import escape - -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind, UsageType -from codegen.sdk.core.detached_symbols.argument import Argument -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import Name, Value -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.core.expressions.defined_name import DefinedName -from codegen.sdk.core.interfaces.usable import Usable -from codegen.sdk.core.statements.statement import Statement -from codegen.sdk.enums import ImportType, NodeType, SymbolType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.output.constants import ANGULAR_STYLE -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - import rich.repr - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.export import Export - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.has_block import HasBlock - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol_groups.comment_group import CommentGroup - -Parent = TypeVar("Parent", bound="HasBlock") -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") - - -@apidoc -class Symbol(Usable[Statement["CodeBlock[Parent, ...]"]], Generic[Parent, TCodeBlock]): - """Abstract representation of a Symbol in a Codebase. A Symbol is a top-level entity in a file, e.g. a Function, Class, GlobalVariable, etc. - - Attributes: - symbol_type: The type of the symbol. - node_type: The type of the node, set to NodeType.SYMBOL. - """ - - symbol_type: SymbolType - node_type: Literal[NodeType.SYMBOL] = NodeType.SYMBOL - - def __init__( - self, - ts_node: TSNode, - file_id: NodeId, - ctx: CodebaseContext, - parent: Statement[CodeBlock[Parent, ...]], - name_node: TSNode | None = None, - name_node_type: type[Name] = DefinedName, - ) -> None: - super().__init__(ts_node, file_id, ctx, parent) - name_node = self._get_name_node(ts_node) if name_node is None else name_node - self._name_node = self._parse_expression(name_node, default=name_node_type) - from codegen.sdk.core.interfaces.has_block import HasBlock - - if isinstance(self, HasBlock): - self.code_block = self._parse_code_block() - self.parse(ctx) - if isinstance(self, HasBlock): - self.code_block.parse() - - def __rich_repr__(self) -> rich.repr.Result: - yield escape(self.filepath) + "::" + (self.full_name if self.full_name else "") - - __rich_repr__.angular = ANGULAR_STYLE - - @property - @noapidoc - def parent_symbol(self) -> Symbol | SourceFile | Import | Export: - """Returns the parent symbol of the symbol.""" - from codegen.sdk.core.export import Export - - parent = super().parent_symbol - if parent == self.file or isinstance(parent, Export): - # Top level symbol - return self - return parent - - @staticmethod - @noapidoc - def _get_name_node(ts_node: TSNode) -> TSNode | None: - """Returns the ID node from the root node of the symbol.""" - return ts_node.child_by_field_name("name") - - @property - @reader(cache=False) - def extended_nodes(self) -> list[Editable]: - """Returns a list of Editable nodes associated with this symbol, including extended symbols. - - Extended symbols include `export`, `public`, `decorator`, comments, and inline comments. - - Args: - self: The symbol instance. - - Returns: - list[Editable]: A list of Editable nodes containing the current symbol and its extended symbols, - sorted in the correct order. - """ - from codegen.sdk.core.interfaces.has_block import HasBlock - - comment_nodes = self.comment.symbols if self.comment else [] - inline_comment_nodes = self.inline_comment.symbols if self.inline_comment else [] - nodes = [self, *comment_nodes, *inline_comment_nodes] - new_ts_node = self.ts_node - - if isinstance(self, HasBlock) and self.is_decorated: - new_ts_node = self.ts_node.parent - - extended_nodes = [(Value(new_ts_node, self.file_node_id, self.ctx, self.parent) if node.ts_node == self.ts_node else node) for node in nodes] - return sort_editables(extended_nodes) - - @writer - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Replace the source of this node with new_src. - - Edits the source code of this node by replacing it with the provided new source code. If specified, the indentation of - the new source can be adjusted to match the current text's indentation. - - Args: - new_src (str): The new source code to replace the current source with. - fix_indentation (bool): If True, adjusts the indentation of new_src to match the current text's indentation. Defaults to False. - priority (int): The priority of this edit. Higher priority edits take precedence. Defaults to 0. - dedupe (bool): If True, prevents duplicate edits. Defaults to True. - - Returns: - None - """ - self.extended.edit(new_src, fix_indentation=fix_indentation, priority=priority, dedupe=dedupe) - - @property - @reader - def source(self) -> str: - """Returns the source code of the symbol. - - Gets the source code of the symbol from its extended representation, which includes any comments, docstrings, access identifiers, or decorators. - - Returns: - str: The complete source code of the symbol including any extended nodes. - """ - return self.extended.source - - @source.setter - @writer - def source(self, value) -> None: - """Sets the source code text of this Symbol. - - Replaces the current source code text with a new value by calling the edit method. - - Args: - value (str): The new source code text to replace the current text with. - - Returns: - None - """ - if self.source != value: - self.edit(value) - - @property - @abstractmethod - @reader - def comment(self) -> CommentGroup | None: - """Returns the comment group associated with the symbol, if any. - - Returns: - CommentGroup | None: The comment group containing all comments associated with the symbol if it exists, None otherwise. - """ - - @property - @abstractmethod - @reader - def inline_comment(self) -> CommentGroup | None: - """Returns the inline comment group associated with the symbol, if any. - - Returns: - CommentGroup | None: The inline comment group object associated with the symbol, or None if no inline comment exists. - """ - - @abstractmethod - @writer - def set_comment(self, comment: str) -> None: - """Sets a comment to the symbol. - - Updates or creates a comment for the symbol. If a comment already exists, it will be overridden. - If no comment exists, a new comment group will be created. - - Args: - comment (str): The comment text to set. - - Returns: - None - """ - - @abstractmethod - @writer - def add_comment(self, comment: str) -> None: - """Adds a comment to the symbol. - - Adds a comment to the top of a symbol. If a comment group already exists, the new comment will be appended - to the existing comment group. If no comment group exists, a new comment group will be created. - - Args: - comment (str): The comment text to add. - - Returns: - None - """ - - @abstractmethod - @writer - def set_inline_comment(self, comment: str) -> None: - """Sets an inline comment to the symbol. - - Adds or updates an inline comment for the symbol with the provided text. If an inline comment already exists, - it will be overridden. If no inline comment exists, a new inline comment will be created. - - Args: - comment (str): The text of the inline comment to be added or updated. - - Returns: - None - """ - - @noapidoc - @commiter - def parse(self, ctx: CodebaseContext) -> None: - """Adds itself as a symbol node in the graph, and an edge from the parent file to itself.""" - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - @writer - def insert_before(self, new_src: str, fix_indentation: bool = False, newline: bool = True, priority: int = 0, dedupe: bool = True, extended: bool = True) -> None: - """Inserts text before the current symbol node in the Abstract Syntax Tree. - - Handles insertion of new source code before a symbol, with special handling for extended nodes like comments and decorators. - The insertion can be done either before the symbol itself or before its extended nodes. - - Args: - new_src (str): The source code text to insert. - fix_indentation (bool): Whether to adjust the indentation of new_src to match current text. Defaults to False. - newline (bool): Whether to add a newline after insertion. Defaults to True. - priority (int): Priority of this edit operation. Higher priority edits are applied first. Defaults to 0. - dedupe (bool): Whether to remove duplicate insertions. Defaults to True. - extended (bool): Whether to insert before extended nodes like comments and decorators. Defaults to True. - - Returns: - None - """ - if extended: - first_node = self.extended_nodes[0] - # Skip extension for the child node - if isinstance(first_node, Symbol): - return first_node.insert_before(new_src, fix_indentation, newline, priority, dedupe, extended=False) - else: - return first_node.insert_before(new_src, fix_indentation, newline, priority, dedupe) - return super().insert_before(new_src, fix_indentation, newline, priority, dedupe) - - def move_to_file( - self, - file: SourceFile, - include_dependencies: bool = True, - strategy: Literal["add_back_edge", "update_all_imports", "duplicate_dependencies"] = "update_all_imports", - ) -> None: - """Moves the given symbol to a new file and updates its imports and references. - - This method moves a symbol to a new file and updates all references to that symbol throughout the codebase. The way imports are handled can be controlled via the strategy parameter. - - Args: - file (SourceFile): The destination file to move the symbol to. - include_dependencies (bool): If True, moves all dependencies of the symbol to the new file. If False, adds imports for the dependencies. Defaults to True. - strategy (str): The strategy to use for updating imports. Can be either 'add_back_edge' or 'update_all_imports'. Defaults to 'update_all_imports'. - - 'add_back_edge': Moves the symbol and adds an import in the original file - - 'update_all_imports': Updates all imports and usages of the symbol to reference the new file - - Returns: - None - - Raises: - AssertionError: If an invalid strategy is provided. - """ - encountered_symbols = {self} - self._move_to_file(file, encountered_symbols, include_dependencies, strategy) - - @noapidoc - def _move_to_file( - self, - file: SourceFile, - encountered_symbols: set[Symbol | Import], - include_dependencies: bool = True, - strategy: Literal["add_back_edge", "update_all_imports", "duplicate_dependencies"] = "update_all_imports", - ) -> tuple[NodeId, NodeId]: - """Helper recursive function for `move_to_file`""" - from codegen.sdk.core.import_resolution import Import - - # =====[ Arg checking ]===== - if file == self.file: - return file.file_node_id, self.node_id - if imp := file.get_import(self.name): - encountered_symbols.add(imp) - imp.remove() - - if include_dependencies: - # =====[ Move over dependencies recursively ]===== - for dep in self.dependencies: - if dep in encountered_symbols: - continue - - # =====[ Symbols - move over ]===== - if isinstance(dep, Symbol) and dep.is_top_level: - encountered_symbols.add(dep) - dep._move_to_file( - file=file, - encountered_symbols=encountered_symbols, - include_dependencies=include_dependencies, - strategy=strategy, - ) - - # =====[ Imports - copy over ]===== - elif isinstance(dep, Import): - if dep.imported_symbol: - file.add_import(imp=dep.imported_symbol, alias=dep.alias.source) - else: - file.add_import(imp=dep.source) - else: - for dep in self.dependencies: - # =====[ Symbols - add back edge ]===== - if isinstance(dep, Symbol) and dep.is_top_level: - file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False) - elif isinstance(dep, Import): - if dep.imported_symbol: - file.add_import(imp=dep.imported_symbol, alias=dep.alias.source) - else: - file.add_import(imp=dep.source) - - # =====[ Make a new symbol in the new file ]===== - file.add_symbol(self) - import_line = self.get_import_string(module=file.import_module_name) - - # =====[ Checks if symbol is used in original file ]===== - # Takes into account that it's dependencies will be moved - is_used_in_file = any( - usage.file == self.file and usage.node_type == NodeType.SYMBOL and usage not in encountered_symbols and (usage.start_byte < self.start_byte or usage.end_byte > self.end_byte) # HACK - for usage in self.symbol_usages - ) - - # ======[ Strategy: Duplicate Dependencies ]===== - if strategy == "duplicate_dependencies": - # If not used in the original file. or if not imported from elsewhere, we can just remove the original symbol - if not is_used_in_file and not any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages): - self.remove() - - # ======[ Strategy: Add Back Edge ]===== - # Here, we will add a "back edge" to the old file importing the symbol - elif strategy == "add_back_edge": - if is_used_in_file or any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages): - self.file.add_import(imp=import_line) - # Delete the original symbol - self.remove() - - # ======[ Strategy: Update All Imports ]===== - # Update the imports in all the files which use this symbol to get it from the new file now - elif strategy == "update_all_imports": - for usage in self.usages: - if isinstance(usage.usage_symbol, Import) and usage.usage_symbol.file != file: - # Add updated import - usage.usage_symbol.file.add_import(import_line) - usage.usage_symbol.remove() - elif usage.usage_type == UsageType.CHAINED: - # Update all previous usages of import * to the new import name - if usage.match and "." + self.name in usage.match: - if isinstance(usage.match, FunctionCall) and self.name in usage.match.get_name(): - usage.match.get_name().edit(self.name) - if isinstance(usage.match, ChainedAttribute): - usage.match.edit(self.name) - usage.usage_symbol.file.add_import(imp=import_line) - - # Add the import to the original file - if is_used_in_file: - self.file.add_import(imp=import_line) - # Delete the original symbol - self.remove() - - @property - @reader - @noapidoc - def is_top_level(self) -> bool: - """Is this symbol a top-level symbol: does it have a level of 0?""" - from codegen.sdk.core.file import File - - parent = self.parent - while not isinstance(parent, Symbol | Argument): - if isinstance(parent, File): - return True - parent = parent.parent - return False - - @writer - def add_keyword(self, keyword: str) -> None: - """Insert a keyword in the appropriate place before this symbol if it doesn't already exist. - - This method adds a keyword (e.g., 'public', 'async', 'static') in the syntactically appropriate - position relative to other keywords. If the keyword already exists, no action is taken. - - Args: - keyword (str): The keyword to be inserted. Must be a valid keyword in the language context. - - Raises: - AssertionError: If the provided keyword is not in the language's valid keywords list. - """ - assert keyword in self.ctx.node_classes.keywords - to_insert_onto = None - to_insert_idx = self.ctx.node_classes.keywords.index(keyword) - for node in self.children_by_field_types(self.ctx.node_classes.keywords): - idx = self.ctx.node_classes.keywords.index(node) - if node == keyword: - return - if idx < to_insert_idx: - to_insert_onto = node - if to_insert_onto is not None: - to_insert_onto.insert_after(" " + keyword, newline=False) - else: - self.insert_before(keyword + " ", newline=False, extended=False) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - from codegen.sdk.core.interfaces.has_block import HasBlock - - symbols = [self] - if isinstance(self, HasBlock): - symbols.extend(self.code_block.descendant_symbols) - return symbols diff --git a/src/codegen/sdk/core/symbol_group.py b/src/codegen/sdk/core/symbol_group.py deleted file mode 100644 index 4d72b1d9f..000000000 --- a/src/codegen/sdk/core/symbol_group.py +++ /dev/null @@ -1,281 +0,0 @@ -from __future__ import annotations - -from collections.abc import Collection, Iterator -from typing import TYPE_CHECKING, Generic, TypeVar, override - -from codegen.sdk.core.autocommit import reader, repr_func, writer -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.output.ast import AST - - -Child = TypeVar("Child", bound="Editable") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class SymbolGroup(Editable[Parent], Collection[Child], Generic[Child, Parent]): - """These are groups of symbols that form some kind of logical grouping, like a class or module, - that do not follow the traditional tree structure. - """ - - _symbols: list[Child] - - def __init__(self, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, node: TSNode | None = None, children: list[Child] | None = None) -> None: - self._symbols = children - if node is None: - # For backwards compatibility, assure that the first node is the main node - node = children[0].ts_node - super().__init__(node, file_node_id, ctx, parent) - - def __repr__(self) -> str: - return f"Collection({self.symbols})" if self.symbols is not None else super().__repr__() - - def _init_children(self): ... - - @repr_func # HACK - def __hash__(self): - return super().__hash__() - # return hash(hash(node) for node in self.symbols) if self.symbols is not None else super().__hash__() - - def __eq__(self, other: object) -> bool: - if other is None: - return False - if isinstance(other, SymbolGroup): - return self.symbols == other.symbols - if isinstance(other, list): - return self.symbols == other - return super().__eq__(other) - - @property - @reader - def symbols(self) -> list[Child]: - """Returns the list of symbols in the group. - - Gets the list of symbols associated with this SymbolGroup. These symbols can be code elements like functions, classes, or variables that form a logical grouping. - - Returns: - list[Child]: A list of symbol objects that belong to this group. - """ - return self._symbols - - @property - @reader - def source(self) -> str: - """Returns the concatenated source code of all symbols in the group. - - Returns: - str: The concatenated source code of all symbols in the group. - """ - # Use _source to avoid infinite recursion - return "\n".join([symbol._source for symbol in self.symbols]) - - @source.setter - @writer - def source(self, value) -> None: - """Sets the source code of the Editable instance. - - Updates the source code by calling the edit method with the provided value. - - Args: - value (str): The new source code to set for this Editable instance. - - Returns: - None - """ - self.edit(value) - - @property - @reader - def next_sibling(self) -> Editable | None: - """Returns the next sibling of the last symbol in the symbol group. - - Provides access to the next sibling node of the last symbol in this symbol group. - - Returns: - Editable | None: The next sibling node of the last symbol in the group, or None if there is no next sibling. - """ - return self.symbols[-1].next_sibling - - @property - @reader - def next_named_sibling(self) -> Editable | None: - """Returns the next named sibling of the last symbol in the group. - - Args: - None - - Returns: - Editable | None: The next named sibling node, or None if there is no next named sibling. - """ - return self.symbols[-1].next_named_sibling - - @writer - def find_string_literals(self, strings_to_match: list[str], fuzzy_match: bool = False) -> list[Editable]: - """Search for string literals matching given strings in the SymbolGroup. - - Iterates through all symbols in the group and aggregates the results of - finding string literals in each symbol. - - Args: - strings_to_match (list[str]): List of strings to search for in string literals. - fuzzy_match (bool, optional): If True, performs fuzzy matching instead of exact matching. - - Returns: - list[Editable]: List of Editable nodes representing the matching string literals found within the symbols. - """ - return [node for symbol in self.symbols for node in symbol.find_string_literals(strings_to_match, fuzzy_match)] - - @writer - def replace(self, old: str, new: str, count: int = -1, priority: int = 0) -> int: - """Replaces all instances of a string with a new string in all symbols within the group. - - Args: - old (str): The string to be replaced. - new (str): The string to replace with. - count (int, optional): Maximum number of replacements to make. Defaults to -1 (replace all). - priority (int, optional): Priority of the replacement operation. Defaults to 0. - - Returns: - int: Number of replacements made. - """ - for symbol in self.symbols: - symbol.replace(old, new, count, priority) - - @reader - def find(self, strings_to_match: list[str] | str, *, exact: bool = False) -> list[Editable]: - """Search for substrings in the given symbols that match `strings_to_match`. - - Args: - strings_to_match (list[str] | str): The string or list of strings to search for. - exact (bool): If True, only return nodes that exactly match the query. - - Returns: - list[Editable]: A list of Editable objects representing each match found. - """ - return [node for symbol in self.symbols for node in symbol.find(strings_to_match, exact)] - - @reader - def search(self, regex_pattern: str, include_strings: bool = True, include_comments: bool = True) -> list[Editable]: - """Searches for regex matches in the codebase. - - Searches through the source code to find text matching a regex pattern, with options to exclude string literals and comments from the search. - - Args: - regex_pattern (str): The regular expression pattern to search for. - include_strings (bool, optional): Whether to include string literals in the search. Defaults to True. - include_comments (bool, optional): Whether to include comments in the search. Defaults to True. - - Returns: - list[Editable]: A list of Editable objects representing matched text nodes in the codebase. - """ - return [node for symbol in self.symbols for node in symbol.search(regex_pattern, include_strings, include_comments)] - - @writer - def insert_before(self, new_src: str, fix_indentation: bool = False, newline: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Inserts source code before this symbol group. - - Inserts the provided source code before the first symbol in the group, while maintaining proper code formatting. - - Args: - new_src (str): The source code to insert. - fix_indentation (bool, optional): Whether to adjust the indentation of the inserted code to match the current code. Defaults to False. - newline (bool, optional): Whether to add a newline after the inserted code. Defaults to True. - priority (int, optional): The priority of this edit operation. Higher priority edits are applied first. Defaults to 0. - dedupe (bool, optional): Whether to prevent duplicate insertions of the same code. Defaults to True. - - Returns: - None - """ - super().insert_before(new_src, fix_indentation, newline, priority, dedupe) - - @writer - def insert_after(self, new_src: str, fix_indentation: bool = False, newline: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Inserts source code after this node in the codebase. - - Args: - new_src (str): The source code to insert. - fix_indentation (bool, optional): Adjust indentation to match current text. - newline (bool, optional): Add a newline before the inserted code. - priority (int, optional): Priority of the edit operation. - dedupe (bool, optional): Deduplicate identical edits. - - Returns: - None - """ - if len(self.symbols) == 0 or self.ts_node != self.symbols[0].ts_node: - super().insert_after(new_src, fix_indentation, newline, priority, dedupe) - else: - self.symbols[-1].insert_after(new_src, fix_indentation, newline, priority, dedupe) - - @writer - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Replace the source of this node with new text. - - Replaces the source of this SymbolGroup with new text by replacing the first symbol's source and removing all other symbols. - - Args: - new_src (str): The new source text to replace the current text with. - fix_indentation (bool, optional): Adjusts the indentation of new_src to match the current text's indentation. Defaults to False. - priority (int, optional): Priority of the edit operation. Higher priority edits take precedence. Defaults to 0. - dedupe (bool, optional): Prevents duplicate edits at the same location. Defaults to True. - - Returns: - None - """ - self.symbols[0].edit(new_src, fix_indentation, priority, dedupe) - for symbol in self.symbols[1:]: - symbol.remove() - - @writer - def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Removes this node and its related extended nodes from the codebase. - - Args: - delete_formatting (bool, optional): Whether to delete related extended nodes like decorators and comments. Defaults to True. - priority (int, optional): Priority level of the removal operation. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate removal operations. Defaults to True. - - Returns: - None - """ - for symbol in self.symbols: - symbol.remove(delete_formatting, priority, dedupe) - - @reader - def __iter__(self) -> Iterator[Child]: - return iter(self.symbols) - - @reader - def __contains__(self, __x) -> bool: - return __x in self.symbols - - @reader - def __len__(self) -> int: - return len(self.symbols) - - @reader - def __getitem__(self, item): - return self.symbols[item] - - def __bool__(self) -> bool: - return True - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - for symbol in self.symbols: - symbol._compute_dependencies(usage_type, dest) - - @override - def _get_ast_children(self) -> list[tuple[str | None, AST]]: - return [(None, symbol.ast()) for symbol in self.symbols] diff --git a/src/codegen/sdk/core/symbol_groups/__init__.py b/src/codegen/sdk/core/symbol_groups/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/core/symbol_groups/collection.py b/src/codegen/sdk/core/symbol_groups/collection.py deleted file mode 100644 index d9b962701..000000000 --- a/src/codegen/sdk/core/symbol_groups/collection.py +++ /dev/null @@ -1,281 +0,0 @@ -from collections import defaultdict -from collections.abc import Iterable, Iterator, MutableSequence -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.transactions import TransactionPriority -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_group import SymbolGroup -from codegen.shared.decorators.docs import noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -Child = TypeVar("Child", bound="Editable") -Parent = TypeVar("Parent") - - -class Collection(SymbolGroup[Child, Parent], MutableSequence[Child], Generic[Child, Parent]): - """Ordered collection of nodes - Attributes: - _bracket_size: Number of characters wrapping the collection - """ - - _elements: int - _reversed: set[int] - _inserts: dict[int, int] - _pending_removes: int = 0 - - _delimiter: str - _indent: int = 0 - _bracket_size: int = 1 - _container_start_byte: int - _container_end_byte: int - - def __init__(self, node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", children: list[Child] | None = None, *, bracket_size: int = 1) -> None: - super().__init__(file_node_id, ctx, parent, node) - self._delimiter = delimiter - self._reversed = set() - self._inserts = defaultdict(lambda: 0) - self._container_start_byte = self.ts_node.start_byte - self._container_end_byte = self.ts_node.end_byte - self._bracket_size = bracket_size - if children is not None: - self._init_children(children) - - def _init_children(self, symbols: list[Child]): - """Call this after setting self._symbols.""" - if self.ts_node.start_point[0] != self.ts_node.end_point[0] and symbols: - # This is a multiline collection. - self._indent = symbols[0].ts_node.start_point[1] - self._delimiter += "\n" - else: - self._delimiter += " " - self._elements = len(symbols) - self._symbols = symbols - self._original_children = symbols.copy() - - @overload - def __setitem__(self, key: int, value: str | Child) -> None: ... - - @overload - def __setitem__(self, key: slice, value: Iterable[Child] | Iterable[str]) -> None: ... - - @writer - def __setitem__(self, key: int | slice, value: str | Child | Iterable[Child] | Iterable[str]) -> None: - if isinstance(key, slice): - assert isinstance(value, Iterable) - for idx, item in zip(range(key.start, key.stop, key.step), value): - self[idx] = item - else: - assert not isinstance(value, Iterable) - if isinstance(value, Editable): - value = value.source - self.symbols[key].edit(value) - - @writer - def __delitem__(self, key: int | slice) -> None: - if isinstance(key, slice): - for i in reversed(range(key.start, key.stop, key.step)): - del self[i] - else: - self.symbols[key].remove(delete_formatting=True) - del self.symbols[key] - - def __iter__(self) -> Iterator[Child]: - return super().__iter__() - - @reader - def __len__(self) -> int: - return self._elements + self._inserts_till() - - @writer - def remove(self, value: Child | None = None, *args, **kwargs) -> None: - """Removes an element from a Collection. - - Deletes the specified element from the Collection by calling its remove method. If no value is specified, - delegates to the parent class's remove method. - - Args: - value (Child | None): The element to remove from the Collection. If None, delegates to parent class. - *args: Variable length argument list to pass to the remove method. - **kwargs: Arbitrary keyword arguments to pass to the remove method. - - Returns: - None: This method doesn't return anything. - """ - # Your custom remove logic goes here - # For example, let's remove all occurrences of the value instead of just the first one - if value is None: - super().remove(*args, **kwargs) - Editable.remove(self, *args, **kwargs) - else: - value.remove(*args, **kwargs) - - def _inserts_till(self, max_idx: int | None = None) -> int: - """Find the number of pending inserts until max_idx.""" - return sum(inserts for idx, inserts in self._inserts.items() if (max_idx is None or idx < max_idx)) - - @writer - def insert(self, index: int, value: str | Child) -> None: - """Adds `value` to the container that this node represents - Args: - value: source to add - index: If provided, the `value` will be inserted at that index, otherwise will default to end of the list. - """ - if index < 0: - index = len(self) - index - # If index is not specified, insert at the end of the list - if self._elements == 0: - insert_byte = self._container_start_byte + self._bracket_size - elif index - self._inserts_till(index) >= self._elements: - # If inserting at end of the list, insert before the closing container character - insert_byte = self._container_end_byte - self._bracket_size - else: - # If inserting in the middle of the list, insert before the next sibling - sibling_index = index - self._inserts_till(index) - insert_byte = self._get_insert_byte_from_next_sibling(sibling_index) - insert_idx = index - # insert_idx = min(index, len(self.symbols) - self.pending_removes) - self._incr_insert_size(insert_idx) - insert_number = self._inserts[insert_idx] - # Case 1: Insert occuring before the last element, should be reversed - if insert_byte < self._container_end_byte - self._bracket_size: - self._reversed.add(insert_idx) - elif len(self.source) > 1 and self._bracket_size > 0: - remaining = self.source[: -self._bracket_size].rstrip() - # Case 2: Last element ends with the delimiter, reverse for this insert - if remaining.endswith(self._delimiter.rstrip()): - self._reversed.add(insert_idx) - # Case 3: A spread element was deleted and we must respect that - elif insert_number == 1: - if (relative_byte := remaining.rfind(self._delimiter)) != -1: - delim_byte = relative_byte + self.start_byte + len(self._delimiter) - element_deleted = self.transaction_manager.get_transactions_at_range(self.file.path, delim_byte, self.start_byte + len(remaining), TransactionPriority.Remove, combined=True) - delimeter_deleted = self.transaction_manager.get_transactions_at_range(self.file.path, delim_byte - len(self._delimiter), delim_byte, TransactionPriority.Remove, combined=True) - if element_deleted and not delimeter_deleted: - # Adjust the insert to insert at the correct location - insert_byte = delim_byte - self._reversed.add(insert_idx) - - def get_source() -> str: - return self._get_insert_source(value, insert_idx) - - def incr_elements() -> None: - self._inserts[insert_idx] -= 1 - self._elements += 1 - self._mark_dirty() - - # We want right -> left ordering - # Therefore, we go by highest index then insert the lowest insert number on the same index - super().insert_at(insert_byte, get_source, priority=(-index, +insert_number), exec_func=incr_elements) - - def _get_insert_byte_from_next_sibling(self, sibling_index: int) -> int: - return self.symbols[sibling_index].start_byte - - def _get_insert_source(self, src: Any, insert_idx: int) -> str: - elements = self._elements - self._pending_removes - if elements == 0: - # Further inserts to this index are reversed - self._reversed.add(insert_idx) - # If list is empty, insert after the opening container character - return str(src) - # Check if this index is reversed - # Additionally, if it isn't, check if the next one is - elif insert_idx in self._reversed or (insert_idx + 1) in self._reversed: - self._reversed.add(insert_idx) - # Insert in the middle, reverse the delimiter - return f"{' ' * self._indent}{src}{self._delimiter}" - else: - # If inserting at the end of the list - return f"{self._delimiter}{src}" - - @noapidoc - def _incr_insert_size(self, index: int) -> None: - self._inserts[index] += 1 - - @noapidoc - def _removed_child_commit(self) -> None: - self._mark_dirty() - self._elements -= 1 - self._pending_removes -= 1 - - @noapidoc - def _removed_child(self) -> None: - self._mark_dirty() - self._pending_removes += 1 - - @property - @reader - def source(self) -> str: - """Get the source code content of the node. - - Retrieves the underlying source code content associated with this node as stored in the _source attribute. - - Returns: - str: The source code content of the node. - """ - return self._source - - @source.setter - @writer - def source(self, value) -> None: - """Set the source of the Editable instance by calling .edit(..)""" - if self.source != value: - self.edit(value) - - @writer - def edit(self, *args, **kwargs) -> None: - """Edit the source for this Collection instance. - - This method is used to update the source of a Collection while preserving its start and end brackets. It is primarily used internally by - Collection to maintain structural integrity during edits. - - Args: - *args: Variable length argument list passed to the parent Editable class's edit method. - **kwargs: Arbitrary keyword arguments passed to the parent Editable class's edit method. - - Returns: - None - """ - return Editable.edit(self, *args, **kwargs) # HACK: keep start/end brackets - - @property - @reader - @noapidoc - def uncommitted_len(self): - """Get the len of this list including pending removes and adds.""" - return len(self) - self._pending_removes - - @reader - def index(self, value: Child, start: int = 0, stop: int | None = None) -> int: - """Return the index of the first occurrence of value. - - Returns -1 if value is not present. - """ - if stop is None: - stop = len(self) - ts_node = value if isinstance(value, TSNode) else value.ts_node - try: - return [x.ts_node for x in self.symbols].index(ts_node, start, stop) - except ValueError: - return -1 - - @noapidoc - def _mark_dirty(self): - self.transaction_manager.pending_undos.add(self.reset) - - @noapidoc - def reset(self): - self._pending_removes = 0 - self._elements = len(self._original_children) - self._symbols = self._original_children.copy() - self._inserts.clear() - self._reversed.clear() - - def _smart_remove(self, child, *args, **kwargs) -> bool: - return self.parent._smart_remove(self, child, *args, **kwargs) diff --git a/src/codegen/sdk/core/symbol_groups/comment_group.py b/src/codegen/sdk/core/symbol_groups/comment_group.py deleted file mode 100644 index a6c64931e..000000000 --- a/src/codegen/sdk/core/symbol_groups/comment_group.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, TypeVar - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.statements.comment import Comment -from codegen.sdk.core.symbol_group import SymbolGroup -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - -Parent = TypeVar("Parent") - - -@apidoc -class CommentGroup(SymbolGroup[Comment, Parent]): - """A group of comments that form a larger comment block.""" - - _indentation: int # Indentation level of the comment block - - def __init__(self, children: list[Comment], file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - assert len(children) > 0, "CommentGroup must have at least one symbol" - super().__init__(file_node_id, ctx, parent, node=children[0].ts_node, children=children) - self._indentation = self._calculate_indentation() - - @property - @reader - def text(self) -> str: - """Return the text content of all comments in the comment block. - - Combines multiple comment lines with newlines, excluding comment delimiters. - - Returns: - str: The concatenated text content of all comments in the block. - """ - return "\n".join([comment.text for comment in self.symbols]) - - @text.setter - @writer - def text(self, new_text: str) -> None: - """Replace the text of a CommentGroup with new text. - - Updates the text of all comments in the group, maintaining proper comment delimiters like `#` or `/* */`. - After updating the first comment's text, all subsequent comments in the group are removed. - - Args: - new_text (str): The new text content to replace the existing comment text. Will be formatted with appropriate comment delimiters. - - Returns: - None - """ - self.edit_text(new_text) - - @writer - def edit_text(self, new_text: str) -> None: - """Replace the text content of a comment group with new text. - - Updates the comment text while preserving and auto-formatting comment delimiters. - Removes any additional comment lines from the comment group, leaving only the - first line with the new text. - - Args: - new_text (str): The new text content to replace the existing comment text. - The text should not include comment delimiters. - - Returns: - None - """ - # Generate comment block with new source - self.symbols[0].edit_text(new_text) - for symbol in self.symbols[1:]: - symbol.remove() - - @noapidoc - @reader - def _calculate_indentation(self) -> int: - """Calculate the indentation level of the comment block.""" - return self.symbols[0].ts_node.start_point[1] diff --git a/src/codegen/sdk/core/symbol_groups/dict.py b/src/codegen/sdk/core/symbol_groups/dict.py deleted file mode 100644 index 20bc3b984..000000000 --- a/src/codegen/sdk/core/symbol_groups/dict.py +++ /dev/null @@ -1,180 +0,0 @@ -from collections.abc import Iterator, MutableMapping -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.expressions.string import String -from codegen.sdk.core.expressions.unpack import Unpack -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.importable import Importable - - -TExpression = TypeVar("TExpression", bound="Expression") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Pair(Editable[Parent], HasValue, Generic[TExpression, Parent]): - """An abstract representation of a key, value pair belonging to a `Dict`. - - Attributes: - key: The key expression of the pair, expected to be of type TExpression. - """ - - key: TExpression - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self.key, self._value_node = self._get_key_value() - if self.key is None: - self._log_parse(f"{self} {self.ts_node} in {self.filepath} has no key") - if self.ts_node_type != "shorthand_property_identifier" and self.value is None: - self._log_parse(f"{self} {self.ts_node} in {self.filepath} has no value") - - def _get_key_value(self) -> tuple[Expression[Self] | None, Expression[Self] | None]: - return self.child_by_field_name("key"), self.child_by_field_name("value") - - @property - def name(self) -> str: - """Returns the source text of the key expression in the pair. - - This property provides access to the textual representation of the pair's key, which is - stored in the `key` attribute. The key is expected to be an Expression type that has - a `source` property containing the original source code text. - - Returns: - str: The source text of the key expression. - - Note: - This property assumes that self.key has been properly initialized in __init__ - and has a valid `source` attribute. In cases where key initialization failed - (key is None), accessing this property may raise an AttributeError. - """ - return self.key.source - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.key: - self.key._compute_dependencies(usage_type, dest) - if self.value and self.value is not self.key: - self.value._compute_dependencies(usage_type, dest) - - -TExpression = TypeVar("TExpression", bound="Expression") -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class Dict(Expression[Parent], Builtin, MutableMapping[str, TExpression], Generic[TExpression, Parent]): - """Represents a dict (object) literal the source code. - - Attributes: - unpack: An optional unpacking element, if present. - """ - - _underlying: Collection[Pair[TExpression, Self] | Unpack[Self], Parent] - unpack: Unpack[Self] | None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = Pair) -> None: - # TODO: handle spread_element - super().__init__(ts_node, file_node_id, ctx, parent) - children = [pair_type(child, file_node_id, ctx, self) for child in ts_node.named_children if child.type not in (None, "comment", "spread_element", "dictionary_splat") and not child.is_error] - if unpack := self.child_by_field_types({"spread_element", "dictionary_splat"}): - children.append(unpack) - self.unpack = unpack - if len(children) > 1: - first_child = children[0].ts_node.end_byte - ts_node.start_byte - second_child = children[1].ts_node.start_byte - ts_node.start_byte - delimiter = ts_node.text[first_child:second_child].decode("utf-8").rstrip() - self._underlying = Collection(ts_node, file_node_id, ctx, parent, delimiter=delimiter, children=children) - - def __bool__(self) -> bool: - return True - - def __len__(self) -> int: - return len(list(elem for elem in self._underlying if isinstance(elem, Pair))) - - def __iter__(self) -> Iterator[str]: - for pair in self._underlying: - if isinstance(pair, Pair): - if pair.key is not None: - if isinstance(pair.key, String): - yield pair.key.content - else: - yield pair.key.source - - def __getitem__(self, __key) -> TExpression: - for pair in self._underlying: - if isinstance(pair, Pair): - if isinstance(pair.key, String): - if pair.key.content == str(__key): - return pair.value - elif pair.key is not None: - if pair.key.source == str(__key): - return pair.value - msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" - raise KeyError(msg) - - def __setitem__(self, __key, __value: TExpression) -> None: - new_value = __value.source if isinstance(__value, Editable) else str(__value) - if value := self.get(__key, None): - value.edit(new_value) - else: - if not self.ctx.node_classes.int_dict_key: - try: - int(__key) - __key = f"'{__key}'" - except ValueError: - pass - self._underlying.append(f"{__key}: {new_value}") - - def __delitem__(self, __key) -> None: - for idx, pair in enumerate(self._underlying): - if isinstance(pair, Pair): - if isinstance(pair.key, String): - if pair.key.content == str(__key): - del self._underlying[idx] - return - elif pair.key is not None: - if pair.key.source == str(__key): - del self._underlying[idx] - return - msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" - raise KeyError(msg) - - def _removed_child_commit(self): - return self._underlying._removed_child_commit() - - def _removed_child(self): - return self._underlying._removed_child() - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self._underlying._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list["Importable"]: - ret = [] - for child in self._underlying.symbols: - if child.value: - ret.extend(child.value.descendant_symbols) - return ret - - @property - def __class__(self): - return dict diff --git a/src/codegen/sdk/core/symbol_groups/expression_group.py b/src/codegen/sdk/core/symbol_groups/expression_group.py deleted file mode 100644 index 82f6da266..000000000 --- a/src/codegen/sdk/core/symbol_groups/expression_group.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.symbol_group import SymbolGroup -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.expressions import Expression - -Parent = TypeVar("Parent") - - -TExpression = TypeVar("TExpression", bound="Expression") -Parent = TypeVar("Parent") - - -@apidoc -class ExpressionGroup(SymbolGroup[TExpression, Parent], Generic[TExpression, Parent]): - """Group of contiguous set of expressions.""" - - @property - @reader - def expressions(self) -> list[TExpression]: - """Returns all expressions in the group. - - A property that returns all expressions stored in the ExpressionGroup as a list. - - Returns: - list[TExpression]: A list of expressions contained in the group, where TExpression is a type variable bound to Expression. - """ - return self._symbols - - @property - @reader - def source(self) -> str: - """Returns the source code of the symbol group. - - Args: - None - - Returns: - str: The source code string for the symbol group, including all symbols within the group. - """ - # TODO: Use _source to avoid infinite recursion - return self.file.content[self.symbols[0].start_byte : self.symbols[-1].end_byte] - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls within the expression group. - - Retrieves all function calls from the expressions in this group, sets their - parent as this group, and returns them. - - Returns: - list[FunctionCall]: A list of all function calls found in the expressions - of this group. - """ - fcalls = [] - for expr in self.expressions: - for call in expr.function_calls: - fcalls.append(call) - return fcalls diff --git a/src/codegen/sdk/core/symbol_groups/list.py b/src/codegen/sdk/core/symbol_groups/list.py deleted file mode 100644 index 81099db28..000000000 --- a/src/codegen/sdk/core/symbol_groups/list.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import TYPE_CHECKING, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext -Parent = TypeVar("Parent", bound=Editable) - - -@apidoc -class List(Collection["Expression[Self, None]", Parent], Expression[Parent], Builtin): - """A list object. - - You can use standard operations to operate on this list (IE len, del, append, insert, etc) - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self._init_children([self._parse_expression(child) for child in ts_node.named_children if child.type]) - - @property - def __class__(self): - return list diff --git a/src/codegen/sdk/core/symbol_groups/multi_line_collection.py b/src/codegen/sdk/core/symbol_groups/multi_line_collection.py deleted file mode 100644 index cfbbb63b9..000000000 --- a/src/codegen/sdk/core/symbol_groups/multi_line_collection.py +++ /dev/null @@ -1,98 +0,0 @@ -from collections import defaultdict -from collections.abc import Iterator -from typing import TYPE_CHECKING, Generic, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -Child = TypeVar("Child", bound=Editable) -Parent = TypeVar("Parent", bound="Editable") - - -@apidoc -class MultiLineCollection(Collection[Child, Parent], Generic[Child, Parent]): - """A list containing multi-line objects. - - Example: A list of function definitions, class definitions - You can use standard operations to operate on this list (IE len, del, append, insert, etc) - """ - - _inserts_max_size: dict[int, int] - _leading_delimiter: str = "\n" - _trailing_delimiter: str = "\n" - - def __init__( - self, - children: list[Child], - file_node_id: NodeId, - ctx: "CodebaseContext", - parent: Parent, - node: TSNode, - indent_size: int, - leading_delimiter: str = "\n", - trailing_delimiter: str = "\n", - start_byte: int | None = None, - end_byte: int | None = None, - ) -> None: - super().__init__(node, file_node_id, ctx, parent, trailing_delimiter, children=children, bracket_size=0) - self._inserts_max_size = defaultdict(lambda: 0) - self._leading_delimiter = leading_delimiter - self._trailing_delimiter = trailing_delimiter - self._indent = indent_size - self._container_start_byte = start_byte or self.ts_node.start_byte - self._container_end_byte = end_byte or self.ts_node.end_byte + 1 - - def __iter__(self) -> Iterator[Child]: - return super().__iter__() - - def __len__(self) -> int: - return super().__len__() - - def _get_insert_byte_from_next_sibling(self, sibling_index: int) -> int: - # If inserting into the first sibling and the container_start_byte was specified, - # insert at the start of the container - if sibling_index == 0: - return self._container_start_byte - # Otherwise, insert at the line start of the sibling - sibling = self.symbols[sibling_index] - return sibling.start_byte - sibling.start_point[1] - - def _get_insert_source(self, src: str | Child, insert_idx: int) -> str: - indent = " " * self._indent - - if isinstance(src, Child.__bound__): - indent_size = src.start_point[1] - src_lines = str(src.source).split("\n") - src_lines = [f"{indent}{line}" for line in src_lines[:1]] + [line if line.strip() == "" else f"{indent}{line[indent_size:]}" for line in src_lines[1:]] - elif isinstance(src, str): - src = src.strip() - src_lines = src.split("\n") - src_lines = [line if line == "" else f"{indent}{line}" for line in src_lines] - else: - msg = f"Invalid source type: {type(src)}" - raise ValueError(msg) - src = "\n".join(src_lines) - - # Only add the leading delimiter if it's inserted before or after existing elements - if insert_idx == 0 or insert_idx >= len(self.symbols): - src = f"{self._leading_delimiter}{src}{self._trailing_delimiter}" - else: - src = f"{src}{self._trailing_delimiter}" - - # If this is the last element to insert before an existing element, add a delimiter - if insert_idx == len(self.symbols) - 1 and self._inserts[insert_idx] == self._inserts_max_size[insert_idx]: - src = f"{src}{self._leading_delimiter}" - return src - - @noapidoc - def _incr_insert_size(self, index: int) -> None: - super()._incr_insert_size(index) - self._inserts_max_size[index] = max(self._inserts[index], self._inserts_max_size[index]) diff --git a/src/codegen/sdk/core/symbol_groups/parents.py b/src/codegen/sdk/core/symbol_groups/parents.py deleted file mode 100644 index a0857d045..000000000 --- a/src/codegen/sdk/core/symbol_groups/parents.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.enums import EdgeType - -if TYPE_CHECKING: - from collections.abc import Iterator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.codebase.resolution_stack import ResolutionStack - from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute - from codegen.sdk.core.expressions.name import Name - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.interfaces.inherits import Inherits - from codegen.sdk.core.node_id_factory import NodeId - - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="Inherits") - - -class Parents(Collection["TType", Parent], Generic[TType, Parent]): - type_arguments: list[Type] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self._init_children([self._parse_type(child) for child in ts_node.named_children if child.type != "type_arguments"]) - self.type_arguments = [self._parse_type(child) for child in ts_node.children if child.type == "type_arguments"] - - def __iter__(self) -> Iterator[TType]: - return super().__iter__() - - def compute_superclass_dependencies(self) -> None: - """Compute superclass dependencies.""" - dest = self.parent - for superclass in self: - resolution: list[ResolutionStack] = superclass.resolved_types - if len(resolution) == 1 and self.ctx.has_node(getattr(resolution[0], "node_id", None)): - self.ctx.add_edge(dest.node_id, resolution[0].node_id, type=EdgeType.SUBCLASS) - else: - self._log_parse("%r is ambiguous with possibilities: %r.", superclass, resolution) - self.parent.__dict__.pop("superclasses", None) - self.parent.__dict__.pop("constructor", None) - - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.SUBCLASS, dest: HasName | None = None) -> None: - if dest is None: - dest = self.parent - for superclass in self: - superclass._compute_dependencies(UsageKind.BODY, dest) - for type_argument in self.type_arguments: - type_argument._compute_dependencies(UsageKind.GENERIC, dest) - - @reader - def is_subclass_of(self, parent: str | HasName, max_depth: int | None = None) -> bool: - """Returns True if the class is a subclass of the given parent class.""" - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.interface import Interface - - if isinstance(parent, HasName): - parent = parent.name - to_search = parent.split(".")[-1] - if to_search in (c.source.split(".")[-1] for c in self.parent_class_names): - return True - for parent_class in self.parent._get_superclasses(max_depth=(max_depth if max_depth is None else max_depth - 1)): - if isinstance(parent_class, Class): - if to_search in (c.source.split(".")[-1] for c in parent_class.parent_class_names): - return True - if isinstance(parent_class, Interface) and parent_class.parent_interfaces is not None: - if to_search in (c.name for c in parent_class.parent_interfaces): - return True - return False - - @property - @reader - def parent_class_names(self) -> list[Name | ChainedAttribute]: - """Returns a list of the args passed to the class (the parent classes)""" - return [superclass.get_name() for superclass in self._symbols if isinstance(superclass, HasName)] diff --git a/src/codegen/sdk/core/symbol_groups/tuple.py b/src/codegen/sdk/core/symbol_groups/tuple.py deleted file mode 100644 index 47ecb18e2..000000000 --- a/src/codegen/sdk/core/symbol_groups/tuple.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import TYPE_CHECKING, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext -Parent = TypeVar("Parent", bound=Editable) - - -@apidoc -class Tuple(Collection["Expression[Self, None]", Parent], Expression[Parent], Builtin): - """A tuple object. - - You can use standard operations to operate on this list (IE len, del, append, insert, etc) - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self._init_children([self._parse_expression(child) for child in ts_node.named_children if child.type]) - - @property - def __class__(self): - return tuple diff --git a/src/codegen/sdk/core/symbol_groups/type_parameters.py b/src/codegen/sdk/core/symbol_groups/type_parameters.py deleted file mode 100644 index aa83615c5..000000000 --- a/src/codegen/sdk/core/symbol_groups/type_parameters.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.symbol_groups.collection import Collection - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.expressions.type import Type - from codegen.sdk.core.interfaces.supports_generic import SupportsGenerics - from codegen.sdk.core.node_id_factory import NodeId - - -TType = TypeVar("TType", bound="Type") -Parent = TypeVar("Parent", bound="SupportsGenerics") - - -class TypeParameters(Collection["TType", Parent], Generic[TType, Parent]): - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self._init_children([self._parse_type(child) for child in ts_node.named_children]) diff --git a/src/codegen/sdk/core/type_alias.py b/src/codegen/sdk/core/type_alias.py deleted file mode 100644 index 17171155c..000000000 --- a/src/codegen/sdk/core/type_alias.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar, override - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.supports_generic import SupportsGenerics -from codegen.sdk.enums import SymbolType -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.attribute import Attribute - from codegen.sdk.core.statements.statement import Statement - - -TCodeBlock = TypeVar("TCodeBlock", bound="CodeBlock") -TAttribute = TypeVar("TAttribute", bound="Attribute") -Parent = TypeVar("Parent", bound="HasBlock") - - -@apidoc -class TypeAlias(SupportsGenerics, HasValue, HasBlock, HasAttribute[TAttribute], Generic[TCodeBlock, TAttribute]): - """Abstract representation of a Type object. - - Only applicable for some programming languages like TypeScript. - - Attributes: - symbol_type: The type of symbol, set to SymbolType.Interface. - code_block: The code block associated with this type alias. - """ - - symbol_type = SymbolType.Interface - code_block: TCodeBlock - - def __init__( - self, - ts_node: TSNode, - file_node_id: NodeId, - ctx: CodebaseContext, - parent: Statement[CodeBlock[Parent, ...]], - ) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - value_node = self.ts_node.child_by_field_name("value") - self._value_node = self._parse_type(value_node) if value_node else None - self.type_parameters = self.child_by_field_name("type_parameters") - - @property - @abstractmethod - @reader - def attributes(self) -> list[TAttribute]: - """List of expressions defined in this Type object.""" - - @reader - def get_attribute(self, name: str) -> TAttribute | None: - """Get attribute by name.""" - return next((x for x in self.attributes if x.name == name), None) - - @noapidoc - @reader - @override - def resolve_attribute(self, name: str) -> TAttribute | None: - return self.get_attribute(name) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - return super().descendant_symbols + self.value.descendant_symbols diff --git a/src/codegen/sdk/core/utils/cache_utils.py b/src/codegen/sdk/core/utils/cache_utils.py deleted file mode 100644 index 60f7c4dbf..000000000 --- a/src/codegen/sdk/core/utils/cache_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -import functools -from collections.abc import Iterator -from typing import Callable, Generic, ParamSpec, TypeVar - -from codegen.sdk.extensions.utils import lru_cache - -ItemType = TypeVar("ItemType") -GenParamSpec = ParamSpec("GenParamSpec") - - -class LazyGeneratorCache(Generic[ItemType]): - """A cache for a generator that is lazily evaluated.""" - - _cache: list[ItemType] - gen: Iterator[ItemType] - - def __init__(self, gen: Iterator[ItemType]): - self._cache = [] - self.gen = gen - - def __iter__(self) -> Iterator[ItemType]: - for item in self._cache: - yield item - - for item in self.gen: - self._cache.append(item) - yield item - - -def cached_generator(maxsize: int = 16, typed: bool = False) -> Callable[[Callable[GenParamSpec, Iterator[ItemType]]], Callable[GenParamSpec, Iterator[ItemType]]]: - """Decorator to cache the output of a generator function. - - The generator's output is fully consumed on the first call and stored as a list. - Subsequent calls with the same arguments yield values from the cached list. - """ - - def decorator(func: Callable[GenParamSpec, Iterator[ItemType]]) -> Callable[GenParamSpec, Iterator[ItemType]]: - @lru_cache(maxsize=maxsize, typed=typed) - @functools.wraps(func) - def wrapper(*args: GenParamSpec.args, **kwargs: GenParamSpec.kwargs) -> Iterator[ItemType]: - return LazyGeneratorCache(func(*args, **kwargs)) - - return wrapper - - return decorator diff --git a/src/codegen/sdk/enums.py b/src/codegen/sdk/enums.py deleted file mode 100644 index b67c5bd81..000000000 --- a/src/codegen/sdk/enums.py +++ /dev/null @@ -1,85 +0,0 @@ -from enum import IntEnum, auto -from typing import NamedTuple - -from codegen.sdk.core.dataclasses.usage import Usage -from codegen.shared.decorators.docs import apidoc - - -class NodeType(IntEnum): - """NodeType is an enumeration class that defines different types of nodes within the graph.""" - - REPO = auto() # Node representing the full repository - FILE = auto() # Node representing a file - IMPORT = auto() # Node representing an import statement - EXPORT = auto() # Node representing an export statement - SYMBOL = auto() # Node representing a symbol defined in a file - EXTERNAL = auto() # Node representing something external to the codebase, e.g. `datetime` - EXPRESSION = auto() # Node representing an expression within a statement. - - -class FileGraphNodeType(IntEnum): - # File graph nodes - STATEMENT = auto() # Node representing a statement in code block. - EXPRESSION = auto() # Node representing an expression within a statement. - - -class FileGraphEdgeType(IntEnum): - # File graph edges - STATEMENT_CONTAINS_EXPRESSION = auto() # Edge from statement to expression. - - -class EdgeType(IntEnum): - # === [ External Edges Between Files ] === - # Edge from Import => resolved Symbol. - # Should be added by the import, only after all the files have been parsed. - IMPORT_SYMBOL_RESOLUTION = auto() - EXPORT = auto() - SUBCLASS = auto() - # Edge from Symbol => used Symbol (or Import) referenced within the same file. - # Should be added by the parent symbol, only after all the file children node types have been added to the graph. - SYMBOL_USAGE = auto() - - -class SymbolType(IntEnum): - """TODO: names should be all uppercase""" - - Function = auto() - Class = auto() - GlobalVar = auto() - Interface = auto() - Type = auto() - Enum = auto() - Namespace = auto() - - -@apidoc -class ImportType(IntEnum): - """Import types for each import object. Determines what the import resolves to, and what symbols are imported. - - Attributes: - DEFAULT_EXPORT: Imports all default exports. Resolves to the file. - NAMED_EXPORT: Imports a named export. Resolves to the symbol export. - WILDCARD: Imports all named exports, and default exports as `default`. Resolves to the file. - MODULE: Imports the module, not doesn't actually allow access to any of the exports - SIDE_EFFECT: Imports the module, not doesn't actually allow access to any of the exports - UNKNOWN: Unknown import type. - """ - - # Imports all default exports. Resolves to the file. - DEFAULT_EXPORT = auto() - # Imports a named export. Resolves to the symbol export. - NAMED_EXPORT = auto() - # Imports all named exports, and default exports as `default`. Resolves to the file. - WILDCARD = auto() - # Imports all default and named exports. The default export is aliased as `default` and can be accessed by `moduleName.default` - # Resolves to the file. - MODULE = auto() - # Imports the module, not doesn't actually allow access to any of the exports - # Resolves to the file. - SIDE_EFFECT = auto() - UNKNOWN = auto() # TODO: get rid of this - mostly used to set default value. we should just set to None. - - -class Edge(NamedTuple): - type: EdgeType - usage: Usage | None diff --git a/src/codegen/sdk/extensions/autocommit.pyi b/src/codegen/sdk/extensions/autocommit.pyi deleted file mode 100644 index 51f4ebd81..000000000 --- a/src/codegen/sdk/extensions/autocommit.pyi +++ /dev/null @@ -1,31 +0,0 @@ -from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar, overload - -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.core.interfaces.editable import Editable - -P = ParamSpec("P") -T = TypeVar("T") - -def is_outdated(c) -> bool: ... -@overload -def reader(wrapped: Callable[P, T]) -> Callable[P, T]: ... -@overload -def reader(wrapped: None = None, *, cache: bool | None = ...) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - -class AutoCommitMixin: - """Support for autocommit""" - - autocommit_cache: dict[str, Any] - removed: bool - def __init__(self, ctx: CodebaseContext) -> None: ... - def update_generation(self, generation: int | None = None) -> None: ... - @property - def is_outdated(self) -> bool: ... - def is_same_version(self, other: AutoCommitMixin) -> bool: ... - -def update_dict(seen: set[Editable], obj: Editable, new_obj: Editable): ... -@overload -def commiter(wrapped: Callable[P, T]) -> Callable[P, T]: ... -@overload -def commiter(wrapped: None = None, *, reset: bool = ...) -> Callable[[Callable[P, T]], Callable[P, T]]: ... diff --git a/src/codegen/sdk/extensions/autocommit.pyx b/src/codegen/sdk/extensions/autocommit.pyx deleted file mode 100644 index d9d2b69eb..000000000 --- a/src/codegen/sdk/extensions/autocommit.pyx +++ /dev/null @@ -1,217 +0,0 @@ -import functools -from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar, Union, overload - -import wrapt - -from codegen.sdk.core.autocommit.constants import AutoCommitState, OutdatedNodeError, enabled - -P = ParamSpec("P") -T = TypeVar("T") - - -def is_outdated(c) -> bool: - from codegen.sdk.core.interfaces.editable import Editable - - if isinstance(c, Editable): - return c.is_outdated - if isinstance(c, list): - return any(is_outdated(i) for i in c) - return False - - -@overload -def reader(wrapped: Callable[P, T]) -> Callable[P, T]: ... - - -@overload -def reader(wrapped: None = None, *, cache: bool | None = ...) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - - -def reader(wrapped: Callable[P, T] | None = None, *, cache: bool | None = None) -> Callable[P, T] | Callable[[Callable[P, T]], Callable[P, T]]: - """Indicates this method is a read - - Args: - ---- - cache (bool): Whether to cache the result of the function. By default enabled for functions without arguments - - """ - if wrapped is None: - return functools.partial(reader, cache=cache) - - @wrapt.decorator(enabled=enabled) - def wrapper(wrapped: Callable[P, T], instance: Union["Editable", None] = None, args: P.args = None, kwargs: P.kwargs = None) -> T: - """Indicates this method is a reader and should be updated if there are any pending changes.""" - num_args = len(args) + len(kwargs) - if instance is None: - instance = args[0] - num_args -= 1 - name = wrapped.__name__ - autocommit = instance.ctx._autocommit - should_cache = cache - - def run_func(): - if should_cache and not instance.is_outdated: - if cached := instance.autocommit_cache.get(name, None): - if not is_outdated(cached): - return cached - ret = wrapped(*args, **kwargs) - if should_cache: - if is_outdated(ret): - raise OutdatedNodeError(instance) - instance.autocommit_cache[name] = ret - return ret - - if autocommit.state in (AutoCommitState.Special, AutoCommitState.Committing): - return run_func() - if num_args > 0: - if cache: - raise NotImplementedError("Cache doesn't support functions with arguments") - should_cache = False - elif cache is None: - should_cache = True - to_unlock = autocommit.try_lock_files({instance.filepath}) - old_state = autocommit.enter_state(AutoCommitState.Read) - # logger.debug("Reading node %r, %r", instance, wrapped) - try: - autocommit.check_update(instance, lock=to_unlock, must_be_updated=False) - ret = run_func() - finally: - autocommit.state = old_state - autocommit.unlock_files(to_unlock) - return ret - - wrapped._reader = True - return wrapper(wrapped) - - -class AutoCommitMixin: - """Support for autocommit""" - - _generation: int - autocommit_cache: dict[str, Any] - removed: bool = False - - def __init__(self, codebase_context: "CodebaseContext"): - self._generation = codebase_context.generation - self.autocommit_cache = {} - - def update_generation(self: "Editable", generation: int | None = None) -> None: - if generation is None: - generation = self.file._generation - self._generation = generation - - @property - def is_outdated(self: "Editable") -> bool: - if file := self.file: - return self._generation < file._generation - return False - - def is_same_version(self, other: "AutoCommitMixin") -> bool: - return self._generation == other._generation - - -def _delay_update(new_value) -> bool: - if isinstance(new_value, AutoCommitMixin): - return new_value.is_outdated - elif isinstance(new_value, list): - return any(v.is_outdated for v in new_value) - return False - - -def update_dict(seen: set["Editable"], obj: "Editable", new_obj: "Editable"): - from codegen.sdk.core.interfaces.editable import Editable - - if obj in seen or obj.removed: - return - if new_obj.is_outdated: - raise OutdatedNodeError(new_obj) - if new_obj.is_same_version(obj): - return - assert new_obj._generation > obj._generation - seen.add(obj) - - def update_child(v, new_value): - if isinstance(v, Editable): - update_dict(seen, v, new_value) - # elif isinstance(v, list): - # # This only will work for lists, as the others are non-ordered - # to_update = list(filter(lambda i: not i.removed, v)) - # if len(to_update) == len(new_value): - # for old, new in zip(to_update, new_value): - # if isinstance(old, Editable): - # update_dict(seen, old, new) - - for k, v in obj.__dict__.items(): - # Update all the detached symbols in the tree - if k in new_obj.__dict__: - new_value = new_obj.__dict__[k] - # assert new_value is not None, f"{k=}, {v=}, {new_value=}" - if new_value is not None: - update_child(v, new_value) - # If you put a breakpoint during this while loop, python may segfault - for k, v in obj.autocommit_cache.items(): - new_value = getattr(new_obj, k) - if isinstance(new_value, Callable) and not isinstance(new_value, Exception): - new_value = new_value() - update_child(v, new_value) - if isinstance(new_value, Editable): - new_obj.autocommit_cache[k] = v - # # If you put a breakpoint during this while loop, python may segfault - # while len(to_update) > 0: - # k = to_update.popleft() - # new_value = getattr(new_obj, k) - # if isinstance(new_value, Callable): - # new_value = new_value() - # if _delay_update(new_value): - # new_obj.autocommit_cache.clear() - # to_update.append(k) - # else: - # v = obj.autocommit_cache.get(k) - # update_child(v, new_value) - assert new_obj.__class__ == obj.__class__ - obj.__dict__ = new_obj.__dict__ - assert new_obj.ts_node == obj.ts_node - assert new_obj.is_same_version(obj) - assert not obj.is_outdated - - -@overload -def commiter(wrapped: Callable[P, T]) -> Callable[P, T]: ... - - -@overload -def commiter(wrapped: None = None, *, reset: bool = ...) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - - -def commiter(wrapped: Callable[P, T] | None = None, *, reset: bool = False) -> Callable[P, T] | Callable[[Callable[P, T]], Callable[P, T]]: - """Indicates this method is part of a commit. There should be no writes within this method and reads will not be updated - - Args: - ---- - reset: Reset the autocommit state when done. Only useful in reset_graph() - - """ - if wrapped is None: - return functools.partial(commiter, reset=reset) - - @wrapt.decorator(enabled=enabled) - def wrapper(wrapped: Callable[P, T], instance: Union["Editable", "CodebaseContext", None] = None, args: P.args = None, kwargs: P.kwargs = None) -> T: - if instance is None: - instance = args[0] - from codegen.sdk.codebase.codebase_context import CodebaseContext - - if isinstance(instance, CodebaseContext): - autocommit = instance._autocommit - else: - autocommit = instance.ctx._autocommit - old_state = autocommit.enter_state(AutoCommitState.Committing) - try: - ret = wrapped(*args, **kwargs) - finally: - autocommit.state = old_state - if reset: - autocommit.reset() - return ret - - return wrapper(wrapped) diff --git a/src/codegen/sdk/extensions/py.typed b/src/codegen/sdk/extensions/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/extensions/resolution.pyi b/src/codegen/sdk/extensions/resolution.pyi deleted file mode 100644 index 9a8cb07d9..000000000 --- a/src/codegen/sdk/extensions/resolution.pyi +++ /dev/null @@ -1,48 +0,0 @@ -from dataclasses import dataclass, field -from functools import cached_property as cached_property -from typing import Generic - -from typing_extensions import TypeVar - -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.core.dataclasses.usage import UsageKind, UsageType -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName - -NodeType = TypeVar("NodeType") - -@dataclass -class ResolutionStack(Generic[NodeType]): - """Represents the resolution stack from a symbol to a usage - - Symbol - ... - - - Attributes: - aliased: If this was aliased at any point - parent_frame: The frame above this frame - """ - - node: NodeType = ... - parent_frame: ResolutionStack | None = ... - direct: bool = True - aliased: bool = False - chained: bool = False - generics: dict = field(default_factory=dict) - - def with_frame(self, node, direct: bool = True, aliased: bool = False, chained: bool = False, generics: dict | None = None) -> ResolutionStack: - """Adds node to the Resolution stack and returns it as a new frame.""" - ... - - def usage_type(self, direct: bool, aliased: bool) -> UsageType: ... - def add_usage(self, match: Editable, usage_type: UsageKind, dest: HasName, codebase_context: CodebaseContext, *, direct: bool = True, aliased: bool = False, chained: bool = False) -> None: - """Add the resolved type to the graph. Also adds any intermediate nodes as usages as well if they are on the graph.""" - - @cached_property - def top(self) -> ResolutionStack: ... - @cached_property - def is_direct_usage(self) -> bool: ... - def with_new_base(self, base, *args, **kwargs) -> ResolutionStack: ... - def with_new_base_frame(self, base: ResolutionStack) -> ResolutionStack: ... - def __init__(self, node, parent_frame=..., aliased=..., direct=..., _seen=...) -> None: ... diff --git a/src/codegen/sdk/extensions/resolution.pyx b/src/codegen/sdk/extensions/resolution.pyx deleted file mode 100644 index 4e8d8bebf..000000000 --- a/src/codegen/sdk/extensions/resolution.pyx +++ /dev/null @@ -1,107 +0,0 @@ -from collections.abc import Generator -from dataclasses import dataclass, field -from functools import cached_property -from typing import TYPE_CHECKING, Generic - -from typing_extensions import TypeVar - -from codegen.sdk.core.dataclasses.usage import Usage, UsageKind, UsageType -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.enums import Edge, EdgeType - -if TYPE_CHECKING: - from codegen.sdk.core.import_resolution import Import - -NodeType = TypeVar("NodeType", bound=Editable) - - -@dataclass(frozen=True) -class ResolutionStack(Generic[NodeType]): - """Represents the resolution stack from a symbol to a usage - - Symbol - ... - - - Attributes: - aliased: If this was aliased at any point - parent_frame: The frame above this frame - """ - - node: NodeType = field(repr=False) - parent_frame: "ResolutionStack | None" = None - direct: bool = True - aliased: bool = False - chained: bool = False - generics: dict = field(default_factory=dict) - - def with_frame(self, node, direct: bool = True, aliased: bool = False, chained: bool = False, generics: dict | None = None) -> "ResolutionStack": - """Adds node to the Resolution stack and returns it as a new frame.""" - assert node is not None - if not generics: - generics = self.generics - return ResolutionStack(node, self, direct, aliased, chained, generics=generics) - - def usage_type(self, direct: bool, aliased: bool, chained: bool) -> UsageType: - if chained: - return UsageType.CHAINED - elif direct: - return UsageType.DIRECT - elif aliased: - return UsageType.ALIASED - else: - return UsageType.INDIRECT - - def get_edges( - self, - match: "Editable", - usage_type: UsageKind, - dest: "HasName", - codebase_context: "CodebaseContext", - *, - direct: bool = True, - aliased: bool = False, - chained: bool = False, - imported_by: Import | None = None, - ) -> Generator[(int, int, Edge), None, None]: - """Get usage edges for a given node.""" - # Only add nodes that are already on the graph - edge_usage_type = self.usage_type(direct, aliased, chained) - if hasattr(self.node, "node_id") and codebase_context.has_node(getattr(self.node, "node_id")): - usage = Usage(kind=usage_type, match=match, usage_type=edge_usage_type, usage_symbol=dest.parent_symbol, imported_by=imported_by) - yield dest.node_id, self.node.node_id, Edge(type=EdgeType.SYMBOL_USAGE, usage=usage) - if self.parent_frame is not None: - from codegen.sdk.core.import_resolution import Import - - if isinstance(self, Import): - imported_by = self - aliased = self.aliased or aliased - direct = self.direct and direct - chained = self.chained or (chained and self.direct) - yield from self.parent_frame.get_edges(match, usage_type, dest, codebase_context, direct=direct, aliased=aliased, chained=chained, imported_by=imported_by) - - def add_usage(self, match: "Editable", usage_type: UsageKind, dest: "HasName", codebase_context: "CodebaseContext", *, direct: bool = True, aliased: bool = False, chained: bool = False) -> None: - """Add the resolved type to the graph. Also adds any intermediate nodes as usages as well if they are on the graph.""" - # Only add nodes that are already on the graph - codebase_context.add_edges(list(self.get_edges(match, usage_type, dest, codebase_context, direct=direct, aliased=aliased, chained=chained))) - - @cached_property - def top(self) -> ResolutionStack: - if self.parent_frame is not None: - return self.parent_frame.top - return self - - @cached_property - def is_direct_usage(self) -> bool: - return self.direct and (self.parent_frame is None or self.parent_frame.is_direct_usage) - - def with_new_base(self, base: Editable, *args, **kwargs) -> ResolutionStack: - new_parent = ResolutionStack(base, *args, **kwargs) - return self.with_new_base_frame(new_parent) - - def with_new_base_frame(self, base: ResolutionStack) -> ResolutionStack: - if self.parent_frame is not None: - new_parent = self.parent_frame.with_new_base_frame(base) - else: - new_parent = base - return new_parent.with_frame(self.node, direct=self.direct, aliased=self.aliased) diff --git a/src/codegen/sdk/extensions/sort.pyx b/src/codegen/sdk/extensions/sort.pyx deleted file mode 100644 index 573ff0d9c..000000000 --- a/src/codegen/sdk/extensions/sort.pyx +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from _operator import attrgetter -from collections.abc import Iterable, Sequence - -from tree_sitter import Node as TSNode -from typing_extensions import TypeVar - -from codegen.sdk.core.interfaces.editable import Editable - -E = TypeVar("E", bound=Editable) - - -def sort_editables(nodes: Iterable[E | None] | Iterable[E], *, reverse: bool = False, dedupe: bool = True, alphabetical: bool = False, by_file: bool = False, by_id: bool = False) -> Sequence[E]: - """Sort a list of Editables. - - Args: - reverse: Reverse the order of the nodes in the list. - dedupe: Filter out duplicate nodes. - alphabetical: Sort nodes alphabetically instead of by start byte - by_file: Sort nodes by file name then either alphabetically or by start byte - """ - if dedupe: - nodes = dict.fromkeys(nodes) - sort_keys = ["name" if alphabetical else "ts_node.start_byte"] - if by_file: - sort_keys.insert(0, "filepath") - if by_id: - sort_keys.append("node_id") - return sorted(filter(lambda node: node is not None, nodes), key=attrgetter(*sort_keys), reverse=reverse) - - -def sort_nodes(nodes: Iterable[TSNode | None] | Iterable[TSNode], *, reverse: bool = False, dedupe: bool = True) -> list[TSNode]: - """Sort a list of ts_nodes. - - Args: - reverse: Reverse the order of the nodes in the list. - dedupe: Filter out duplicate nodes. - """ - if dedupe: - nodes = dict.fromkeys(nodes) - return sorted(filter(lambda node: node is not None, nodes), key=attrgetter("start_byte"), reverse=reverse) diff --git a/src/codegen/sdk/extensions/utils.pyi b/src/codegen/sdk/extensions/utils.pyi deleted file mode 100644 index c75f0ce03..000000000 --- a/src/codegen/sdk/extensions/utils.pyi +++ /dev/null @@ -1,27 +0,0 @@ -from collections.abc import Generator, Iterable -from functools import cached_property as functools_cached_property -from functools import lru_cache as functools_lru_cache - -from tree_sitter import Node as TSNode - -def get_all_identifiers(node: TSNode) -> list[TSNode]: - """Get all the identifiers in a tree-sitter node. Recursive implementation""" - -def iter_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> Generator[TSNode, None, None]: ... -def find_all_descendants( - node: TSNode, - type_names: Iterable[str] | str, - max_depth: int | None = None, - nested: bool = True, - stop_at_first: str | None = None, -) -> list[TSNode]: ... -def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]: - """Returns a list of tuples of the start and end nodes of each line in the node""" - -def find_first_descendant(node: TSNode, type_names: list[str], max_depth: int | None = None) -> TSNode | None: ... - -cached_property = functools_cached_property -lru_cache = functools_lru_cache - -def uncache_all(): ... -def is_descendant_of(node: TSNode, possible_parent: TSNode) -> bool: ... diff --git a/src/codegen/sdk/extensions/utils.pyx b/src/codegen/sdk/extensions/utils.pyx deleted file mode 100644 index 73da6ce59..000000000 --- a/src/codegen/sdk/extensions/utils.pyx +++ /dev/null @@ -1,162 +0,0 @@ -from collections import Counter -from collections.abc import Generator, Iterable -from functools import cached_property as functools_cached_property -from functools import lru_cache as functools_lru_cache - -from tabulate import tabulate -from tree_sitter import Node as TSNode - - -def get_all_identifiers(node: TSNode) -> list[TSNode]: - """Get all the identifiers in a tree-sitter node. Recursive implementation""" - identifiers = [] - - def traverse(current_node: TSNode): - if current_node is None: - return - if current_node.type in ("identifier", "shorthand_property_identifier_pattern"): - identifiers.append(current_node) - return - - elif current_node.type == "attribute": - value_node = current_node.child_by_field_name("value") - if value_node: - traverse(value_node) - return - - for child in current_node.children: - traverse(child) - - traverse(node) - return sorted(dict.fromkeys(identifiers), key=lambda x: x.start_byte) - - -def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True, stop_at_first: str | None = None) -> list[TSNode]: - if isinstance(type_names, str): - type_names = [type_names] - descendants = [] - - def traverse(current_node: TSNode, depth=0): - if max_depth is not None and depth > max_depth: - return - - if current_node.type in type_names: - descendants.append(current_node) - if not nested and current_node != node: - return - - if stop_at_first and current_node.type == stop_at_first: - return - - for child in current_node.children: - traverse(child, depth + 1) - - traverse(node) - return descendants - - -def iter_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> Generator[TSNode, None, None]: - if isinstance(type_names, str): - type_names = [type_names] - type_names = frozenset(type_names) - - def traverse(current_node: TSNode, depth=0): - if max_depth is not None and depth > max_depth: - return - - if current_node.type in type_names: - yield current_node - if not nested and current_node != node: - return - - for child in current_node.children: - yield from traverse(child, depth + 1) - - yield from traverse(node) - - -def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]: - line_to_start_node = {} - line_to_end_node = {} - - def collect_start_and_end_nodes(current_node: TSNode) -> None: - start_row = current_node.start_point[0] - if start_row not in line_to_start_node or line_to_start_node[start_row].start_point[1] >= current_node.start_point[1]: - line_to_start_node[start_row] = current_node - - if current_node.start_point[0] != current_node.end_point[0]: - # We only care about multi-line nodes - for child in current_node.children: - collect_start_and_end_nodes(child) - end_row = current_node.end_point[0] - if end_row not in line_to_end_node or line_to_end_node[end_row].end_point[1] <= current_node.end_point[1]: - line_to_end_node[end_row] = current_node - - collect_start_and_end_nodes(node) - return list(zip(line_to_start_node.values(), line_to_end_node.values())) - - -def find_first_descendant(node: TSNode, type_names: list[str], max_depth: int | None = None) -> TSNode | None: - def find(current_node: TSNode, depth: int = 0) -> TSNode | None: - if current_node.type in type_names: - return current_node - if max_depth is not None and depth >= max_depth: - return - for child in current_node.children: - if ret := find(child, depth + 1): - return ret - - return find(node) - - -to_uncache = [] -lru_caches = [] -counter = Counter() - - -class cached_property(functools_cached_property): - def __get__(self, instance, owner=None): - ret = super().__get__(instance) - if instance is not None: - to_uncache.append((instance, self.attrname)) - counter[self.attrname] += 1 - return ret - - -def lru_cache(func=None, *, maxsize=128, typed=False): - """A wrapper around functools.lru_cache that tracks the cached function so that its cache - can be cleared later via uncache_all(). - """ - if func is None: - # return decorator - return lambda f: lru_cache(f, maxsize=maxsize, typed=typed) - - # return decorated - cached_func = functools_lru_cache(maxsize=maxsize, typed=typed)(func) - lru_caches.append(cached_func) - return cached_func - - -def uncache_all(): - for instance, name in to_uncache: - try: - del instance.__dict__[name] - except KeyError: - pass - - for cached_func in lru_caches: - cached_func.cache_clear() - - -def report(): - print(tabulate(counter.most_common(10))) - - -def is_descendant_of(node: TSNode, possible_parent: TSNode) -> bool: - """Helper to check if node is inside possible_parent in the AST""" - current = node - while current: - if current == possible_parent: - return True - current = current.parent - return False diff --git a/src/codegen/sdk/output/ast.py b/src/codegen/sdk/output/ast.py deleted file mode 100644 index 3f9fae898..000000000 --- a/src/codegen/sdk/output/ast.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Self - -from openai import BaseModel -from pydantic.config import ConfigDict - -from codegen.sdk.codebase.span import Span - - -class AST(BaseModel): - model_config = ConfigDict(frozen=True) - codegen_sdk_type: str - span: Span - tree_sitter_type: str - children: list[tuple[str | None, Self]] diff --git a/src/codegen/sdk/output/constants.py b/src/codegen/sdk/output/constants.py deleted file mode 100644 index b2dd5553a..000000000 --- a/src/codegen/sdk/output/constants.py +++ /dev/null @@ -1,3 +0,0 @@ -ANGULAR_STYLE = False -MAX_EDITABLE_LINES = 10 -MAX_STRING_LENGTH = 10000 diff --git a/src/codegen/sdk/output/inspect.py b/src/codegen/sdk/output/inspect.py deleted file mode 100644 index 1ef0a4571..000000000 --- a/src/codegen/sdk/output/inspect.py +++ /dev/null @@ -1,25 +0,0 @@ -import inspect -from collections.abc import Callable -from inspect import Parameter - -from codegen.shared.decorators.docs import DocumentedObject, no_apidoc_signatures - - -def only_default_args(method: Callable) -> bool: - sig = inspect.signature(method) - for param in sig.parameters: - if not isinstance(param, Parameter): - return False - if param.default != Parameter.empty: - return False - return True - - -def is_noapidoc(obj: object, attr: str) -> bool: - module = inspect.getmodule(obj) - module_name = module.__name__ if module else "" - - if module_name: - doc_obj = DocumentedObject(name=attr, module=module_name, object=obj) - return doc_obj.signature() in no_apidoc_signatures - return False diff --git a/src/codegen/sdk/output/jsonable.py b/src/codegen/sdk/output/jsonable.py deleted file mode 100644 index 2da4efd1f..000000000 --- a/src/codegen/sdk/output/jsonable.py +++ /dev/null @@ -1,88 +0,0 @@ -from abc import ABC, abstractmethod -from functools import cached_property - -from tree_sitter import Node as TSNode - -from codegen.sdk._proxy import ProxyProperty -from codegen.sdk.codebase.span import Span -from codegen.sdk.output.inspect import is_noapidoc, only_default_args -from codegen.sdk.output.placeholder import Placeholder -from codegen.sdk.output.utils import safe_getattr -from codegen.sdk.types import JSON -from codegen.shared.decorators.docs import noapidoc - -BLACKLIST = ["json", "G", "viz", "autocommit_cache", "ts_node", "symbol_usages", "usages"] - - -@noapidoc -class JSONable(ABC): - ts_node: TSNode - - @noapidoc - def _list_members(self, include_methods: bool = True) -> dict[str, object]: - """Lists all valid members (properties/attributes/methods) of this object.""" - members = {} - for attr in dir(self): - if attr in BLACKLIST or attr.startswith("_"): - continue - if is_noapidoc(self, attr): - continue - val = safe_getattr(self, attr, None) - if val is None: - continue - if callable(val) and not isinstance(val, ProxyProperty): - if not include_methods: - continue - if not safe_getattr(val, "_apidoc", True): - continue - if safe_getattr(val, "_reader", False): - if not only_default_args(val): - continue - attr += "()" - val = val() - members[attr] = val - return members - - @noapidoc - def json(self, max_depth: int = 2, methods: bool = True) -> JSON: - if max_depth < 0: - self._add_to_index - return self.placeholder.model_dump() - - res = {} - for attr, val in self._list_members(include_methods=methods).items(): - depth = max_depth - 1 - - if isinstance(val, JSONable): - val = val.json(depth, methods) - if isinstance(val, list): - val = [elem.json(depth, methods) if isinstance(elem, JSONable) else elem for elem in val] - if isinstance(val, dict): - val = {key: elem.json(depth, methods) if isinstance(elem, JSONable) else elem for key, elem in val.items()} - if isinstance(val, dict | str | list | int | float | bool | None): - res[attr] = val - - return res - - @property - @noapidoc - def placeholder(self) -> Placeholder: - """Property that returns a placeholder representation of the current object. - - Creates a Placeholder object representing the current object, typically when a full JSON - representation cannot be provided due to depth limitations. - - Returns: - Placeholder: A simplified representation containing the object's span, string representation, - kind_id from the TreeSitter node, and class name. - """ - return Placeholder(span=self.span, preview=repr(self), kind_id=self.ts_node.kind_id, name=self.__class__.__name__) - - @property - @abstractmethod - @noapidoc - def span(self) -> Span: ... - @cached_property - @abstractmethod - @noapidoc - def _add_to_index(self) -> None: ... diff --git a/src/codegen/sdk/output/placeholder.py b/src/codegen/sdk/output/placeholder.py deleted file mode 100644 index b6bbfc9a0..000000000 --- a/src/codegen/sdk/output/placeholder.py +++ /dev/null @@ -1,12 +0,0 @@ -from pydantic import BaseModel -from pydantic.config import ConfigDict - -from codegen.sdk.codebase.span import Span - - -class Placeholder(BaseModel): - model_config = ConfigDict(frozen=True) - preview: str - span: Span - kind_id: int - name: str diff --git a/src/codegen/sdk/output/utils.py b/src/codegen/sdk/output/utils.py deleted file mode 100644 index 5dfe8a21d..000000000 --- a/src/codegen/sdk/output/utils.py +++ /dev/null @@ -1,84 +0,0 @@ -import json -import sys -from decimal import Decimal -from os import PathLike -from pathlib import Path - -from rich.console import Console, RenderResult -from rich.syntax import Syntax -from rich.text import Text -from tree_sitter import Node as TSNode -from tree_sitter import Point - -from codegen.sdk.output.constants import MAX_EDITABLE_LINES - - -def style_editable(ts_node: TSNode, filepath: PathLike, file_node: TSNode) -> RenderResult: - start_line = ts_node.start_point[0] + 1 # 1 based - start_col = ts_node.start_point[1] - end_line = ts_node.end_point[0] + 1 # 1 based - end_col = ts_node.end_point[1] - truncated = 0 - truncated_len = start_line + MAX_EDITABLE_LINES - 1 - if end_line > truncated_len: - truncated = end_line - start_line + 1 - for child in ts_node.children: - if child.end_point[0] + 1 < truncated_len: - end_line = child.end_point[0] + 1 - syntax = _stylize_range(end_col, end_line, file_node, filepath, start_col, start_line) - yield syntax - if truncated: - yield Text(f"\nTruncated from {truncated} lines") - - -def _stylize_range(end_col, end_line, file_node, filepath, start_col, start_line): - syntax = Syntax.from_path(filepath, line_numbers=True, line_range=(start_line, end_line)) - syntax.stylize_range(style="dim", start=(start_line, 0), end=(start_line, start_col)) - syntax.stylize_range(style="dim", start=(end_line, end_col), end=(file_node.end_point[0] + 1, file_node.end_point[1])) - syntax.stylize_range(style="dim", start=(end_line, end_col), end=(end_line + 1, 0)) - return syntax - - -def stylize_error(path: PathLike, start: tuple[int, int] | Point, end: tuple[int, int] | Point, file_node: TSNode, content: str, message: str): - Path(path).write_text(content) - source = _stylize_range(end[1], end[0] + 1, file_node, path, start[1], start[0] + 1) - console = Console(file=sys.stderr) - console.print(f"Syntax Error {message} at:") - console.print(source) - - -def safe_getattr(obj, attr, default=None): - try: - return getattr(obj, attr, default) - except (AttributeError, NotImplementedError): - return default - - -class DeterministicJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, float): - return f"{obj:.10f}" - if isinstance(obj, Decimal): - return f"{obj:.10f}" - if isinstance(obj, set): - return sorted(list(obj)) - if hasattr(obj, "__dict__"): - return {key: self.default(value) for key, value in obj.__dict__.items()} - return super().default(obj) - - -def deterministic_json_dumps(data, **kwargs): - def sort_dict(item): - if isinstance(item, dict): - return {key: sort_dict(value) for key, value in sorted(item.items())} - elif isinstance(item, list): - if len(item) > 0 and isinstance(item[0], dict): - # Sort list of dictionaries based on all keys - return sorted([sort_dict(i) for i in item], key=lambda x: json.dumps(x, sort_keys=True)) - else: - return [sort_dict(i) for i in item] - else: - return item - - sorted_data = sort_dict(data) - return json.dumps(sorted_data, cls=DeterministicJSONEncoder, **kwargs) diff --git a/src/codegen/sdk/py.typed b/src/codegen/sdk/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/python/__init__.py b/src/codegen/sdk/python/__init__.py deleted file mode 100644 index 1c007891c..000000000 --- a/src/codegen/sdk/python/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from codegen.sdk.python.assignment import PyAssignment -from codegen.sdk.python.class_definition import PyClass -from codegen.sdk.python.file import PyFile -from codegen.sdk.python.function import PyFunction -from codegen.sdk.python.import_resolution import PyImport -from codegen.sdk.python.symbol import PySymbol - -__all__ = [ - "PyAssignment", - "PyClass", - "PyFile", - "PyFunction", - "PyImport", - "PySymbol", -] diff --git a/src/codegen/sdk/python/assignment.py b/src/codegen/sdk/python/assignment.py deleted file mode 100644 index 2614b6d43..000000000 --- a/src/codegen/sdk/python/assignment.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -from collections.abc import Collection -from typing import TYPE_CHECKING - -from codegen.sdk.codebase.transactions import RemoveTransaction, TransactionPriority -from codegen.sdk.core.assignment import Assignment -from codegen.sdk.core.autocommit.decorators import remover -from codegen.sdk.core.expressions.multi_expression import MultiExpression -from codegen.sdk.core.statements.assignment_statement import AssignmentStatement -from codegen.sdk.extensions.autocommit import reader -from codegen.sdk.python.symbol import PySymbol -from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup -from codegen.shared.decorators.docs import noapidoc, py_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.statements.assignment_statement import PyAssignmentStatement - -logger = get_logger(__name__) - - -@py_apidoc -class PyAssignment(Assignment["PyAssignmentStatement"], PySymbol): - """An abstract representation of a assignment in python. - - This includes assignments of variables to functions, other variables, class instantiations, etc. - """ - - @noapidoc - @classmethod - def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyAssignmentStatement) -> MultiExpression[PyAssignmentStatement, PyAssignment]: - if ts_node.type not in ["assignment", "augmented_assignment"]: - msg = f"Unknown assignment type: {ts_node.type}" - raise ValueError(msg) - - left_node = ts_node.child_by_field_name("left") - right_node = ts_node.child_by_field_name("right") - assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node) - return MultiExpression(ts_node, file_node_id, ctx, parent, assignments) - - @classmethod - def from_named_expression(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyAssignmentStatement) -> MultiExpression[PyAssignmentStatement, PyAssignment]: - """Creates a MultiExpression from a Python named expression. - - Creates assignments from a named expression node ('walrus operator' :=) by parsing its name and value fields. - - Args: - ts_node (TSNode): The TreeSitter node representing the named expression. - file_node_id (NodeId): The identifier of the file containing this node. - ctx (CodebaseContext): The codebase context instance. - parent (Parent): The parent node that contains this expression. - - Returns: - MultiExpression[Parent, PyAssignment]: A MultiExpression containing the assignments created from the named expression. - - Raises: - ValueError: If the provided ts_node is not of type 'named_expression'. - """ - if ts_node.type != "named_expression": - msg = f"Unknown assignment type: {ts_node.type}" - raise ValueError(msg) - - left_node = ts_node.child_by_field_name("name") - right_node = ts_node.child_by_field_name("value") - assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node) - return MultiExpression(ts_node, file_node_id, ctx, parent, assignments) - - @property - @reader - def comment(self) -> PyCommentGroup | None: - """Returns the comment group associated with the symbol. - - Retrieves and returns any comments associated with the symbol. These comments are typically - located above or adjacent to the symbol in the source code. - - Args: - self: The symbol instance to retrieve comments for. - - Returns: - PyCommentGroup | None: A comment group object containing the symbol's comments if they exist, - None otherwise. - """ - # HACK: This is a temporary solution until comments are fixed - return PyCommentGroup.from_symbol_comments(self) - - @property - @reader - def inline_comment(self) -> PyCommentGroup | None: - """A property that retrieves the inline comment group associated with a symbol. - - Retrieves any inline comments that are associated with this symbol. Inline comments are comments that appear on the same line as the code. - - Args: - None - - Returns: - PyCommentGroup | None: The inline comment group associated with the symbol, if one exists. Returns None if there are no inline comments. - """ - # HACK: This is a temporary solution until comments are fixed - return PyCommentGroup.from_symbol_inline_comments(self, self.ts_node.parent) - - @noapidoc - def _partial_remove_when_tuple(self, name, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True): - idx = self.parent.left.index(name) - value = self.value[idx] - self.parent._values_scheduled_for_removal.append(value) - # Special case for removing brackets of value - if len(self.value) - len(self.parent._values_scheduled_for_removal) == 1: - remainder = str(next(x for x in self.value if x not in self.parent._values_scheduled_for_removal and x != value)) - r_t = RemoveTransaction(self.value.start_byte, self.value.end_byte, self.file, priority=priority) - self.transaction_manager.add_transaction(r_t) - self.value.insert_at(self.value.start_byte, remainder, priority=priority) - else: - # Normal just remove one value - value.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - # Remove assignment name - name.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) - - @noapidoc - def _active_transactions_on_assignment_names(self, transaction_order: TransactionPriority) -> int: - return [ - any(self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=transaction_order)) - for asgnmt in self.parent.assignments - ].count(True) - - @remover - def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: - """Deletes this assignment and its related extended nodes (e.g. decorators, comments). - - - Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase. - After removing the node, it handles cleanup of any surrounding formatting based on the context. - - Args: - delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True. - priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0. - dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True. - - Returns: - None - """ - if self.ctx.config.unpacking_assignment_partial_removal: - if isinstance(self.parent, AssignmentStatement) and len(self.parent.assignments) > 1: - # Unpacking assignments - name = self.get_name() - if isinstance(self.value, Collection): - if len(self.parent._values_scheduled_for_removal) < len(self.parent.assignments) - 1: - self._partial_remove_when_tuple(name, delete_formatting, priority, dedupe) - return - else: - self.parent._values_scheduled_for_removal = [] - else: - if name.source == "_": - logger.warning("Attempting to remove '_' in unpacking, command will be ignored. If you wish to remove the statement, remove the other remaining variable(s)!") - return - transaction_count = self._active_transactions_on_assignment_names(TransactionPriority.Edit) - throwaway = [asgnmt.name == "_" for asgnmt in self.parent.assignments].count(True) - # Only edit if we didn't already omit all the other assignments, otherwise just remove the whole thing - if transaction_count + throwaway < len(self.parent.assignments) - 1: - name.edit("_", priority=priority, dedupe=dedupe) - return - - super().remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) diff --git a/src/codegen/sdk/python/class_definition.py b/src/codegen/sdk/python/class_definition.py deleted file mode 100644 index d4c6e2394..000000000 --- a/src/codegen/sdk/python/class_definition.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Self - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_group import SymbolGroup -from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection -from codegen.sdk.core.symbol_groups.parents import Parents -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.detached_symbols.decorator import PyDecorator -from codegen.sdk.python.detached_symbols.parameter import PyParameter -from codegen.sdk.python.expressions.type import PyType -from codegen.sdk.python.function import PyFunction -from codegen.sdk.python.interfaces.has_block import PyHasBlock -from codegen.sdk.python.symbol import PySymbol -from codegen.shared.decorators.docs import noapidoc, py_apidoc - - -@py_apidoc -class PyClass(Class[PyFunction, PyDecorator, PyCodeBlock, PyParameter, PyType], PyHasBlock, PySymbol): - """Extends Class for Python codebases - - Attributes: - constructor_keyword: The keyword used to identify the constructor method in Python classes. - """ - - _decorated_node: TSNode | None - constructor_keyword = "__init__" - - def __init__(self, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: PyHasBlock, decorated_node: TSNode | None = None) -> None: - super().__init__(ts_node, file_id, ctx, parent) - self._decorated_node = decorated_node - - if superclasses_node := self.ts_node.child_by_field_name("superclasses"): - self.parent_classes = Parents(superclasses_node, self.file_node_id, self.ctx, self) - if self.constructor is not None and len(self.constructor.parameters) > 1: - self._parameters = SymbolGroup(self.file_node_id, self.ctx, self, children=self.constructor.parameters[1:]) - self.type_parameters = self.child_by_field_name("type_parameters") - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - dest = dest or self.self_dest - # =====[ Decorated functions ]===== - for decorator in self.decorators: - decorator._compute_dependencies(usage_type, dest) - - # =====[ Superclasses ]===== - # e.g. class A(B, c.D): - if self.parent_classes is not None: - self.parent_classes._compute_dependencies(UsageKind.SUBCLASS, dest) - if self.type_parameters: - self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) - # =====[ Code Block ]===== - self.code_block._compute_dependencies(usage_type, dest) - - @reader - def _parse_methods(self) -> MultiLineCollection[PyFunction, Self]: - methods = [m.symbol for m in self.code_block.symbol_statements if isinstance(m.symbol, PyFunction) and not m.symbol.is_overload] - block_node = self.code_block.ts_node - indent_size = block_node.named_children[0].start_point[1] - if len(methods) > 0: - # Set start byte at column=0 of first method - start_byte = methods[0].start_byte - methods[0].start_point[1] - elif len(self.code_block.statements) > 0: - # Set start byte at next byte after the last statement in code block - # Assumption is that the next byte is column=0 of the statement's next line - start_byte = self.code_block.statements[-1].end_byte + 1 - else: - # Set start byte at column=0 of start of the code block - start_byte = block_node.start_byte - block_node.start_point[1] - return MultiLineCollection(children=methods, file_node_id=self.file_node_id, ctx=self.ctx, parent=self, node=self.code_block.ts_node, indent_size=indent_size, start_byte=start_byte) - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def add_source(self, source: str) -> None: - """Adds source code to the class definition. - - Adds the provided source code string to the body of the class definition. The method handles - proper indentation of the source code within the class body. - - Args: - source (str): The source code to be added to the class definition. If the source doesn't - start with a newline, it will be indented with 4 spaces. - - Raises: - ValueError: If the class body cannot be found. - """ - class_body = self.child_by_field_name("body") - if class_body is None: - msg = "Could not find class body" - raise ValueError(msg) - # Mimic previous behaviour - source = source if source.startswith("\n") else " " + source - # TODO: use real fix_indentation behaviour - class_body.insert_after("\n" + source, fix_indentation=False, newline=False) - - @cached_property - @noapidoc - def generics(self) -> dict[str, PyType]: - ret = super().generics - if self.parent_classes: - for supercls in self.parent_classes: - if isinstance(supercls, GenericType): - if supercls.name == "Generic": - for param in supercls.parameters: - ret[param.name] = param - return ret diff --git a/src/codegen/sdk/python/detached_symbols/code_block.py b/src/codegen/sdk/python/detached_symbols/code_block.py deleted file mode 100644 index f251fd324..000000000 --- a/src/codegen/sdk/python/detached_symbols/code_block.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.detached_symbols.code_block import CodeBlock -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.core.statements.import_statement import ImportStatement -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.python.assignment import PyAssignment - from codegen.sdk.python.interfaces.has_block import PyHasBlock - from codegen.sdk.python.statements.with_statement import WithStatement - - -Parent = TypeVar("Parent", bound="PyHasBlock") - - -@py_apidoc -class PyCodeBlock(CodeBlock[Parent, "PyAssignment"], Generic[Parent]): - """Extends CodeBlock for Python codebases.""" - - @noapidoc - @reader - def _parse_statements(self) -> MultiLineCollection[Statement, Self]: - statements: list[Statement] = self.ctx.parser.parse_py_statements(self.ts_node, self.file_node_id, self.ctx, self) - collection = MultiLineCollection( - children=statements, - file_node_id=self.file_node_id, - ctx=self.ctx, - parent=self, - node=self.ts_node, - indent_size=self.start_point[1], - leading_delimiter="", - start_byte=self.start_byte - self.start_point[1], - ) - return collection - - @property - @reader - def with_statements(self) -> list[WithStatement]: - """Returns a list of all 'with' statements within the code block. - - Retrieves all with statements in the code block, including those at all nested levels. - - Returns: - A list of with statement objects found within this code block. - """ - return [x for x in self.statements if x.statement_type == StatementType.WITH_STATEMENT] - - @reader - def get_with_statements(self, level: int) -> list[WithStatement]: - """Gets with statements at a specific block level. - - Filters the with statements in this code block to only include those at the specified block level. - - Args: - level (int): The block level to filter by. 0 represents the top level. - - Returns: - list[WithStatement]: A list of WithStatement objects at the specified block level. - """ - return [x for x in self.with_statements if x.parent.level == level] - - def _smart_remove(self, child, *args, **kwargs) -> bool: - if len(self.statements) <= 1 and not isinstance(child, ImportStatement): - if isinstance(self.parent, BlockStatement): - self.parent.remove(*args, **kwargs) - return True - else: - self.remove_byte_range(self.start_byte, self.end_byte) - self.parent.insert_after("pass", newline=False) - return True - return False diff --git a/src/codegen/sdk/python/detached_symbols/decorator.py b/src/codegen/sdk/python/detached_symbols/decorator.py deleted file mode 100644 index 861dfea0c..000000000 --- a/src/codegen/sdk/python/detached_symbols/decorator.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.detached_symbols.decorator import Decorator -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.python.class_definition import PyClass - from codegen.sdk.python.detached_symbols.parameter import PyParameter - from codegen.sdk.python.function import PyFunction - - -@py_apidoc -class PyDecorator(Decorator["PyClass", "PyFunction", "PyParameter"]): - """Extends Decorators for Python codebases.""" - - @reader - def _get_name_node(self) -> TSNode: - """Returns the name of the decorator.""" - for child in self.ts_node.children: - # =====[ Identifier ]===== - # Just `@dataclass` etc. - if child.type == "identifier": - return child - - # =====[ Attribute ]===== - # e.g. `@a.b` - elif child.type == "attribute": - return child - - # =====[ Call ]===== - # e.g. `@a.b()` - elif child.type == "call": - func = child.child_by_field_name("function") - return func - - msg = f"Could not find decorator name within {self.source}" - raise ValueError(msg) - - @property - @reader - def call(self) -> FunctionCall | None: - """Gets the function call node from the decorator if the decorator is a call. - - This property retrieves the FunctionCall instance if the decorator is a function call - (e.g., @decorator()), otherwise returns None for simple decorators (e.g., @decorator). - - Args: - None - - Returns: - FunctionCall | None: A FunctionCall instance if the decorator is a function call, - None if it's a simple decorator. - """ - if call_node := next((x for x in self.ts_node.named_children if x.type == "call"), None): - return FunctionCall(call_node, self.file_node_id, self.ctx, self.parent) - return None diff --git a/src/codegen/sdk/python/detached_symbols/parameter.py b/src/codegen/sdk/python/detached_symbols/parameter.py deleted file mode 100644 index 7eccd7aae..000000000 --- a/src/codegen/sdk/python/detached_symbols/parameter.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import TYPE_CHECKING - -from typing_extensions import deprecated - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.detached_symbols.parameter import Parameter -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.python.expressions.type import PyType -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.python.function import PyFunction - - -@py_apidoc -class PyParameter(Parameter[PyType, Collection["PyParameter", "PyFunction"]]): - """Extends Parameter for Python codebases.""" - - @property - @reader - def is_optional(self) -> bool: - """Determines if the parameter is optional in Python code. - - A parameter is considered optional if it has a default value or if it is a list/dictionary splat pattern. - This includes default parameters, typed default parameters, and list/dictionary splat patterns. - - Returns: - bool: True if the parameter is optional, False otherwise. - """ - return ( - self.ts_node.type == "default_parameter" or self.ts_node.type == "typed_default_parameter" or self.ts_node.type == "list_splat_pattern" or self.ts_node.type == "dictionary_splat_pattern" - ) - - @property - @reader - def is_variadic(self) -> bool: - """Determines if a parameter is a variadic parameter. - - Checks if this parameter is defined as a variadic parameter using the splat operator (*args or **kwargs). - - Returns: - bool: True if the parameter is variadic (uses * or ** syntax), False otherwise. - """ - return self.ts_node.type == "list_splat_pattern" or self.ts_node.type == "dictionary_splat_pattern" - - @deprecated("Use `type.edit` instead") - @writer - def set_type_annotation(self, type_annotation: str, include_comment: str = "") -> None: - """Sets the type annotation of a parameter. - - Sets or updates the type annotation for this parameter. This method is deprecated in favor of using `type.edit` directly. - - Args: - type_annotation (str): The type annotation to set for the parameter. - include_comment (str, optional): A comment to include with the type annotation. Defaults to "". - - Returns: - None - - Deprecated: - Use `type.edit` instead. - """ - self.type.edit(type_annotation) - - @writer - def add_trailing_comment(self, comment: str) -> None: - """Add a trailing comment to a parameter in a function signature. - - Adds a trailing comment after the specified parameter in the parent function's signature, followed by a newline. - - Args: - comment (str): The comment text to be added after the parameter. - - Returns: - None - """ - self.parent_function.edit(self.parent_function.source.replace(self.source + ",", self.source + "," + f"# {comment} \n\n")) diff --git a/src/codegen/sdk/python/expressions/chained_attribute.py b/src/codegen/sdk/python/expressions/chained_attribute.py deleted file mode 100644 index 19639f8d7..000000000 --- a/src/codegen/sdk/python/expressions/chained_attribute.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.expressions import Expression, Name -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@py_apidoc -class PyChainedAttribute(ChainedAttribute[Expression, Name, Parent], Generic[Parent]): - """Abstract representation of a python chained attribute. - This includes methods of python classes and module functions. - """ - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent=parent, object=ts_node.child_by_field_name("object"), attribute=ts_node.child_by_field_name("attribute")) diff --git a/src/codegen/sdk/python/expressions/conditional_expression.py b/src/codegen/sdk/python/expressions/conditional_expression.py deleted file mode 100644 index 16e52c28a..000000000 --- a/src/codegen/sdk/python/expressions/conditional_expression.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import TYPE_CHECKING, TypeVar - -from codegen.sdk.core.expressions.ternary_expression import TernaryExpression -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@py_apidoc -class PyConditionalExpression(TernaryExpression[Parent]): - """Conditional Expressions (A if condition else B)""" - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self.consequence = self.children[0] - self.condition = self.children[1] - self.alternative = self.children[2] diff --git a/src/codegen/sdk/python/expressions/generic_type.py b/src/codegen/sdk/python/expressions/generic_type.py deleted file mode 100644 index 0c685d4c4..000000000 --- a/src/codegen/sdk/python/expressions/generic_type.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.python.expressions.named_type import PyNamedType -from codegen.shared.decorators.docs import py_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.python.expressions.type import PyType - -logger = get_logger(__name__) - - -Parent = TypeVar("Parent") - - -@py_apidoc -class PyGenericType(PyNamedType[Parent], GenericType["PyType", Parent], Generic[Parent]): - """Generic python type. - - Examples: - list[int] - """ - - def _get_name_node(self) -> TSNode | None: - if self.ts_node_type == "subscript": - return self.ts_node.child_by_field_name("value") - if self.ts_node_type == "generic_type": - return self.child_by_field_types(["identifier", "attribute"]).ts_node - return self.ts_node - - def _get_parameters(self) -> Collection["PyType", Self] | None: - if self.ts_node_type == "subscript": - types = [self._parse_type(child) for child in self.ts_node.children_by_field_name("subscript")] - return Collection(node=self.ts_node, file_node_id=self.file_node_id, ctx=self.ctx, parent=self, children=types) - elif self.ts_node_type == "generic_type": - type_parameter = self.ts_node.named_children[1] - assert type_parameter.type == "type_parameter" - types = [self._parse_type(child) for child in type_parameter.named_children] - return Collection(node=type_parameter, file_node_id=self.file_node_id, ctx=self.ctx, parent=self, children=types) - logger.warning(f"Type {self.ts_node_type} not implemented") - return None diff --git a/src/codegen/sdk/python/expressions/named_type.py b/src/codegen/sdk/python/expressions/named_type.py deleted file mode 100644 index b2d1bd604..000000000 --- a/src/codegen/sdk/python/expressions/named_type.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Generic, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions.named_type import NamedType -from codegen.shared.decorators.docs import py_apidoc - -Parent = TypeVar("Parent") - - -@py_apidoc -class PyNamedType(NamedType[Parent], Generic[Parent]): - """Named python type - - Examples: - int,str (builtin types) - Path (classes) - """ - - def _get_name_node(self) -> TSNode | None: - return self.ts_node diff --git a/src/codegen/sdk/python/expressions/string.py b/src/codegen/sdk/python/expressions/string.py deleted file mode 100644 index 7c717f11f..000000000 --- a/src/codegen/sdk/python/expressions/string.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions import Expression, String -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -Parent = TypeVar("Parent", bound="Expression") - - -@py_apidoc -class PyString(String, Generic[Parent]): - """An abstract representation of a python string.""" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - substitutions = [x for x in ts_node.named_children if x.type == "interpolation"] - self.expressions = [self._parse_expression(x.child_by_field_name("expression")) for x in substitutions] diff --git a/src/codegen/sdk/python/expressions/type.py b/src/codegen/sdk/python/expressions/type.py deleted file mode 100644 index 2e7b32aa9..000000000 --- a/src/codegen/sdk/python/expressions/type.py +++ /dev/null @@ -1,2 +0,0 @@ -PyType = "PyUnionType[Parent] | PyNamedType[Parent] | PyGenericType[Parent] | NoneType" -__all__ = ["PyType"] diff --git a/src/codegen/sdk/python/expressions/union_type.py b/src/codegen/sdk/python/expressions/union_type.py deleted file mode 100644 index d6181989e..000000000 --- a/src/codegen/sdk/python/expressions/union_type.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.expressions.union_type import UnionType -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.python.expressions.type import PyType - -Parent = TypeVar("Parent") - - -@py_apidoc -class PyUnionType(UnionType["PyType", Parent], Generic[Parent]): - """Union type - - Examples: - str | int - """ - - pass diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py deleted file mode 100644 index 3b1fc9f93..000000000 --- a/src/codegen/sdk/python/file.py +++ /dev/null @@ -1,276 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.file import SourceFile -from codegen.sdk.core.interface import Interface -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.enums import ImportType -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.python import PyAssignment -from codegen.sdk.python.class_definition import PyClass -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.expressions.type import PyType -from codegen.sdk.python.function import PyFunction -from codegen.sdk.python.import_resolution import PyImport -from codegen.sdk.python.interfaces.has_block import PyHasBlock -from codegen.sdk.python.statements.attribute import PyAttribute -from codegen.shared.decorators.docs import noapidoc, py_apidoc -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.python.symbol import PySymbol - - -@py_apidoc -class PyFile(SourceFile[PyImport, PyFunction, PyClass, PyAssignment, Interface[PyCodeBlock, PyAttribute, PyFunction, PyType], PyCodeBlock], PyHasBlock): - """SourceFile representation for Python codebase - - Attributes: - programming_language: The programming language of the file. Set to ProgrammingLanguage.PYTHON. - """ - - programming_language = ProgrammingLanguage.PYTHON - - @staticmethod - def get_extensions() -> list[str]: - """Returns the file extensions associated with Python files. - - Gets the list of file extensions that are considered Python files. - - Returns: - list[str]: A list containing '.py' as the only Python file extension. - """ - return [".py"] - - def symbol_can_be_added(self, symbol: PySymbol) -> bool: - """Checks if a Python symbol can be added to this Python source file. - - Verifies whether a given Python symbol is compatible with and can be added to this Python source file. Currently always returns True as Python files can contain any Python symbol type. - - Args: - symbol (PySymbol): The Python symbol to check for compatibility with this file. - - Returns: - bool: Always returns True as Python files can contain any Python symbol type. - """ - return True - - #################################################################################################################### - # GETTERS - #################################################################################################################### - - @noapidoc - def get_import_module_name_for_file(self, filepath: str, ctx: CodebaseContext) -> str: - """Returns the module name that this file gets imported as - - For example, `my/package/name.py` => `my.package.name` - """ - base_path = ctx.projects[0].base_path - module = filepath.replace(".py", "") - if module.endswith("__init__"): - module = "/".join(module.split("/")[:-1]) - module = module.replace("/", ".") - # TODO - FIX EDGE CASE WITH REPO BASE!! - if base_path and module.startswith(base_path): - module = module.replace(f"{base_path}.", "", 1) - # TODO - FIX EDGE CASE WITH SRC BASE - if module.startswith("src."): - module = module.replace("src.", "", 1) - return module - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Generates an import string for a symbol. - - Constructs a Python import statement based on the provided parameters, handling different import types and module paths. - - Args: - alias (str | None, optional): Alias to use for the imported symbol. Defaults to None. - module (str | None, optional): Module path to import from. If None, uses module name from source. Defaults to None. - import_type (ImportType, optional): Type of import statement to generate. Defaults to ImportType.UNKNOWN. - is_type_import (bool, optional): Whether this is a type import. Currently unused. Defaults to False. - - Returns: - str: A formatted import string in the form of 'from {module} import {symbol}' with optional alias or wildcard syntax. - """ - symbol_name = self.name - module = module if module is not None else self.import_module_name - # Case: importing dir/file.py - if f".{symbol_name}" in module: - module = module.replace(f".{symbol_name}", "") - # Case: importing file.py, symbol and module will be the same - if symbol_name == module: - module = "." - - if import_type == ImportType.WILDCARD: - return f"from {module} import * as {symbol_name}" - elif alias is not None and alias != self.name: - return f"from {module} import {symbol_name} as {alias}" - else: - return f"from {module} import {symbol_name}" - - @reader - def get_import_insert_index(self, import_string) -> int | None: - """Determines the index position where a new import statement should be inserted in a Python file. - - The function determines the optimal position for inserting a new import statement, following Python's import ordering conventions. - Future imports are placed at the top of the file, followed by all other imports. - - Args:z - import_string (str): The import statement to be inserted. - - Returns: - int | None: The index where the import should be inserted. Returns 0 for future imports or if there are no existing imports after future imports. - Returns None if there are no imports in the file. - """ - if not self.imports: - return None - - # Case: if the import is a future import, add to top of file - if "__future__" in import_string: # TODO: parse this into an import module and import name - return 0 - - # Case: file already had future imports, add import after the last one - future_imp_idxs = [idx for idx, imp in enumerate(self.imports) if "__future__" in imp.source] - if future_imp_idxs: - return future_imp_idxs[-1] + 1 - - # Case: default add import to top of file - return 0 - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None: - """Adds an import to the file. - - This method adds an import statement to the file. It can handle both string imports and symbol imports. - If the import already exists in the file, or is pending to be added, it won't be added again. - Future imports are placed at the top, followed by regular imports. - - Args: - imp (Symbol | str): Either a Symbol to import or a string representation of an import statement. - alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None. - import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN. - is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False. - - Returns: - Import | None: The existing import for the symbol if found, otherwise None. - """ - # Handle Symbol imports - if isinstance(imp, Symbol): - imports = self.imports - match = next((x for x in imports if x.imported_symbol == imp), None) - if match: - return match - - # Convert symbol to import string - import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import) - else: - # Handle string imports - import_string = str(imp) - - # Check for duplicate imports - if any(import_string.strip() in str(imp.source) for imp in self.imports): - return None - if import_string.strip() in self._pending_imports: - return None - - # Add to pending imports - self._pending_imports.add(import_string.strip()) - self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear()) - - # Insert at correct location - if self.imports: - import_insert_index = self.get_import_insert_index(import_string) or 0 - if import_insert_index < len(self.imports): - self.imports[import_insert_index].insert_before(import_string, priority=1) - else: - self.imports[-1].insert_after(import_string, priority=1) - else: - self.insert_before(import_string, priority=1) - - return None - - @noapidoc - def remove_unused_exports(self) -> None: - """Removes unused exports from the file. NO-OP for python""" - pass - - @cached_property - @noapidoc - @reader(cache=True) - def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[PyImport]]: - """Returns a dict mapping name => Symbol (or import) in this file that can be imported from - another file. - """ - if self.name == "__init__": - ret = super().valid_import_names - if self.directory: - for file in self.directory: - if file.name == "__init__": - continue - if isinstance(file, PyFile): - ret[file.name] = file - return ret - return super().valid_import_names - - @noapidoc - def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None: - """Recursively searches for a symbol through wildcard import chains. - - Attempts to find a symbol by name in the current file, and if not found, recursively searches - through any wildcard imports (from x import *) to find the symbol in imported modules. - - Args: - symbol_name (str): The name of the symbol to search for. - - Returns: - PySymbol | None: The found symbol if it exists in this file or any of its wildcard - imports, None otherwise. - """ - node = None - if node := self.get_node_by_name(symbol_name): - return node - - if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: - for wildcard_import in wildcard_imports: - if imp_resolution := wildcard_import.resolve_import(): - node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name) - - return node - - @noapidoc - def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbol | None: - """Finds the wildcard import that resolves a given symbol name. - - Searches for a symbol by name, first in the current file, then through wildcard imports. - Unlike get_node_from_wildcard_chain, this returns the wildcard import that contains - the symbol rather than the symbol itself. - - Args: - symbol_name (str): The name of the symbol to search for. - - Returns: - PyImport | PySymbol | None: - - PySymbol if the symbol is found directly in this file - - PyImport if the symbol is found through a wildcard import - - None if the symbol cannot be found - """ - node = None - if node := self.get_node_by_name(symbol_name): - return node - - if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: - for wildcard_import in wildcard_imports: - if imp_resolution := wildcard_import.resolve_import(): - if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name): - node = wildcard_import - - return node diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py deleted file mode 100644 index 0ab63f114..000000000 --- a/src/codegen/sdk/python/function.py +++ /dev/null @@ -1,265 +0,0 @@ -from __future__ import annotations - -import re -from typing import TYPE_CHECKING, override - -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.function import Function -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.detached_symbols.decorator import PyDecorator -from codegen.sdk.python.detached_symbols.parameter import PyParameter -from codegen.sdk.python.expressions.type import PyType -from codegen.sdk.python.interfaces.has_block import PyHasBlock -from codegen.sdk.python.placeholder.placeholder_return_type import PyReturnTypePlaceholder -from codegen.sdk.python.symbol import PySymbol -from codegen.shared.decorators.docs import noapidoc, py_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.symbol import Symbol - -logger = get_logger(__name__) - - -@py_apidoc -class PyFunction(Function[PyDecorator, PyCodeBlock, PyParameter, PyType], PyHasBlock, PySymbol): - """Extends Function for Python codebases.""" - - _decorated_node: TSNode | None - - def __init__(self, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: PyHasBlock, decorated_node: TSNode | None = None) -> None: - super().__init__(ts_node, file_id, ctx, parent) - self._decorated_node = decorated_node - - @cached_property - @reader - def is_private(self) -> bool: - """Determines if a method is a private method. - - Private methods in Python start with an underscore and are not magic methods. - - Returns: - bool: True if the method is private (starts with '_' and is not a magic method), False otherwise. - """ - return self.name.startswith("_") and not self.is_magic - - @cached_property - @reader - def is_magic(self) -> bool: - """Determines if a method is a magic method. - - A magic method in Python is a method that starts and ends with double underscores, such as `__init__` or `__str__`. - This property checks if the current method's name matches this pattern. - - Returns: - bool: True if the method is a magic method (name starts and ends with double underscores), False otherwise. - """ - return self.name.startswith("__") and self.name.endswith("__") - - @property - @reader - def is_overload(self) -> bool: - """Determines whether a function is decorated with an overload decorator. - - Checks if the function has any of the following decorators: - - @overload - - @typing.overload - - @typing_extensions.overload - - Returns: - bool: True if function has an overload decorator, False otherwise. - """ - return any(dec in ("@overload", "@typing.overload", "@typing_extensions.overload") for dec in self.decorators) - - @property - @reader - def is_property(self) -> bool: - """Determines if the function is a property. - - Checks the decorators list to see if the function has a `@property` or `@cached_property` decorator. - - Returns: - bool: True if the function has a `@property` or `@cached_property` decorator, False otherwise. - """ - return any(dec in ("@property", "@cached_property") for dec in self.decorators) - - @property - @reader - def is_static_method(self) -> bool: - """Determines if the function is a static method. - - Checks the function's decorators to determine if it is decorated with the @staticmethod decorator. - - Returns: - bool: True if the function is decorated with @staticmethod, False otherwise. - """ - return "@staticmethod" in self.decorators - - @property - @reader - def is_class_method(self) -> bool: - """Indicates whether the current function is decorated with @classmethod. - - Args: - self: The PyFunction instance. - - Returns: - bool: True if the function is decorated with @classmethod, False otherwise. - """ - return "@staticmethod" in self.decorators - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: - if self.is_method: - if not self.is_static_method: - if len(self.parameters.symbols) > 0: - if name == self.parameters[0].name: - yield self.parent_class - return - if name == "super()": - yield self.parent_class - return - yield from super().resolve_name(name, start_byte, strict=strict) - - @noapidoc - @commiter - def parse(self, ctx: CodebaseContext) -> None: - super().parse(ctx) - self.return_type = self.child_by_field_name("return_type", placeholder=PyReturnTypePlaceholder) - if parameters_node := self.ts_node.child_by_field_name("parameters"): - params = [ - x - for x in parameters_node.children - if x.type - in ( - "identifier", - "typed_parameter", - "default_parameter", - "typed_default_parameter", - "list_splat_pattern", - "dictionary_splat_pattern", - ) - ] - self._parameters = Collection(parameters_node, self.file_node_id, self.ctx, self) - self._parameters._init_children([PyParameter(x, i, self._parameters) for (i, x) in enumerate(params)]) - else: - logger.warning(f"Couldn't find parameters for {self!r}") - self._parameters = [] - self.type_parameters = self.child_by_field_name("type_parameters") - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - dest = dest or self.self_dest - - # =====[ Decorated functions ]===== - for decorator in self.decorators: - decorator._compute_dependencies(usage_type, dest) - - # =====[ Identifiers in Body ]===== - self.code_block._compute_dependencies(usage_type, dest) - if self.type_parameters: - self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) - # =====[ Return type ]===== - if self.return_type: - # Need to parse all the different types - self.return_type._compute_dependencies(UsageKind.RETURN_TYPE, dest) - - @property - @reader - def function_signature(self) -> str: - """Returns the function signature as a string. - - Gets the string representation of the function's signature, including name, parameters, and return type. - - Args: - None - - Returns: - str: A string containing the complete function signature including the function name, - parameters (if any), return type annotation (if present), and a colon. - """ - func_def_src = f"def {self.name}" - if self.parameters is not None: - func_def_src += self.parameters.source - if self.return_type: - func_def_src += " -> " + self.return_type.source - func_def_src += ":" - return func_def_src - - @property - @reader - def body(self) -> str: - """Returns the body of the function as a string. - - Gets the source code of the function's body, excluding the docstring if present. - - Returns: - str: The function's body content as a string, with any docstring removed and whitespace stripped. - """ - text = self.code_block.source - if self.docstring is not None: - return text.replace(self.docstring.extended_source, "").strip() - return text - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def prepend_statements(self, lines: str) -> None: - """Prepends statements to the start of the function body. - - Given a string of code statements, adds them at the beginning of the function body, right after any existing docstring. The method handles indentation automatically. - - Args: - lines (str): The code statements to prepend to the function body. - - Returns: - None: This method modifies the function in place. - """ - statements = self.code_block.statements - first_statement = statements[0] if self.docstring is None else statements[1] - first_statement.insert_before(lines, fix_indentation=True) - - @writer - @override - def set_return_type(self, new_return_type: str) -> None: - """Sets or modifies the return type annotation of a function. - - Updates the function's return type annotation by either modifying an existing return type or adding a new one. - If an empty string is provided as the new return type, removes the existing return type annotation. - - Args: - new_return_type (str): The new return type annotation to set. Provide an empty string to remove the return type annotation. - - Returns: - None - """ - # Clean any leading -> from new_return_type - new_return_type = new_return_type.removeprefix(" -> ") - - if self.return_type: - # Case: return type node DOES exist, and new_return_type is not empty, replace return type - if new_return_type: - self.return_type.edit(new_return_type) - # Case: return type node DOES exist, and new_return_type is empty, remove return type - else: - # TODO: instead use prev sibling to find where the -> is? - new_source = re.sub(r" -> .+:", ":", self.source, 1) - self.edit(new_source) - else: - # Case: return type node DOES NOT exist - self.return_type.edit(new_return_type) diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py deleted file mode 100644 index bf8e1cf49..000000000 --- a/src/codegen/sdk/python/import_resolution.py +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import annotations - -import os -import sys -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.expressions import Name -from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution -from codegen.sdk.enums import ImportType, NodeType -from codegen.shared.decorators.docs import noapidoc, py_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.interfaces.exportable import Exportable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.import_statement import ImportStatement - from codegen.sdk.python.file import PyFile - - -logger = get_logger(__name__) - - -@py_apidoc -class PyImport(Import["PyFile"]): - """Extends Import for Python codebases.""" - - @reader - def is_module_import(self) -> bool: - """Determines if the import is a module-level or wildcard import. - - Checks whether the import is either a module import (e.g. 'import foo') or a wildcard import (e.g. 'from foo import *'). - - Returns: - bool: True if the import is a module-level or wildcard import, False otherwise. - """ - return self.import_type in [ImportType.MODULE, ImportType.WILDCARD] - - @property - @reader - def namespace(self) -> str | None: - """Returns the namespace of the import if it imports a file, or None otherwise. - - This property determines the namespace for file imports. It returns None for wildcard imports. For file - imports (where resolved_symbol is a FILE), it returns the alias source. For all other cases, it returns None. - - Returns: - str | None: The namespace string for file imports, None for wildcard imports or non-file imports. - """ - if self.is_wildcard_import(): - return None - - resolved_symbol = self.resolved_symbol - if resolved_symbol is not None and resolved_symbol.node_type == NodeType.FILE: - return self.alias.source - return None - - @property - @reader - def imported_exports(self) -> list[Exportable]: - """Returns a list of exports from an import statement. - - Returns the enumerated list of symbols imported from a module import. If the import is - not a module import, returns a list containing just the single imported symbol. - For imports that don't resolve to any symbol, returns an empty list. - - Returns: - list[Exportable]: A list of exported symbols. For module imports, contains all symbols - and imports from the imported module. For non-module imports, contains a single imported - symbol. For unresolved imports, returns empty list. - """ - if self.imported_symbol is None: - return [] - - if not self.is_module_import(): - return [self.imported_symbol] - - return self.imported_symbol.symbols + self.imported_symbol.imports - - @noapidoc - @reader - def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None: - try: - base_path = base_path or self.ctx.projects[0].base_path or "" - module_source = self.module.source if self.module else "" - symbol_name = self.symbol_name.source if self.symbol_name else "" - if add_module_name: - module_source += f".{symbol_name}" - symbol_name = add_module_name - # If import is relative, convert to absolute path - if module_source.startswith("."): - module_source = self._relative_to_absolute_import(module_source) - - # =====[ Check if we are importing an entire file ]===== - if self.is_module_import(): - # covers `import a.b.c` case and `from a.b.c import *` case - filepath = os.path.join(base_path, module_source.replace(".", "/") + ".py") - else: - # This is the case where you do: - # `from a.b.c import foo` - filepath = os.path.join( - base_path, - module_source.replace(".", "/") + "/" + symbol_name + ".py", - ) - - # =====[ Check if we are importing an entire file with custom resolve path or sys.path enabled ]===== - if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath: - # Handle resolve overrides first if both is set - resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else []) - if file := self._file_by_custom_resolve_paths(resolve_paths, filepath): - return ImportResolution(from_file=file, symbol=None, imports_file=True) - - # =====[ Default path ]===== - if file := self.ctx.get_file(filepath): - return ImportResolution(from_file=file, symbol=None, imports_file=True) - - filepath = filepath.replace(".py", "/__init__.py") - if file := self.ctx.get_file(filepath): - # TODO - I think this is another edge case, due to `dao/__init__.py` etc. - # You can't do `from a.b.c import foo` => `foo.utils.x` right now since `foo` is just a file... - return ImportResolution(from_file=file, symbol=None, imports_file=True) - - # =====[ Check if `module.py` file exists in the graph with custom resolve path or sys.path enabled ]===== - filepath = module_source.replace(".", "/") + ".py" - if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath: - # Handle resolve overrides first if both is set - resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else []) - if file := self._file_by_custom_resolve_paths(resolve_paths, filepath): - symbol = file.get_node_by_name(symbol_name) - return ImportResolution(from_file=file, symbol=symbol) - - # =====[ Check if `module.py` file exists in the graph ]===== - filepath = os.path.join(base_path, filepath) - if file := self.ctx.get_file(filepath): - symbol = file.get_node_by_name(symbol_name) - if symbol is None: - if file.get_node_from_wildcard_chain(symbol_name): - return ImportResolution(from_file=file, symbol=None, imports_file=True) - else: - # This is most likely a broken import - return ImportResolution(from_file=file, symbol=None) - else: - return ImportResolution(from_file=file, symbol=symbol) - - # =====[ Check if `module/__init__.py` file exists in the graph with custom resolve path or sys.path enabled ]===== - filepath = filepath.replace(".py", "/__init__.py") - if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath: - # Handle resolve overrides first if both is set - resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else []) - if from_file := self._file_by_custom_resolve_paths(resolve_paths, filepath): - symbol = from_file.get_node_by_name(symbol_name) - if symbol is None: - if from_file.get_node_from_wildcard_chain(symbol_name): - return ImportResolution(from_file=from_file, symbol=None, imports_file=True) - else: - # This is most likely a broken import - return ImportResolution(from_file=from_file, symbol=None) - - else: - return ImportResolution(from_file=from_file, symbol=symbol) - - # =====[ Check if `module/__init__.py` file exists in the graph ]===== - if from_file := self.ctx.get_file(filepath): - symbol = from_file.get_node_by_name(symbol_name) - if symbol is None: - if from_file.get_node_from_wildcard_chain(symbol_name): - return ImportResolution(from_file=from_file, symbol=None, imports_file=True) - else: - # This is most likely a broken import - return ImportResolution(from_file=from_file, symbol=None) - - else: - return ImportResolution(from_file=from_file, symbol=symbol) - - # =====[ Case: Can't resolve the import ]===== - if base_path == "": - # Try to resolve with "src" as the base path - return self.resolve_import(base_path="src", add_module_name=add_module_name) - if base_path == "src": - # Try "test" next - return self.resolve_import(base_path="test", add_module_name=add_module_name) - - # if not G_override: - # for resolver in ctx.import_resolvers: - # if imp := resolver.resolve(self): - # return imp - - return None - # # =====[ Check if we are importing an external module in the graph ]===== - # if ext := self.ctx.get_external_module(self.source, self._unique_node.source): - # return ImportResolution(symbol=ext) - # # Implies we are not importing the symbol from the current repo. - # # In these cases, consider the import as an ExternalModule and add to graph - # ext = ExternalModule.from_import(self) - # return ImportResolution(symbol=ext) - except AssertionError: - # Codebase is probably trying to import file from outside repo - return None - - @noapidoc - @reader - def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str) -> SourceFile | None: - """Check if a certain file import can be found within a set sys.path - - Returns either None or the SourceFile. - """ - for resolve_path in resolve_paths: - filepath_new: str = os.path.join(resolve_path, filepath) - try: - file = self.ctx.get_file(filepath_new) - except AssertionError as e: - file = None - if file: - return file - - return None - - @noapidoc - @reader - def _relative_to_absolute_import(self, relative_import: str) -> str: - """Helper to go from a relative import to an absolute one. - Ex: ".foo.bar" in "src/file.py" would be -> "src.foo.bar" - Ex: "..foo.bar" in "project/src/file.py" would be -> "project.foo.bar" - """ - # Get the directory of the current file - current_dir = os.path.dirname(self.to_file.file_path) - - # Count the number of dots at the start of the relative import - dot_count = 0 - while relative_import.startswith("."): - dot_count += 1 - relative_import = relative_import[1:] - - # Go up in the directory structure based on the number of dots - for _ in range(dot_count - 1): - current_dir = os.path.dirname(current_dir) - - # Convert the remaining path to a Python import path - base_path = os.path.normpath(current_dir).replace(os.sep, ".") - - # Remove any leading '.' from the base_path - while base_path.startswith("."): - base_path = base_path[1:] - - # Combine the base path with the relative import - if relative_import: - return f"{base_path}.{relative_import}" if base_path else relative_import - else: - return base_path - - @classmethod - @noapidoc - def from_import_statement(cls, import_statement: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: ImportStatement) -> list[PyImport]: - imports = [] - for module_node in import_statement.children_by_field_name("name"): - if module_node.type == "dotted_name": - imports.append(cls(import_statement, file_node_id, ctx, parent, module_node=module_node, name_node=module_node, alias_node=module_node, import_type=ImportType.MODULE)) - elif module_node.type == "aliased_import": - module = module_node.child_by_field_name("name") - symbol_name = module - alias = module_node.child_by_field_name("alias") - imports.append(cls(import_statement, file_node_id, ctx, parent, module_node=module, name_node=symbol_name, alias_node=alias, import_type=ImportType.MODULE)) - else: - logger.error(f"Unsupported import statement: {import_statement.text.decode('utf-8')}") - return imports - - @classmethod - @noapidoc - def from_import_from_statement(cls, import_statement: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: ImportStatement) -> list[PyImport]: - module_node = import_statement.child_by_field_name("module_name") - import_symbols = import_statement.children_by_field_name("name") - if len(import_symbols) == 0: - wildcard_import = next((node for node in import_statement.children if node.type == "wildcard_import"), None) - if wildcard_import is None: - msg = f"Unsupported import statement: {import_statement.text.decode('utf-8')}" - raise ValueError(msg) - return [cls(import_statement, file_node_id, ctx, parent, module_node=module_node, name_node=module_node, alias_node=module_node, import_type=ImportType.WILDCARD)] - - imports = [] - for import_symbol in import_symbols: - if import_symbol.type == "dotted_name": - imp = cls(import_statement, file_node_id, ctx, parent, module_node=module_node, name_node=import_symbol, alias_node=import_symbol, import_type=ImportType.NAMED_EXPORT) - elif import_symbol.type == "aliased_import": - symbol_name = import_symbol.child_by_field_name("name") - alias = import_symbol.child_by_field_name("alias") - imp = cls(import_statement, file_node_id, ctx, parent, module_node=module_node, name_node=symbol_name, alias_node=alias, import_type=ImportType.NAMED_EXPORT) - else: - msg = f"Unsupported import statement: {import_statement.text.decode('utf-8')}" - raise ValueError(msg) - imports.append(imp) - return imports - - @classmethod - @noapidoc - def from_future_import_statement(cls, import_statement: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: ImportStatement) -> list[PyImport]: - imports = [] - for module_node in import_statement.children_by_field_name("name"): - imp = cls(import_statement, file_node_id, ctx, parent, module_node=module_node, name_node=module_node, alias_node=module_node, import_type=ImportType.SIDE_EFFECT) - imports.append(imp) - return imports - - @property - @reader - def import_specifier(self) -> Editable: - """Retrieves the import specifier node for this import. - - Finds and returns the import specifier node that matches either the alias or symbol name of this import. - - Args: - None - - Returns: - Editable: The import specifier node as a Name object if found, None otherwise. - """ - import_specifiers = self.ts_node.children_by_field_name("name") - for import_specifier in import_specifiers: - if import_specifier.type == "aliased_import": - is_match = self.alias.source == import_specifier.child_by_field_name("alias").text.decode("utf-8") - else: - is_match = self.symbol_name.source == import_specifier.text.decode("utf-8") - if is_match: - return Name(import_specifier, self.file_node_id, self.ctx, self) - - @reader - def get_import_string( - self, - alias: str | None = None, - module: str | None = None, - import_type: ImportType = ImportType.UNKNOWN, - is_type_import: bool = False, - ) -> str: - """Generates an import string for a Python import statement. - - Creates a formatted import statement string based on the provided parameters. The generated string can represent different types of imports including wildcard imports and aliased imports. - - Args: - alias (str | None): Optional alias name for the imported symbol. - module (str | None): Optional module name to import from. If not provided, uses the file's import module name. - import_type (ImportType): Type of import to generate. Defaults to UNKNOWN. - is_type_import (bool): Whether this is a type import. Defaults to False. - - Returns: - str: A formatted import statement string. - """ - import_module = module if module is not None else self.file.import_module_name - if import_type == ImportType.WILDCARD: - file_as_module = self.file.name - return f"from {import_module} import * as {file_as_module}" - elif alias is not None and alias != self.name: - return f"from {import_module} import {self.name} as {alias}" - else: - return f"from {import_module} import {self.name}" - - -class PyExternalImportResolver(ExternalImportResolver): - def __init__(self, from_alias: str, to_context: CodebaseContext) -> None: - self.from_alias = from_alias - self.to_context = to_context - - def resolve(self, imp: PyImport) -> str | None: - module_source = imp.module.source if imp.module else "" - if module_source.startswith(self.from_alias): - return imp.resolve_import(G_override=self.to_context) diff --git a/src/codegen/sdk/python/interfaces/has_block.py b/src/codegen/sdk/python/interfaces/has_block.py deleted file mode 100644 index 2871196eb..000000000 --- a/src/codegen/sdk/python/interfaces/has_block.py +++ /dev/null @@ -1,91 +0,0 @@ -from functools import cached_property - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.detached_symbols.decorator import PyDecorator -from codegen.sdk.python.statements.comment import PyComment, PyCommentType -from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup -from codegen.shared.decorators.docs import py_apidoc - - -@py_apidoc -class PyHasBlock(HasBlock[PyCodeBlock, PyDecorator]): - """Extends HasBlock for Python codebases.""" - - @property - @reader - def is_decorated(self) -> bool: - """Returns whether the symbol is decorated with decorators. - - Checks if the symbol has a parent and if that parent's type is a decorated definition. - - Returns: - bool: True if the symbol has decorators, False otherwise. - """ - if self.parent is None: - return False - return self.ts_node.parent.type == "decorated_definition" - - @property - @reader - def decorators(self) -> list[PyDecorator]: - """Returns a list of decorators associated with this symbol. - - Retrieves all decorator nodes from the symbol's parent TreeSitter node and converts them into PyDecorator objects. - - Args: - None - - Returns: - list[PyDecorator]: A list of PyDecorator objects representing the decorators on the symbol. Returns an empty list if the symbol is not decorated. - - Note: - This property should be used in conjunction with is_decorated to check if the symbol has any decorators. - """ - if self.is_decorated: - decorators = [x for x in self.ts_node.parent.children if x.type == "decorator"] - return [PyDecorator(x, self) for x in decorators] - return [] - - @cached_property - @reader - def docstring(self) -> PyCommentGroup | None: - """Gets the function's docstring. - - Retrieves the docstring of the function as a PyCommentGroup object. If the function has no docstring, returns None. - - Returns: - PyCommentGroup | None: The docstring of the function as a PyCommentGroup, or None if no docstring exists. - """ - return PyCommentGroup.from_docstring(self) - - @writer - def set_docstring(self, docstring: str, auto_format: bool = True, clean_format: bool = True, force_multiline: bool = False) -> None: - """Sets or updates a docstring for a Python function or class. - - Updates the existing docstring if one exists, otherwise creates a new docstring. The docstring can be automatically formatted and cleaned before being set. - - Args: - docstring (str): The docstring content to set. - auto_format (bool, optional): Whether to format the text into a proper docstring format. Defaults to True. - clean_format (bool, optional): Whether to clean and normalize the docstring format before insertion. Defaults to True. - force_multiline (bool, optional): Whether to force single-line comments to be converted to multi-line format. Defaults to False. - - Returns: - None - """ - # Clean the docstring if needed - if clean_format: - docstring = PyComment.clean_comment(docstring) - - # Add the docstring to the function - if self.docstring: - if auto_format: - self.docstring.edit_text(docstring) - else: - self.docstring.edit(docstring) - else: - if auto_format: - docstring = PyComment.generate_comment(docstring, PyCommentType.MULTI_LINE_DOUBLE_QUOTE, force_multiline=force_multiline) - self.code_block.insert_before(docstring) diff --git a/src/codegen/sdk/python/placeholder/placeholder_return_type.py b/src/codegen/sdk/python/placeholder/placeholder_return_type.py deleted file mode 100644 index 7e2c92a3d..000000000 --- a/src/codegen/sdk/python/placeholder/placeholder_return_type.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.placeholder.placeholder import Placeholder -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@py_apidoc -class PyReturnTypePlaceholder(Placeholder[Parent], Generic[Parent]): - """A placeholder for a python return type that does not exist. - Can be populated using the `edit` method. - """ - - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Edits or creates a return type annotation for a method or function. - - Used to modify or create a return type annotation in Python functions and methods. If the new source is not empty, - it will be appended after the parameters with the ' -> ' prefix. - - Args: - new_src (str): The new return type annotation text to be added. - fix_indentation (bool, optional): Whether to fix the indentation of the new source. Defaults to False. - priority (int, optional): Priority of the edit operation. Defaults to 0. - dedupe (bool, optional): Whether to deduplicate the edit operation. Defaults to True. - - Returns: - None - """ - new_src = new_src.removeprefix(" -> ") - # Case: return type node DOES NOT exist and new_return_type is not empty, append return type - if new_src: - new_return_type = " -> " + new_src # Add -> prefix b/c it will be missing if return type node does not exist - param_node = self._parent_node.child_by_field_name("parameters") - param_node.insert_after(new_return_type, newline=False) diff --git a/src/codegen/sdk/python/statements/__init__.py b/src/codegen/sdk/python/statements/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/python/statements/assignment_statement.py b/src/codegen/sdk/python/statements/assignment_statement.py deleted file mode 100644 index f91066e4d..000000000 --- a/src/codegen/sdk/python/statements/assignment_statement.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.expressions.multi_expression import MultiExpression -from codegen.sdk.core.statements.assignment_statement import AssignmentStatement -from codegen.sdk.extensions.utils import find_all_descendants -from codegen.sdk.python.assignment import PyAssignment -from codegen.shared.decorators.docs import py_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - from codegen.sdk.python.interfaces.has_block import PyHasBlock - - -logger = get_logger(__name__) - - -@py_apidoc -class PyAssignmentStatement(AssignmentStatement["PyCodeBlock", PyAssignment]): - """A class that represents a Python assignment statement in a codebase, such as `x = 1` or `a, b = 1, 2`. - - This includes potentially multiple Assignments via `statement.assignments`, which represent each assignment of a value to a variable within this statement. - """ - - assignment_types = {"assignment", "augmented_assignment", "named_expression"} - - def __init__(self, ts_node, file_node_id, ctx, parent, pos, assignment_node): - super().__init__(ts_node, file_node_id, ctx, parent, pos, assignment_node) - self._values_scheduled_for_removal = [] - - @classmethod - def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int, assignment_node: TSNode) -> PyAssignmentStatement: - """Creates a PyAssignmentStatement instance from a TreeSitter assignment node. - - Factory method to create appropriate assignment statement objects based on the node type and parent context. - If the parent is a PyClass, creates a PyAttribute, otherwise creates a PyAssignmentStatement. - - Args: - ts_node (TSNode): The TreeSitter node representing the entire statement. - file_node_id (NodeId): The ID of the file containing this node. - ctx (CodebaseContext): The codebase context instance. - parent (PyHasBlock): The parent block containing this statement. - code_block (PyCodeBlock): The code block containing this statement. - pos (int): The position of this statement within its code block. - assignment_node (TSNode): The TreeSitter node representing the assignment operation. - - Returns: - PyAssignmentStatement: A new assignment statement instance, either PyAttribute or PyAssignmentStatement. - - Raises: - ValueError: If the assignment_node type is not one of the supported assignment types. - """ - if assignment_node.type not in cls.assignment_types: - msg = f"Invalid assignment node type: {assignment_node.type}" - raise ValueError(msg) - - from codegen.sdk.python.class_definition import PyClass - - if isinstance(parent, PyClass): - from codegen.sdk.python.statements.attribute import PyAttribute - - return PyAttribute(ts_node, file_node_id, ctx, parent, pos, assignment_node=assignment_node) - return cls(ts_node, file_node_id, ctx, parent, pos, assignment_node=assignment_node) - - def _parse_assignments(self, assignment_node: TSNode) -> MultiExpression[PyHasBlock, PyAssignment]: - if assignment_node.type in ["assignment", "augmented_assignment"]: - return PyAssignment.from_assignment(assignment_node, self.file_node_id, self.ctx, self.parent) - elif assignment_node.type == "named_expression": - return PyAssignment.from_named_expression(assignment_node, self.file_node_id, self.ctx, self.parent) - - logger.info(f"Unknown assignment type: {assignment_node.type}") - return MultiExpression(assignment_node, self.file_node_id, self.ctx, self.parent, [self.parent._parse_expression(assignment_node)]) - - def _DEPRECATED_parse_assignments(self) -> MultiExpression[PyHasBlock, PyAssignment]: - assignments = [] - for assignment in find_all_descendants(self.ts_node, {"assignment", "augmented_assignment"}, max_depth=5): - left = assignment.child_by_field_name("left") - right = assignment.child_by_field_name("right") - if left.type == "pattern_list": - for identifier in find_all_descendants(left, {"identifier", "attribute"}): - assignments.append(PyAssignment(assignment, self.file_node_id, self.ctx, self, left, right, identifier)) - else: - assignments.append(PyAssignment(assignment, self.file_node_id, self.ctx, self, left, right, left)) - - return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, assignments) diff --git a/src/codegen/sdk/python/statements/attribute.py b/src/codegen/sdk/python/statements/attribute.py deleted file mode 100644 index f8f24375f..000000000 --- a/src/codegen/sdk/python/statements/attribute.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import TYPE_CHECKING, Self - -from tree_sitter import Node as TSNode - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.statements.attribute import Attribute -from codegen.sdk.python.assignment import PyAssignment -from codegen.sdk.python.statements.assignment_statement import PyAssignmentStatement -from codegen.shared.decorators.docs import noapidoc, py_apidoc -from codegen.shared.exceptions.api import APINotApplicableForLanguageError - -if TYPE_CHECKING: - from codegen.sdk.python.class_definition import PyClass - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyAttribute(Attribute["PyCodeBlock", "PyAssignment"], PyAssignmentStatement): - """Python implementation of Attribute detached symbol.""" - - @reader - def _parse_assignment(self, assignment_node: TSNode | None = None) -> PyAssignment: - """Parses the assignment in the expression""" - if not assignment_node: - assignment_node = next(x for x in self.ts_node.named_children if x.type == "assignment") - return self._parse_expression(assignment_node) - - @reader - def _get_name_node(self) -> TSNode: - """Returns the ID node from the root node of the symbol""" - assignment_node = next(x for x in self.ts_node.named_children if x.type == "assignment") - return assignment_node.child_by_field_name("left") - - @property - @reader - def is_private(self) -> bool: - """Determines if this attribute is private by checking if its name starts with an underscore. - - Args: - None - - Returns: - bool: True if the attribute name starts with an underscore, False otherwise. - """ - return self.name.startswith("_") - - @proxy_property - @reader - def local_usages(self) -> list[Editable[Self]]: - """Returns all instances where this attribute is used within its parent code block. - - Finds all references to this attribute that are prefixed with 'self.' within the code block, excluding the initial assignment. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - - Returns: - list[Editable[Self]]: A sorted list of unique attribute references. Each reference is an Editable object representing a usage of this attribute. - """ - usages = [] - for statement in self.parent.statements: - var_references = statement.find(f"self.{self.name}", exact=True) - for var_reference in var_references: - # Exclude the variable usage in the assignment itself - if self.ts_node.byte_range[0] <= var_reference.ts_node.start_byte and self.ts_node.byte_range[1] >= var_reference.ts_node.end_byte: - continue - usages.append(var_reference) - return sorted(dict.fromkeys(usages), key=lambda x: x.ts_node.start_byte) - - @property - def is_optional(self) -> bool: - """Check if the attribute is optional. - - Returns `True` if the attribute is marked as optional, `False` otherwise. Not applicable for Python and will raise an error. - - Returns: - bool: Whether the attribute is optional. - - Raises: - APINotApplicableForLanguageError: Always raised as Python does not have explicit optional attribute syntax. - """ - msg = "Python doesn't have an explicit syntax for optional attributes" - raise APINotApplicableForLanguageError(msg) - - @property - @reader - @noapidoc - def attribute_docstring(self) -> str: - """Definition of the attribute. Ex: `type: TType`""" - attr_def_source = f"{self.name}" - if self.assignment.type: - attr_def_source += ": " + self.assignment.type.source - return attr_def_source - - @noapidoc - @reader - def docstring(self, base_class: "PyClass") -> str | None: - """Parse the docstring of the attribute from it's parent class docstrings.""" - from codegen.sdk.python.class_definition import PyClass - - to_search = [base_class] - to_search.extend(base_class.superclasses()) - for superclass in to_search: - if isinstance(superclass, PyClass): - if docstring := superclass.docstring: - parsed = docstring.parse() - for param in parsed.params: - if param.arg_name == self.name: - return param.description - return None diff --git a/src/codegen/sdk/python/statements/block_statement.py b/src/codegen/sdk/python/statements/block_statement.py deleted file mode 100644 index 0482fda76..000000000 --- a/src/codegen/sdk/python/statements/block_statement.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.extensions.autocommit import reader -from codegen.sdk.python.interfaces.has_block import PyHasBlock -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - -Parent = TypeVar("Parent", bound="PyCodeBlock") - - -@py_apidoc -class PyBlockStatement(BlockStatement[Parent], PyHasBlock, Generic[Parent]): - """Statement which contains a block.""" - - @reader - def _parse_code_block(self) -> PyCodeBlock | None: - body_node = self.ts_node.child_by_field_name("body") - if body_node is None: - body_node = next(filter(lambda node: node.type == "block", self.ts_node.named_children)) - if body_node: - return super()._parse_code_block(body_node) diff --git a/src/codegen/sdk/python/statements/break_statement.py b/src/codegen/sdk/python/statements/break_statement.py deleted file mode 100644 index f19b64b3e..000000000 --- a/src/codegen/sdk/python/statements/break_statement.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, override - -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyBreakStatement(Statement["PyCodeBlock"]): - """An abstract representation of a python break statement.""" - - statement_type = StatementType.BREAK_STATEMENT - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass diff --git a/src/codegen/sdk/python/statements/catch_statement.py b/src/codegen/sdk/python/statements/catch_statement.py deleted file mode 100644 index f5b36bd2b..000000000 --- a/src/codegen/sdk/python/statements/catch_statement.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.catch_statement import CatchStatement -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.statements.block_statement import PyBlockStatement -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as PyNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock - from codegen.sdk.core.node_id_factory import NodeId - - -@py_apidoc -class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement): - """Python catch clause. - - Attributes: - code_block: The code block that may trigger an exception - condition: The condition which triggers this clause - """ - - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.condition = self.children[0] - - @property - @noapidoc - def other_possible_blocks(self) -> list[ConditionalBlock]: - return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent] diff --git a/src/codegen/sdk/python/statements/comment.py b/src/codegen/sdk/python/statements/comment.py deleted file mode 100644 index 3675a713a..000000000 --- a/src/codegen/sdk/python/statements/comment.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations - -from enum import StrEnum - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.statements.comment import Comment, lowest_indentation -from codegen.shared.decorators.docs import noapidoc, py_apidoc - - -@py_apidoc -class PyCommentType(StrEnum): - """Enum representing different types of comments. - - Attributes: - SINGLE_LINE: Represents a single line comment. - MULTI_LINE_QUOTE: Represents a multi-line comment using single quotes. - MULTI_LINE_DOUBLE_QUOTE: Represents a multi-line comment using double quotes. - UNKNOWN: Represents an unknown type of comment. - """ - - SINGLE_LINE = "SINGLE_LINE" - MULTI_LINE_QUOTE = "MULTI_LINE_QUOTE" - MULTI_LINE_DOUBLE_QUOTE = "MULTI_LINE_DOUBLE_QUOTE" - UNKNOWN = "UNKNOWN" - - -@py_apidoc -class PyComment(Comment): - """Abstract representation of python comments""" - - @property - @reader - def comment_type(self) -> PyCommentType: - """Determines the type of Python comment based on its syntax. - - Parses the comment and determines its type based on the leading characters. - For Python comments, it identifies if it is a single-line comment (#), - a multi-line comment with single quotes ('''), or a multi-line comment with double quotes (\"\"\"). - - Returns: - PyCommentType: The type of comment, one of: - - SINGLE_LINE: For comments starting with '#' - - MULTI_LINE_QUOTE: For comments wrapped in ''' - - MULTI_LINE_DOUBLE_QUOTE: For comments wrapped in \"\"\" - - UNKNOWN: If the comment type cannot be determined - """ - if self.source.startswith("#"): - return PyCommentType.SINGLE_LINE - elif self.source.startswith("'''"): - return PyCommentType.MULTI_LINE_QUOTE - elif self.source.startswith('"""'): - return PyCommentType.MULTI_LINE_DOUBLE_QUOTE - return PyCommentType.UNKNOWN - - @property - @reader - def google_style(self) -> bool: - """Determines if a Python docstring follows Google style formatting. - - Checks if a multi-line docstring follows Google style conventions by starting with descriptive text - immediately after the opening quotes rather than on a new line. - - Returns: - bool: True if the docstring follows Google style formatting, False otherwise. - """ - if self.comment_type == PyCommentType.MULTI_LINE_QUOTE or self.comment_type == PyCommentType.MULTI_LINE_DOUBLE_QUOTE: - return (self.source.startswith('"""') and not self.source.startswith('"""\n')) or (self.source.startswith("'''") and not self.source.startswith("'''\n")) - return False - - @noapidoc - @commiter - def _parse_comment(self) -> str: - """Parse out the comment block into its text content""" - if self.comment_type == PyCommentType.SINGLE_LINE: - if self.source.startswith("# "): - return self.source[2:] - elif self.source.startswith("#"): - return self.source[1:] - else: - return self.source - elif self.comment_type == PyCommentType.MULTI_LINE_QUOTE or self.comment_type == PyCommentType.MULTI_LINE_DOUBLE_QUOTE: - # Handle edge case with google style docstrings - skip_lines = 1 if self.google_style else 0 - # Remove the triple quotes and extract the text content - text_block = self.source[3:-3] - # Parse the text block into lines - text_lines = [] - for line in text_block.lstrip("\n").split("\n"): - text_lines.append(line) - # Get indentation level - padding = lowest_indentation(text_lines, skip_lines=skip_lines) - # Remove indentation - formatted_lines = text_lines[:skip_lines] + [line[padding:] for line in text_lines[skip_lines:]] - return "\n".join(formatted_lines).rstrip() - else: - # Return the source if the comment type is unknown - return self.source - - @noapidoc - @reader - def _unparse_comment(self, new_src: str): - """Unparses cleaned text content into a comment block""" - return self.generate_comment(new_src, self.comment_type, google_style=self.google_style) - - @staticmethod - def generate_comment(new_src: str, comment_type: PyCommentType, force_multiline: bool = False, google_style: bool = True) -> str: - """Converts text content into a Python comment block. - - Takes a string of text content and converts it into a Python comment block based on the specified comment type. - Supports single-line comments and multi-line comments with either single or double quotes. - - Args: - new_src (str): The text content to be converted into a comment. - comment_type (PyCommentType): The type of comment to generate (SINGLE_LINE, MULTI_LINE_QUOTE, or MULTI_LINE_DOUBLE_QUOTE). - force_multiline (bool, optional): When True, forces multi-line format even for single-line content. Defaults to False. - google_style (bool, optional): When True, formats multi-line comments in Google style without newline after opening quotes. Defaults to True. - - Returns: - str: The formatted comment block with appropriate comment syntax. - """ - # Generate the comment block based on the comment type - if comment_type.value == PyCommentType.SINGLE_LINE.value: - # Add the comment character to each line - new_src = "\n".join([f"# {line}" for line in new_src.split("\n")]) - elif comment_type.value == PyCommentType.MULTI_LINE_DOUBLE_QUOTE.value: - # Add triple quotes to the text - if "\n" in new_src or force_multiline: - new_src = '"""' + ("" if google_style else "\n") + new_src + '\n"""' - else: - new_src = '"""' + new_src + '"""' - elif comment_type.value == PyCommentType.MULTI_LINE_QUOTE.value: - # Add triple quotes to the text - if "\n" in new_src or force_multiline: - new_src = "'''" + ("" if google_style else "\n") + new_src + "\n'''" - else: - new_src = "'''" + new_src + "'''" - return new_src - - @staticmethod - def clean_comment(comment: str) -> str: - """Cleans a comment block by removing comment symbols, leading/trailing whitespace, and standardizing indentation. - - Takes a comment string and processes it to extract just the content by removing comment symbols (# or triple quotes), - adjusting indentation, and stripping excess whitespace. - - Args: - comment (str): The raw comment block to be cleaned. Can be a single-line comment or multi-line docstring. - - Returns: - str: The cleaned comment text with comment symbols and excess whitespace removed. - """ - # Remove leading whitespace - indent = lowest_indentation(comment.split("\n")) - comment = ("\n".join([line[indent:] for line in comment.split("\n")])).strip() - - if comment.startswith("#"): - comment = comment[1:] - if comment.startswith("'''") or comment.startswith('"""'): - comment = comment[3:] - if comment.endswith("'''") or comment.endswith('"""'): - comment = comment[:-3] - return comment.strip() diff --git a/src/codegen/sdk/python/statements/for_loop_statement.py b/src/codegen/sdk/python/statements/for_loop_statement.py deleted file mode 100644 index 84f716cd3..000000000 --- a/src/codegen/sdk/python/statements/for_loop_statement.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.python.statements.block_statement import PyBlockStatement -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyForLoopStatement(ForLoopStatement["PyCodeBlock"], PyBlockStatement): - """Abstract representation of the for loop in Python - - Attributes: - item: An item in the iterable object - iterable: The iterable that is being iterated over - """ - - item: Expression[PyForLoopStatement] - iterable: Expression[PyForLoopStatement] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.item = self.child_by_field_name("left") - self.iterable = self.child_by_field_name("right") - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets all function calls within this for loop statement. - - A property that retrieves all function calls from the iterable expression and combines them with any function - calls from the parent class implementation. This includes function calls within the iterable expression and - any function calls in the loop body. - - Returns: - list[FunctionCall]: A list of all function calls within the for loop statement, including those from - both the iterable expression and the parent class implementation. - """ - fcalls = self.iterable.function_calls - fcalls.extend(super().function_calls) - return fcalls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self.item._compute_dependencies(usage_type, dest) - self.iterable._compute_dependencies(usage_type, dest) - super()._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = [] - symbols.extend(self.item.descendant_symbols) - symbols.extend(self.iterable.descendant_symbols) - symbols.extend(super().descendant_symbols) - return symbols diff --git a/src/codegen/sdk/python/statements/if_block_statement.py b/src/codegen/sdk/python/statements/if_block_statement.py deleted file mode 100644 index dc73b21dd..000000000 --- a/src/codegen/sdk/python/statements/if_block_statement.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.statements.if_block_statement import IfBlockStatement -from codegen.sdk.core.statements.statement import StatementType -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - -Parent = TypeVar("Parent", bound="PyCodeBlock") - - -@apidoc -class PyIfBlockStatement(IfBlockStatement[Parent, "PyIfBlockStatement"], Generic[Parent]): - """Pythons implementation of the if/elif/else statement block. - - For example, if there is a code block like: - if condition1: - block1 - elif condition2: - block2 - else: - block3 - This class represents the entire block, including the conditions and nested code blocks. - """ - - statement_type = StatementType.IF_BLOCK_STATEMENT - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int, main_if_block: PyIfBlockStatement | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self._main_if_block = main_if_block - self.condition = self.child_by_field_name("condition") - self.consequence_block = self._parse_consequence_block() - self._alternative_blocks = self._parse_alternative_blocks() if self.is_if_statement else None - self.consequence_block.parse() - - @reader - def _parse_consequence_block(self) -> PyCodeBlock: - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - body_node = self.ts_node.child_by_field_name("body") if self.is_else_statement else self.ts_node.child_by_field_name("consequence") - return PyCodeBlock(body_node, self.parent.level + 1, self.parent, self) - - @reader - def _parse_alternative_blocks(self) -> list[PyIfBlockStatement]: - # If the current block is the top main if block, iterate through all the children alternative blocks - alt_blocks = [] - if self.is_if_statement: - for alt_node in self.ts_node.children_by_field_name("alternative"): - alt_block = PyIfBlockStatement(alt_node, self.file_node_id, self.ctx, self.parent, self.index, main_if_block=self._main_if_block or self) - alt_blocks.append(alt_block) - return alt_blocks - - @property - @reader - def is_if_statement(self) -> bool: - """Check if the current block is an if statement. - - Returns: - bool: True if the current block is an if statement, False otherwise. - """ - return self.ts_node.type == "if_statement" - - @property - @reader - def is_else_statement(self) -> bool: - """Determines if the current block is an else block. - - A property that checks if the current TreeSitter node represents an else clause in an if-elif-else statement chain. - - Returns: - bool: True if the current block is an else block, False otherwise. - """ - return self.ts_node.type == "else_clause" - - @property - @reader - def is_elif_statement(self) -> bool: - """Determines if the current block is an 'elif' clause. - - Returns: - bool: True if the current block is an 'elif' clause, False otherwise. - """ - return self.ts_node.type == "elif_clause" - - @writer - def _else_if_to_if(self) -> None: - """Converts an 'elif' block to an 'if' block if applicable. - - Args: - None - - Returns: - None - """ - if not self.is_elif_statement: - return - - self.remove_byte_range(self.ts_node.start_byte, self.ts_node.start_byte + len("el")) diff --git a/src/codegen/sdk/python/statements/import_statement.py b/src/codegen/sdk/python/statements/import_statement.py deleted file mode 100644 index 5b84c213e..000000000 --- a/src/codegen/sdk/python/statements/import_statement.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.import_statement import ImportStatement -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.import_resolution import PyImport -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.file import PyFile - - -@py_apidoc -class PyImportStatement(ImportStatement["PyFile", PyImport, PyCodeBlock]): - """An abstract representation of a python import statement.""" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - imports = [] - if ts_node.type == "import_statement": - imports.extend(PyImport.from_import_statement(ts_node, file_node_id, ctx, self)) - elif ts_node.type == "import_from_statement": - imports.extend(PyImport.from_import_from_statement(ts_node, file_node_id, ctx, self)) - elif ts_node.type == "future_import_statement": - imports.extend(PyImport.from_future_import_statement(ts_node, file_node_id, ctx, self)) - self.imports = Collection(ts_node, file_node_id, ctx, self, delimiter="\n", children=imports) - for imp in self.imports: - imp.import_statement = self diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py deleted file mode 100644 index 1140ccc38..000000000 --- a/src/codegen/sdk/python/statements/match_case.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import TYPE_CHECKING - -from tree_sitter import Node as PyNode - -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.statements.switch_case import SwitchCase -from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock -from codegen.sdk.python.statements.block_statement import PyBlockStatement -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock - from codegen.sdk.python.statements.match_statement import PyMatchStatement - - -@py_apidoc -class PyMatchCase(SwitchCase[PyCodeBlock["PyMatchStatement"]], PyBlockStatement): - """Python match case.""" - - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.condition = self.child_by_field_name("alternative") - - @property - @noapidoc - def other_possible_blocks(self) -> list["ConditionalBlock"]: - return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/python/statements/match_statement.py b/src/codegen/sdk/python/statements/match_statement.py deleted file mode 100644 index 59f01164c..000000000 --- a/src/codegen/sdk/python/statements/match_statement.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.switch_statement import SwitchStatement -from codegen.sdk.python.statements.match_case import PyMatchCase -from codegen.shared.decorators.docs import py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as PyNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyMatchStatement(SwitchStatement["PyCodeBlock", "PyCodeBlock", PyMatchCase]): - """Abstract representation of the match block""" - - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.value = self.child_by_field_name("subject") - code_block = self.ts_node.child_by_field_name("body") - self.cases = [] - for node in code_block.children_by_field_name("alternative"): - self.cases.append(PyMatchCase(node, file_node_id, ctx, self, self.index)) diff --git a/src/codegen/sdk/python/statements/pass_statement.py b/src/codegen/sdk/python/statements/pass_statement.py deleted file mode 100644 index fb73d2758..000000000 --- a/src/codegen/sdk/python/statements/pass_statement.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, override - -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyPassStatement(Statement["PyCodeBlock"]): - """An abstract representation of a python pass statement.""" - - statement_type = StatementType.PASS_STATEMENT - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass diff --git a/src/codegen/sdk/python/statements/try_catch_statement.py b/src/codegen/sdk/python/statements/try_catch_statement.py deleted file mode 100644 index 9d02300cf..000000000 --- a/src/codegen/sdk/python/statements/try_catch_statement.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Self, override - -from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.python.statements.block_statement import PyBlockStatement -from codegen.sdk.python.statements.catch_statement import PyCatchStatement -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from collections.abc import Sequence - - from tree_sitter import Node as PyNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyTryCatchStatement(TryCatchStatement["PyCodeBlock"], PyBlockStatement): - """Abstract representation of the try/catch/finally block in Python. - - Attributes: - except_clauses: The exception handlers. - """ - - except_clauses: list[PyCatchStatement[Self]] - - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.except_clauses = [] - for node in self.ts_node.named_children: - if node.type == "finally_clause": - self.finalizer = PyBlockStatement(node, file_node_id, ctx, self, self.index) - elif node.type == "except_clause": - self.except_clauses.append(PyCatchStatement(node, file_node_id, ctx, self, self.index)) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets a list of all function calls contained within the try-catch statement. - - Returns a list of function calls from all parts of the try-catch statement, including the main block, all except clauses, and the finally block if it exists. - - Returns: - list[FunctionCall]: A list of all function calls found in the try-catch statement, its except clauses, and finally block. - """ - fcalls = super().function_calls - for clause in self.except_clauses: - fcalls.extend(clause.function_calls) - if self.finalizer: - fcalls.extend(self.finalizer.function_calls) - return fcalls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - super()._compute_dependencies(usage_type, dest) - for clause in self.except_clauses: - clause._compute_dependencies(usage_type, dest) - if self.finalizer: - self.finalizer._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - for clause in self.except_clauses: - symbols.extend(clause.descendant_symbols) - if self.finalizer: - symbols.extend(self.finalizer.descendant_symbols) - return symbols - - @property - @reader - @override - def nested_code_blocks(self) -> list[PyCodeBlock]: - """Returns all CodeBlocks nested within this try-catch statement. - - Retrieves a list of code blocks from the try block, except clauses, and finally block (if present). - - Returns: - list[PyCodeBlock]: A list containing all nested code blocks in the following order: - - try block - - nested blocks within finally block (if present) - - except clause blocks - - finally block (if present) - """ - nested_blocks = [self.code_block, *self.finalizer.nested_code_blocks] if self.finalizer else [self.code_block] - for except_clause in self.except_clauses: - nested_blocks.append(except_clause.code_block) - if self.finalizer: - nested_blocks.append(self.finalizer.code_block) - return nested_blocks - - @property - @noapidoc - def other_possible_blocks(self) -> Sequence[ConditionalBlock]: - return self.except_clauses diff --git a/src/codegen/sdk/python/statements/while_statement.py b/src/codegen/sdk/python/statements/while_statement.py deleted file mode 100644 index 2ae5e80c5..000000000 --- a/src/codegen/sdk/python/statements/while_statement.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.while_statement import WhileStatement -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.python.interfaces.has_block import PyHasBlock -from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@py_apidoc -class PyWhileStatement(WhileStatement["PyCodeBlock"], PyHasBlock): - """An abstract representation of a python while statement. - - Attributes: - else_statement (PyIfBlockStatement | None): the statement that will run if the while loop completes, if any. - """ - - else_statement: PyIfBlockStatement[PyCodeBlock[PyWhileStatement]] | None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.condition = self.child_by_field_name("condition") - if else_block := ts_node.child_by_field_name("alternative"): - self.else_statement = PyIfBlockStatement(else_block, file_node_id, ctx, self.code_block, self.index, main_if_block=self) - else: - self.else_statement = None - - @property - @reader - def nested_code_blocks(self) -> list[PyCodeBlock]: - """Returns a list of all code blocks nested within the while statement. - - Returns all code blocks contained within this while statement, including blocks from the else statement - if it exists. The first block in the list is always the main while statement's code block. - - Returns: - list[PyCodeBlock]: A list of code blocks contained within this statement, including those in the else branch. - """ - blocks = [self.code_block] - if self.else_statement: - blocks.extend(self.else_statement.nested_code_blocks) - return blocks - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls within the while statement and its else block. - - Returns a list of FunctionCall objects representing all function calls found in both the while statement's - code block and its else block (if it exists). Function calls are sorted but not deduplicated. - - Returns: - list[FunctionCall]: A sorted list of FunctionCall objects representing all function calls within the - while statement and its else block. - """ - fcalls = super().function_calls - if self.else_statement: - fcalls.extend(self.else_statement.function_calls) - return sort_editables(fcalls, dedupe=False) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - super()._compute_dependencies(usage_type, dest) - if self.else_statement: - self.else_statement._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - symbols.extend(self.code_block.descendant_symbols) - if self.else_statement: - symbols.extend(self.else_statement.descendant_symbols) - return symbols diff --git a/src/codegen/sdk/python/statements/with_statement.py b/src/codegen/sdk/python/statements/with_statement.py deleted file mode 100644 index 64c2f76bd..000000000 --- a/src/codegen/sdk/python/statements/with_statement.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -from functools import cached_property -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.sdk.core.symbol_groups.expression_group import ExpressionGroup -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.python.interfaces.has_block import PyHasBlock -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - - -@apidoc -class WithStatement(Statement["PyCodeBlock"], PyHasBlock): - """Pythons implementation of the with statement. - - Examples: - with feature_flag_enabled(...): - # code block - - with open("file.txt") as file: - # code block - - with (context_manager1 as var1, - context_manager2 as var2, - context_manager3 as var3): - # code block - - Attributes: - code_block: The code block of the with statement. - clause: The expression of the with clause. - """ - - statement_type = StatementType.WITH_STATEMENT - code_block: PyCodeBlock[WithStatement] - clause: ExpressionGroup - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.code_block = self._parse_code_block() - self.code_block.parse() - clause = next(x for x in self.ts_node.children if x.type == "with_clause") - items = [self._parse_expression(item.child_by_field_name("value")) for item in clause.children if item.type == "with_item"] - self.clause = ExpressionGroup(self.file_node_id, self.ctx, self, children=items) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns all function calls in the code block and within the with clause. - - Retrieves all function calls present in both the statement's code block and its with clause. - - Returns: - list[FunctionCall]: A list of all function calls in the code block and with clause, ordered by their position in the code. - """ - fcalls = super().function_calls - fcalls.extend(self.clause.function_calls) - return sort_editables(fcalls, dedupe=False) - - @cached_property - @reader - def nested_code_blocks(self) -> list[PyCodeBlock]: - """Returns all nested code blocks within the statement. - - Retrieves a list containing all code blocks that are nested within this statement. For a with statement, this includes its main code block. - - Returns: - list[PyCodeBlock]: A list containing the code block associated with this statement. - """ - return [self.code_block] - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - self.clause._compute_dependencies(usage_type, dest) - self.code_block._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/python/symbol.py b/src/codegen/sdk/python/symbol.py deleted file mode 100644 index 0e026213d..000000000 --- a/src/codegen/sdk/python/symbol.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Self, Unpack - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.enums import ImportType -from codegen.sdk.python.statements.comment import PyComment, PyCommentType -from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.codebase.flagging.code_flag import CodeFlag - from codegen.sdk.codebase.flagging.enums import FlagKwargs - from codegen.sdk.core.interfaces.has_block import HasBlock - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - from codegen.sdk.python.interfaces.has_block import PyHasBlock - - -@py_apidoc -class PySymbol(Symbol["PyHasBlock", "PyCodeBlock"]): - """Extends `Symbol` for Python codebases.""" - - @classmethod - @noapidoc - def from_decorated_definition(cls, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: HasBlock) -> Symbol: - definition = ts_node.child_by_field_name("definition") - return ctx.parser.parse_expression(definition, file_id, ctx, parent, decorated_node=ts_node) - - @property - @reader - def is_exported(self) -> bool: - """Indicates whether a Python symbol is exported. - - In Python, all symbols are exported by default, so this property always returns True. - - Returns: - bool: Always True, as Python symbols are exported by default. - """ - return True - - @reader - def get_import_string( - self, - alias: str | None = None, - module: str | None = None, - import_type: ImportType = ImportType.UNKNOWN, - is_type_import: bool = False, - ) -> str: - """Generates an import string for a Python symbol. - - Returns a string representation of how to import this symbol, with support for different import types and aliasing. - - Args: - alias (str | None): Optional alias name for the import. If provided and different from symbol name, creates aliased import. - module (str | None): Optional module name to import from. If not provided, uses the symbol's file's module name. - import_type (ImportType): Type of import to generate. If WILDCARD, generates star import. Defaults to UNKNOWN. - is_type_import (bool): Whether this is a type import. Currently unused. Defaults to False. - - Returns: - str: The formatted import string. Will be one of: - - "from {module} import * as {file_name}" (for WILDCARD imports) - - "from {module} import {name} as {alias}" (for aliased imports) - - "from {module} import {name}" (for standard imports) - """ - import_module = module if module is not None else self.file.import_module_name - if import_type == ImportType.WILDCARD: - file_as_module = self.file.name - return f"from {import_module} import * as {file_as_module}" - elif alias is not None and alias != self.name: - return f"from {import_module} import {self.name} as {alias}" - else: - return f"from {import_module} import {self.name}" - - @property - @reader - def comment(self) -> PyCommentGroup | None: - """Retrieves the comment group associated with a Python symbol. - - A read-only property that returns the non-inline comment group (if any) that is associated with this symbol. - Comments are considered associated with a symbol if they appear immediately before the symbol's definition. - - Returns: - PyCommentGroup | None: A comment group object containing the symbol's comments, or None if no comments exist. - """ - return PyCommentGroup.from_symbol_comments(self) - - @property - @reader - def inline_comment(self) -> PyCommentGroup | None: - """Returns the inline comment group associated with this symbol. - - Retrieves any inline comments attached to this symbol. An inline comment appears on the same line as the code it comments on. - - Args: - self (PySymbol): The Python symbol to check for inline comments. - - Returns: - PyCommentGroup | None: A comment group containing the inline comments if they exist, None otherwise. - """ - return PyCommentGroup.from_symbol_inline_comments(self) - - @writer - def set_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, comment_type: PyCommentType = PyCommentType.SINGLE_LINE) -> None: - """Sets a comment for the Python symbol. - - Adds or modifies a comment associated with the Python symbol. If a comment already exists, - it will be edited. If no comment exists, a new comment group will be created. - - Args: - comment (str): The comment text to be added or set. - auto_format (bool, optional): If True, automatically formats the text as a comment. - Defaults to True. - clean_format (bool, optional): If True, cleans the format of the comment before - inserting. Defaults to True. - comment_type (PyCommentType, optional): Type of comment to add (e.g., single line, - multi line). Defaults to PyCommentType.SINGLE_LINE. - - Returns: - None: This method modifies the symbol's comment in place. - """ - if clean_format: - comment = PyComment.clean_comment(comment) - - # If comment already exists, add the comment to the existing comment group - if self.comment: - if auto_format: - self.comment.edit_text(comment) - else: - self.comment.edit(comment, fix_indentation=True) - else: - if auto_format: - comment = PyComment.generate_comment(comment, comment_type) - self.insert_before(comment, fix_indentation=True) - - @writer - def add_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, comment_type: PyCommentType = PyCommentType.SINGLE_LINE) -> None: - """Adds a new comment to the symbol. - - Appends a comment to the symbol either adding it to an existing comment group or creating a new one. - - Args: - comment (str): The comment text to be added. - auto_format (bool): Whether to automatically format the text into a proper comment format. - Defaults to True. - clean_format (bool): Whether to clean and normalize the comment text before adding. - Defaults to True. - comment_type (PyCommentType): The style of comment to add (e.g., single-line, multi-line). - Defaults to PyCommentType.SINGLE_LINE. - - Returns: - None - - Raises: - None - """ - if clean_format: - comment = PyComment.clean_comment(comment) - if auto_format: - comment = PyComment.generate_comment(comment, comment_type) - - # If comment already exists, add the comment to the existing comment group - if self.comment: - self.comment.insert_after(comment, fix_indentation=True) - else: - self.insert_before(comment, fix_indentation=True) - - @writer - def set_inline_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True) -> None: - """Sets an inline comment to the symbol. - - Adds or replaces an inline comment for a Python symbol. If an inline comment exists, - it will be replaced with the new comment. If no inline comment exists, a new one - will be created at the end of the line. - - Args: - comment (str): The inline comment text to add. - auto_format (bool, optional): If True, formats the text into a proper inline - comment with appropriate prefixes and spacing. Defaults to True. - clean_format (bool, optional): If True, cleans the comment text before insertion - by removing extra whitespace and comment markers. Defaults to True. - - Returns: - None - """ - if clean_format: - comment = PyComment.clean_comment(comment) - - if self.comment: - if auto_format: - self.comment.edit_text(comment) - else: - self.comment.edit(comment) - else: - if auto_format: - comment = " " + PyComment.generate_comment(comment, PyCommentType.SINGLE_LINE) - self.insert_after(comment, fix_indentation=False, newline=False) - - @writer - def flag(self, **kwargs: Unpack[FlagKwargs]) -> CodeFlag[Self]: - """Flags a Python symbol by adding a flag comment and returning a CodeFlag. - - This implementation first creates the CodeFlag through the standard flagging system, - then adds a Python-specific comment to visually mark the flagged code. - - Args: - **kwargs: Flag keyword arguments including optional 'message' - - Returns: - CodeFlag[Self]: The code flag object for tracking purposes - """ - # First create the standard CodeFlag through the base implementation - code_flag = super().flag(**kwargs) - - # Add a Python comment to visually mark the flag - message = kwargs.get("message", "") - if message: - self.set_inline_comment(f"🚩 {message}") - - return code_flag diff --git a/src/codegen/sdk/python/symbol_groups/comment_group.py b/src/codegen/sdk/python/symbol_groups/comment_group.py deleted file mode 100644 index a35cdf3dd..000000000 --- a/src/codegen/sdk/python/symbol_groups/comment_group.py +++ /dev/null @@ -1,250 +0,0 @@ -from __future__ import annotations - -import re -from typing import TYPE_CHECKING - -from docstring_parser import Docstring, DocstringStyle, parse - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.statements.statement import StatementType -from codegen.sdk.core.symbol_groups.comment_group import CommentGroup -from codegen.sdk.enums import SymbolType -from codegen.sdk.python.statements.comment import PyComment -from codegen.shared.decorators.docs import noapidoc, py_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.python.function import PyFunction - from codegen.sdk.python.symbol import PySymbol - - -@py_apidoc -class PyCommentGroup(CommentGroup): - """A group of related symbols that represent a comment or docstring in Python - - For example: - ``` - # Comment 1 - # Comment 2 - # Comment 3 - ``` - would be 3 individual comments (accessible via `symbols`), but together they form a `CommentGroup` (accessible via `self`). - """ - - _text: str # Actual text content of the comment - - @classmethod - @noapidoc - def from_symbol_comments(cls, symbol: PySymbol): - siblings = symbol.parent.parent.statements - comments = [] - # Iterate backwards from the function node to collect all preceding comment nodes - for i in range(symbol.parent.index - 1, -1, -1): - if siblings[i].statement_type == StatementType.COMMENT: - # Check if the comment is directly above each other - if siblings[i].end_point[0] == siblings[i + 1].start_point[0] - 1: - comments.insert(0, siblings[i]) - else: - break # Stop if there is a break in the comments - else: - break # Stop if a non-comment node is encountered - - from codegen.sdk.python.class_definition import PyClass - - # Check if the function node is a method - if symbol.symbol_type == SymbolType.Function: - if isinstance(symbol.parent_class, PyClass): - # Filter out the class docstring if it exists - if symbol.parent_class.docstring: - docstring_comments = set(symbol.parent_class.docstring.symbols) - comments = [c for c in comments if c not in docstring_comments] - - if not comments: - return None - - return cls(comments, symbol.file_node_id, symbol.ctx, symbol) - - @classmethod - @noapidoc - def from_symbol_inline_comments(cls, symbol: PySymbol, node: TSNode | None = None): - statement = symbol.parent - index = statement.index - siblings = statement.parent.statements - comment_nodes = [] - # Check if there are any comments after the function node - if index + 1 < len(siblings): - if siblings[index + 1].statement_type == StatementType.COMMENT: - # Check if the comment is on the same line - if siblings[index].end_point[0] == siblings[index + 1].start_point[0]: - comment_nodes.append(siblings[index + 1]) - - if not comment_nodes: - return None - - return cls(comment_nodes, symbol.file_node_id, symbol.ctx, symbol) - - @classmethod - @noapidoc - def from_docstring(cls, symbol: PySymbol): - # Check if there is an expression node above the symbol - top_child = symbol.code_block.ts_node.children[0] - if top_child.type == "expression_statement": - string_node = top_child.children[0] - if string_node.type == "string": - text = string_node.text.decode("utf-8") - comment_node = PyComment.from_code_block(string_node, symbol) - return cls([comment_node], symbol.file_node_id, symbol.ctx, symbol) - return None - - def to_google_docstring(self, function: PyFunction) -> str: # pragma: no cover - """Convert a comment group into a Google-style docstring. - - Processes the text content of the comment group and converts it into a properly formatted Google-style docstring, - incorporating existing function signature information and merging any existing docstring content with the new format. - - Args: - function (PyFunction): The Python function whose signature will be used to extract parameter and return type information. - - Returns: - str: A formatted Google-style docstring string that includes the function's description, parameters, and return value information. - """ - NAME_OF_PARAMETERS_SECTION = "Parameters:" - NAME_OF_ARGS_SECTION = "Args:" - NAME_OF_RETURNS_SECTION = "Returns:" - - def parse_google_block(section_header: str, first_line: str, docstring_iter) -> str: - """Parse the parameters section of the docstring""" - unrelated_strings = [] - parameters = {} - - # Catch edge case where there is content in the first line - if first_line_formatted := first_line.replace(section_header, "").strip(): - unrelated_strings.append(first_line_formatted) - - param_pattern = re.compile(r"^\s*(\w+)(\s+\([^)]+\))?:\s*(.+)$") - - while line := next(docstring_iter, None): - match = param_pattern.match(line) - if match: - param_name = match.group(1) - param_type = match.group(2).strip("() ") if match.group(2) else None - description = match.group(3).strip() - parameters[param_name] = (param_type, description) - else: - unrelated_strings.append(line.strip()) - - return unrelated_strings, parameters - - def merge_codebase_docstring(codebase_doc, parsed_doc): - """Merge the codebase docstring with the parsed docstring""" - for param_name, (param_type, param_description) in codebase_doc.items(): - if param_name in parsed_doc: - # Merge the types and descriptions - parsed_type, parsed_description = parsed_doc[param_name] - if not param_type: - param_type = parsed_type - if not param_description: - param_description = parsed_description - # Update the codebase docstring - codebase_doc[param_name] = (param_type, param_description) - return codebase_doc - - # Build the new docstring - new_docstring = "" - # Parse the docstring - parsed_parameters_unrelated_strings, parsed_parameters = [], {} - parsed_args_unrelated_strings, parsed_args = [], {} - parsed_returns_unrelated_strings, parsed_returns = [], {} - - # Iterate over the docstring - docstring_iter = iter(self.text.split("\n")) - while (line := next(docstring_iter, None)) is not None: - # Check if the line is a section header - if line.strip().lower().startswith(NAME_OF_PARAMETERS_SECTION.lower()): - parsed_parameters_unrelated_strings, parsed_parameters = parse_google_block(NAME_OF_PARAMETERS_SECTION, line, docstring_iter) - elif line.strip().lower().startswith(NAME_OF_ARGS_SECTION.lower()): - parsed_args_unrelated_strings, parsed_args = parse_google_block(NAME_OF_ARGS_SECTION, line, docstring_iter) - elif line.strip().lower().startswith(NAME_OF_RETURNS_SECTION.lower()): - parsed_returns_unrelated_strings, parsed_returns = parse_google_block(NAME_OF_RETURNS_SECTION, line, docstring_iter) - else: - # Add the line to the new docstring - new_docstring += line + "\n" - - # Remove extra newlines - new_docstring = new_docstring.rstrip() - - # Merge parameters and args together - parsed_args_unrelated_strings += parsed_parameters_unrelated_strings - parsed_args.update(parsed_parameters) - - # Create args section - if (args := [param for param in function.parameters if param.name != "self"]) or parsed_args_unrelated_strings or parsed_args: - args_doc = {param.name: (param.type, None) for param in args} - # Merge codebase args with parsed parameters - args_doc = merge_codebase_docstring(args_doc, parsed_args) - - new_docstring += f"\n\n{NAME_OF_ARGS_SECTION}\n" - # Generate and add the args section - if args_doc: - for arg_name, (arg_type, arg_description) in args_doc.items(): - # Add the arg to the docstring - # Add Padding and name - new_docstring += f" {arg_name}" - # Add type if it exists - if arg_type: - new_docstring += f" ({arg_type})" - # Add description if it exists - if arg_description: - new_docstring += f": {arg_description}" - # Add newline - new_docstring += "\n" - # Add a newline if there are unrelated strings - if parsed_args_unrelated_strings: - new_docstring += "\n" - # Add the unrelated strings - if parsed_args_unrelated_strings: - for unrelated_string in parsed_args_unrelated_strings: - new_docstring += f" {unrelated_string}\n" - - # Create returns section - if ((return_type := function.return_type) and return_type.source != "None") or parsed_returns_unrelated_strings or parsed_returns: - new_docstring += f"\n{NAME_OF_RETURNS_SECTION}\n" - - # Merge codebase return type with parsed return type - if (return_type := function.return_type) and return_type.source != "None": - ret_doc = {return_type: (None, None)} - ret_doc = merge_codebase_docstring(ret_doc, parsed_returns) - else: - ret_doc = parsed_returns - - # Generate and add the returns section - if ret_doc: - ret_name, (ret_type, ret_description) = next(iter(ret_doc.items())) - # Edge case: If there is no description, and parsed_returns_unrelated_strings is one line, add it to the description - if not ret_description and len(parsed_returns_unrelated_strings) == 1: - ret_description = parsed_returns_unrelated_strings.pop() - - # Add the return to the docstring - # Add Padding and name - new_docstring += f" {ret_name}" - # Add description if it exists - if ret_description: - new_docstring += f": {ret_description}" - # Add newline - new_docstring += "\n" - - # Add a newline if there are unrelated strings - if parsed_returns_unrelated_strings: - new_docstring += "\n" - # Add the unrelated strings - if parsed_returns_unrelated_strings: - for unrelated_string in parsed_returns_unrelated_strings: - new_docstring += f" {unrelated_string}\n" - - return new_docstring - - @noapidoc - @reader - def parse(self) -> Docstring: - return parse(self.source, style=DocstringStyle.GOOGLE) diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt deleted file mode 100644 index cbf343cf1..000000000 --- a/src/codegen/sdk/system-prompt.txt +++ /dev/null @@ -1,12304 +0,0 @@ ---- -title: "Codegen" -sidebarTitle: "Overview" -icon: "code" -iconType: "solid" ---- - -[Codegen](https://github.com/codegen-sh/codegen-sdk) is a python library for manipulating codebases. - -It provides a scriptable interface to a powerful, multi-lingual language server built on top of [Tree-sitter](https://tree-sitter.github.io/tree-sitter/). - -```python -from codegen import Codebase - -# Codegen builds a complete graph connecting -# functions, classes, imports and their relationships -codebase = Codebase("./") - -# Work with code without dealing with syntax trees or parsing -for function in codebase.functions: - # Comprehensive static analysis for references, dependencies, etc. - if not function.usages: - # Auto-handles references and imports to maintain correctness - function.remove() - -# Fast, in-memory code index -codebase.commit() -``` - - - -Codegen handles complex refactors while maintaining correctness, enabling a broad set of advanced code manipulation programs. - - -Codegen works with both Python and Typescript/JSX codebases. Learn more about language support [here](/building-with-codegen/language-support). - -## Quick Started - - -Codegen requires Python 3.12 - 3.13 (recommended: Python 3.13+). - - -### Using UV (Recommended) -```bash -uv tool install codegen --python 3.13 -``` - -### Using Pipx - - -Pipx is not officially supported by Codegen, but it should still work. - - -```bash -pipx install codegen -``` - - -For further & more in depth installation instructions, see the [installation guide](/introduction/installation). - - -## What can I do with Codegen? - -Codegen's simple yet powerful APIs enable a range of applications, including: - - - - Create an intelligent agent that can analyze and manipulate your codebase using natural language. - - - Generate interactive visualizations of your codebase's structure, dependencies, and relationships. - - - Create high-quality training data for fine-tuning LLMs on your codebase. - - - Create powerful code transformations to automate large-scale changes. - - - -See below for an example call graph visualization generated with Codegen. - - - - - -View source code on [modal/modal-client](https://github.com/modal-labs/modal-client/blob/cbac0d80dfd98588027ecd21850152776be3ab82/modal/client.py#L70). View codemod on [codegen.sh](https://www.codegen.sh/codemod/66e2e195-ceec-4935-876a-ed4cfc1731c7/public/diff) - - -## Get Started - -import { - COMMUNITY_SLACK_URL, - CODEGEN_SDK_GITHUB_URL, -} from "/snippets/links.mdx"; - - - - Follow our step-by-step tutorial to start manipulating code with Codegen. - - - Learn how to use Codegen for common code transformation tasks. - - - Star us on GitHub and contribute to the project. - - - Get help and connect with the Codegen community. - - - -## Why Codegen? - -Many software engineering tasks - refactors, enforcing patterns, analyzing control flow, etc. - are fundamentally programmatic operations. Yet the tools we use to express these transformations often feel disconnected from how we think about code. - -Codegen was engineered backwards from real-world refactors we performed for enterprises at [Codegen, Inc.](/introduction/about). Instead of starting with theoretical abstractions, we built the set of APIs that map directly to how humans and AI think about code changes: - -- **Natural Mental Model**: Express transformations through high-level operations that match how you reason about code changes, not low-level text or AST manipulation. -- **Clean Business Logic**: Let the engine handle the complexities of imports, references, and cross-file dependencies. -- **Scale with Confidence**: Make sweeping changes across large codebases consistently across Python, TypeScript, JavaScript, and React. - -As AI becomes increasingly sophisticated, we're seeing a fascinating shift: AI agents aren't bottlenecked by their ability to understand code or generate solutions. Instead, they're limited by their ability to efficiently manipulate codebases. The challenge isn't the "brain" - it's the "hands." - -We built Codegen with a key insight: future AI agents will need to ["act via code,"](/blog/act-via-code) building their own sophisticated tools for code manipulation. Rather than generating diffs or making direct text changes, these agents will: - -1. Express transformations as composable programs -2. Build higher-level tools by combining primitive operations -3. Create and maintain their own abstractions for common patterns - -This creates a shared language that both humans and AI can reason about effectively, making code changes more predictable, reviewable, and maintainable. Whether you're a developer writing a complex refactoring script or an AI agent building transformation tools, Codegen provides the foundation for expressing code changes as they should be: through code itself. - - ---- -title: "Getting Started" -sidebarTitle: "Getting Started" -icon: "bolt" -iconType: "solid" ---- - -A quick tour of Codegen in a Jupyter notebook. - -## Installation - -Install [codegen](https://pypi.org/project/codegen/) on Pypi via [uv](https://github.com/astral-sh/uv): - -```bash -uv tool install codegen -``` - -## Quick Start with Jupyter - -The [codegen notebook](/cli/notebook) command creates a virtual environment and opens a Jupyter notebook for quick prototyping. This is often the fastest way to get up and running. - -```bash -# Launch Jupyter with a demo notebook -codegen notebook --demo -``` - - - - The `notebook --demo` comes pre-configured to load [FastAPI](https://github.com/fastapi/fastapi)'s codebase, so you can start - exploring right away! - - - - Prefer working in your IDE? See [IDE Usage](/introduction/ide-usage) - - -## Initializing a Codebase - -Instantiating a [Codebase](/api-reference/core/Codebase) will automatically parse a codebase and make it available for manipulation. - -```python -from codegen import Codebase - -# Clone + parse fastapi/fastapi -codebase = Codebase.from_repo('fastapi/fastapi') - -# Or, parse a local repository -codebase = Codebase("path/to/git/repo") -``` - - - This will automatically infer the programming language of the codebase and - parse all files in the codebase. Learn more about [parsing codebases here](/building-with-codegen/parsing-codebases) - - -## Exploring Your Codebase - -Let's explore the codebase we just initialized. - -Here are some common patterns for code navigation in Codegen: - -- Iterate over all [Functions](/api-reference/core/Function) with [Codebase.functions](/api-reference/core/Codebase#functions) -- View class inheritance with [Class.superclasses](/api-reference/core/Class#superclasses) -- View function usages with [Function.usages](/api-reference/core/Function#usages) -- View inheritance hierarchies with [inheritance APIs](https://docs.codegen.com/building-with-codegen/class-api#working-with-inheritance) -- Identify recursive functions by looking at [FunctionCalls](https://docs.codegen.com/building-with-codegen/function-calls-and-callsites) -- View function call-sites with [Function.call_sites](/api-reference/core/Function#call-sites) - -```python -# Print overall stats -print("🔍 Codebase Analysis") -print("=" * 50) -print(f"📚 Total Classes: {len(codebase.classes)}") -print(f"⚡ Total Functions: {len(codebase.functions)}") -print(f"🔄 Total Imports: {len(codebase.imports)}") - -# Find class with most inheritance -if codebase.classes: - deepest_class = max(codebase.classes, key=lambda x: len(x.superclasses)) - print(f"\n🌳 Class with most inheritance: {deepest_class.name}") - print(f" 📊 Chain Depth: {len(deepest_class.superclasses)}") - print(f" ⛓️ Chain: {' -> '.join(s.name for s in deepest_class.superclasses)}") - -# Find first 5 recursive functions -recursive = [f for f in codebase.functions - if any(call.name == f.name for call in f.function_calls)][:5] -if recursive: - print(f"\n🔄 Recursive functions:") - for func in recursive: - print(f" - {func.name}") -``` - -## Analyzing Tests - -Let's specifically drill into large test files, which can be cumbersome to manage. - -```python -from collections import Counter - -# Filter to all test functions and classes -test_functions = [x for x in codebase.functions if x.name.startswith('test_')] -test_classes = [x for x in codebase.classes if x.name.startswith('Test')] - -print("🧪 Test Analysis") -print("=" * 50) -print(f"📝 Total Test Functions: {len(test_functions)}") -print(f"🔬 Total Test Classes: {len(test_classes)}") -print(f"📊 Tests per File: {len(test_functions) / len(codebase.files):.1f}") - -# Find files with the most tests -print("\n📚 Top Test Files by Class Count") -print("-" * 50) -file_test_counts = Counter([x.file for x in test_classes]) -for file, num_tests in file_test_counts.most_common()[:5]: - print(f"🔍 {num_tests} test classes: {file.filepath}") - print(f" 📏 File Length: {len(file.source)} lines") - print(f" 💡 Functions: {len(file.functions)}") -``` - -## Splitting Up Large Test Files - -Lets split up the largest test files into separate modules for better organization. - -This uses Codegen's [codebase.move_to_file(...)](/building-with-codegen/moving-symbols), which will: -- update all imports -- (optionally) move dependencies -- do so very fast ⚡️ - -While maintaining correctness. - -```python -filename = 'tests/test_path.py' -print(f"📦 Splitting Test File: {filename}") -print("=" * 50) - -# Grab a file -file = codebase.get_file(filename) -base_name = filename.replace('.py', '') - -# Group tests by subpath -test_groups = {} -for test_function in file.functions: - if test_function.name.startswith('test_'): - test_subpath = '_'.join(test_function.name.split('_')[:3]) - if test_subpath not in test_groups: - test_groups[test_subpath] = [] - test_groups[test_subpath].append(test_function) - -# Print and process each group -for subpath, tests in test_groups.items(): - print(f"\\n{subpath}/") - new_filename = f"{base_name}/{subpath}.py" - - # Create file if it doesn't exist - if not codebase.has_file(new_filename): - new_file = codebase.create_file(new_filename) - file = codebase.get_file(new_filename) - - # Move each test in the group - for test_function in tests: - print(f" - {test_function.name}") - test_function.move_to_file(new_file, strategy="add_back_edge") - -# Commit changes to disk -codebase.commit() -``` - - - In order to commit changes to your filesystem, you must call - [codebase.commit()](/api-reference/core/Codebase#commit). Learn more about - [commit() and reset()](/building-with-codegen/commit-and-reset). - - -### Finding Specific Content - -Once you have a general sense of your codebase, you can filter down to exactly what you're looking for. Codegen's graph structure makes it straightforward and performant to find and traverse specific code elements: - -```python -# Grab specific content by name -my_resource = codebase.get_symbol('TestResource') - -# Find classes that inherit from a specific base -resource_classes = [ - cls for cls in codebase.classes - if cls.is_subclass_of('Resource') -] - -# Find functions with specific decorators -test_functions = [ - f for f in codebase.functions - if any('pytest' in d.source for d in f.decorators) -] - -# Find files matching certain patterns -test_files = [ - f for f in codebase.files - if f.name.startswith('test_') -] -``` - -## Safe Code Transformations - -Codegen guarantees that code transformations maintain correctness. It automatically handles updating imports, references, and dependencies. Here are some common transformations: - -```python -# Move all Enum classes to a dedicated file -for cls in codebase.classes: - if cls.is_subclass_of('Enum'): - # Codegen automatically: - # - Updates all imports that reference this class - # - Maintains the class's dependencies - # - Preserves comments and decorators - # - Generally performs this in a sane manner - cls.move_to_file(f'enums.py') - -# Rename a function and all its usages -old_function = codebase.get_function('process_data') -old_function.rename('process_resource') # Updates all references automatically - -# Change a function's signature -handler = codebase.get_function('event_handler') -handler.get_parameter('e').rename('event') # Automatically updates all call-sites -handler.add_parameter('timeout: int = 30') # Handles formatting and edge cases -handler.add_return_type('Response | None') - -# Perform surgery on call-sites -for fcall in handler.call_sites: - arg = fcall.get_arg_by_parameter_name('env') - # f(..., env={ data: x }) => f(..., env={ data: x or None }) - if isinstance(arg.value, Collection): - data_key = arg.value.get('data') - data_key.value.edit(f'{data_key.value} or None') -``` - - - When moving symbols, Codegen will automatically update all imports and - references. See [Moving Symbols](/building-with-codegen/moving-symbols) to - learn more. - - -## Leveraging Graph Relations - -Codegen's graph structure makes it easy to analyze relationships between code elements across files: - -```python -# Find dead code -for func in codebase.functions: - if len(func.usages) == 0: - print(f'🗑️ Dead code: {func.name}') - func.remove() - -# Analyze import relationships -file = codebase.get_file('api/endpoints.py') -print("\nFiles that import endpoints.py:") -for import_stmt in file.inbound_imports: - print(f" {import_stmt.file.path}") - -print("\nFiles that endpoints.py imports:") -for import_stmt in file.imports: - if import_stmt.resolved_symbol: - print(f" {import_stmt.resolved_symbol.file.path}") - -# Explore class hierarchies -base_class = codebase.get_class('BaseModel') -if base_class: - print(f"\nClasses that inherit from {base_class.name}:") - for subclass in base_class.subclasses: - print(f" {subclass.name}") - # We can go deeper in the inheritance tree - for sub_subclass in subclass.subclasses: - print(f" └─ {sub_subclass.name}") -``` - - - Learn more about [dependencies and - references](/building-with-codegen/dependencies-and-usages) or [imports](/building-with-codegen/imports) and [exports](/building-with-codegen/exports). - - -## Advanced Settings - -Codegen also supports a number of advanced settings that can be used to customize the behavior of the graph construction process. - -These flags are helpful for debugging problematic repos, optimizing Codegen’s performance, or testing unreleased or experimental (potentially backwards-breaking) features. - -```python -from codegen import Codebase -from codegen.configs import CodebaseConfig - -# Initialize a Codebase with custom configuration -codebase = Codebase( - "path/to/git/repo"", - config=CodebaseConfig( - verify_graph=True, - method_usages=False, - sync_enabled=True, - generics=False, - import_resolution_overrides={ - "old_module": "new_module" - }, - ts_language_engine=True, - v8_ts_engine=True - ) -) -``` - -To learn more about available settings, see the [Advanced Settings](/introduction/advanced-settings) page. - - -These are considered experimental and unstable features that may be removed or changed in the future. - - -## What's Next? - - - - Follow step-by-step tutorials for common code transformation tasks like - modernizing React codebases or migrating APIs. - - - Understand key concepts like working with files, functions, imports, and the - call graph to effectively manipulate code. - - - Iterate locally with your favorite IDE, work with a debugger and build sophisticated codemods - - - Learn how to use Codegen with Cursor, Devin, Windsurf, and more. - - - - - ---- -title: "Installation" -sidebarTitle: "Installation" -icon: "download" -iconType: "solid" ---- - -Install and set up Codegen in your development environment. - -#### We currently support: -- Running Codegen in Python 3.12 - 3.13 (recommended: Python 3.13+) -- macOS and Linux - - macOS is supported - - Linux is supported on x86_64 and aarch64 with glibc 2.34+ - - Windows is supported via WSL. See [here](https://docs.codegen.com/building-with-codegen/codegen-with-wsl) for more details. -- Python, Typescript, Javascript and React codebases - -## Prerequisites - -We recommend using [uv](https://github.com/astral-sh/uv) for installation. If you haven't installed `uv` yet: -```bash -curl -LsSf https://astral.sh/uv/install.sh | sh -``` - -## Installing Codegen - -```bash -uv tool install codegen --python 3.13 -``` - - - -This makes the `codegen` command available globally in your terminal, while keeping its dependencies isolated. - - -## Quick Start - -Let's walk through a minimal example of using Codegen in a project: - -1. Navigate to your repository: - ```bash - cd path/to/your/project - ``` - -2. Initialize Codegen in your project with [codegen init](/cli/init): - ```bash - codegen init - ``` - - This creates a `.codegen/` directory with: - ```bash - .codegen/ - ├── .venv/ # Python virtual environment (gitignored) - ├── config.toml # Project configuration - ├── codemods/ # Your codemod implementations - ├── jupyter/ # Jupyter notebooks for exploration - └── codegen-system-prompt.txt # AI system prompt - ``` - -3. Create your first codemod with [codegen create](/cli/create): - ```bash - codegen create organize-imports \ - -d "Sort and organize imports according to PEP8" - ``` - - The `-d` flag in `codegen create` generates an AI-powered implementation. This requires a Github account registered on [codegen.sh](https://codegen.sh) - - - - -4. Run your codemod with [codegen run](/cli/run): - ```bash - codegen run organize-imports - ``` - -5. Reset any filesystem changes (excluding `.codegen/*`) with [codegen reset](/cli/reset): - ```bash - codegen reset - ``` - -## Troubleshooting - -Having issues? Here are some common problems and their solutions: - -- **I'm hitting an UV error related to `[[ packages ]]`**: This means you're likely using an outdated version of UV. Try updating to the latest version with: `uv self update`. -- **I'm hitting an error about `No module named 'codegen.sdk.extensions.utils'`**: The compiled cython extensions are out of sync. Update them with `uv sync --reinstall-package codegen`. -- **I'm hitting a `RecursionError: maximum recursion depth exceeded` error while parsing my codebase**: If you are using python 3.12, try upgrading to 3.13. If you are already on 3.13, try upping the recursion limit with `sys.setrecursionlimit(10000)`. - - -For more help, join our [community Slack](/introduction/community) or check the [FAQ](/introduction/faq). - - -## Next Steps - - - - Learn how to use Codegen effectively in VSCode, Cursor, and other IDEs. - - - Follow step-by-step tutorials for common code transformation tasks. - - - Leverage AI assistants like Copilot, Cursor and Devin - - - Learn more about building with Codegen - - - - - ---- -title: "Using Codegen in Your IDE" -sidebarTitle: "IDE Usage" -icon: "window" -iconType: "solid" ---- - -Get up and running with Codegen programs in IDEs like VSCode, Cursor and PyCharm. - -Make sure to [install and initialize](/introduction/installation) Codegen with `codegen init` - -## Configuring your IDE Interpreter - -Codegen creates a custom Python environment in `.codegen/.venv`. Configure your IDE to use this environment for the best development experience. - - - - 1. Install the VSCode Python Extensions for LSP and debugging support. We recommend Python, Pylance and Python Debugger for the best experience. - - 2. Open the Command Palette (Cmd/Ctrl + Shift + P) - 3. Type "Python: Select Interpreter" - - 4. Choose "Enter interpreter path" - 5. Navigate to and select: - ```bash - .codegen/.venv/bin/python - ``` - - Alternatively, create a `.vscode/settings.json`: - ```json - { - "python.defaultInterpreterPath": "${workspaceFolder}/.codegen/.venv/bin/python", - "python.analysis.extraPaths": [ - "${workspaceFolder}/.codegen/.venv/lib/python3.12/site-packages" - ] - } - ``` - - - - 1. Open PyCharm Settings/Preferences - 2. Navigate to "Project > Python Interpreter" - 3. Click the gear icon ⚙️ and select "Add" - 4. Choose "Existing Environment" - 5. Set interpreter path to: - ```bash - .codegen/.venv/bin/python - ``` - - - - - -## MCP Server Setup -This is an optional step but highly recommended if your IDE supports MCP support and you use AI Agents. -The MCP server is a local server that allows your AI Agent to interact with the Codegen specific tools, -it will allow an agent to: -- ask an expert to create a codemod -- improve a codemod -- get setup instructions - -### IDE Configuration -#### Cline -Add this to your cline_mcp_settings.json: -```json -{ - "mcpServers": { - "codegen-cli": { - "command": "uv", - "args": [ - "--directory", - "/codegen-sdk/src/codegen/cli/mcp", - "run", - "server.py" - ] - } - } -} -``` - - -#### Cursor: -Under the `Settings` > `Feature` > `MCP Servers` section, click "Add New MCP Server" and add the following: - -``` -Name: codegen-mcp -Type: Command -Command: uv --directory /codegen-sdk/src/codegen/cli/mcp run server.py -``` - - -## Index Codegen Docs -#### Cursor: -If you use Cursor you'll be able to configure the IDE to index the Codegen docs. To do so go to `Settings` > `Features` > `Docs` -and then click on `Add new docs`. We recommend using this url to index the API reference: -``` -https://docs.codegen.com/api-reference/index -``` - - -## Create a New Codemod - -Generate the boilerplate for a new code manipulation program using [codegen create](/cli/create): - -```bash -codegen create organize-types \ - -d "Move all TypeScript types to \ - into a centralized types.ts file" -``` - - - Passing in `-d --description` will get an LLM expert to compose an initial version for you. This requires a Github account registered on [codegen.sh](https://codegen.sh) - - -This will: -1. Create a new codemod in `.codegen/codemods/organize_types/` -2. Generate a custom `system-prompt.txt` based on your task -3. Set up the basic structure for your program - - -The generated codemod includes type hints and docstrings, making it easy to get IDE autocompletion and documentation. - - -## Iterating with Chat Assistants - -When you do `codegen init`, you will receive a [system prompt optimized for AI consumption](/introduction/work-with-ai) at `.codegen/codegen-system-prompt.txt`. - -If you reference this file in "chat" sessions with Copilot, Cursor, Cody, etc., the assistant will become fluent in Codegen. - - - - Collaborating with Cursor's assistant and the Codegen system prompt - - -In addition, when you [create](/cli/create) a codemod with "-d", Codegen generates an optimized system prompt in `.codegen/codemods/{name}/{name}-system-prompt.txt`. This prompt contains: -- Relevant Codegen API documentation -- Examples of relevant transformations -- Context about your specific task - - -You can also drag and drop the system prompt ([available here](/introduction/work-with-ai))file directly into chat windows like ChatGPT or Claude for standalone help. - - -## Running and Testing Codemods - -```bash -# Run => write changes to disk -codegen run organize-types - -# Reset changes on disk -codegen reset -``` - -You can also run the program directly via `.codegen/.venv/bin/python path/to/codemod.py` or via your editor's debugger - -## Viewing Changes - -We recommend viewing changes in your IDE's native diff editor. - - -## What's Next - - - - See real-world examples of codemods in action. - - - Learn about Codegen's core concepts and features - - - - ---- -title: "Working with AI" -sidebarTitle: "AI Integration" -icon: "microchip" -iconType: "solid" ---- - -Codegen is designed to be used with AI assistants. This document describes how to use Codegen with common AI tools, including Copilot, Cursor, Devin and more. - -## System Prompt - -Codegen provides a `.txt` file that you can drag-and-drop into any chat assistant. This is roughly 60k tokens and will enable chat assistants like, ChatGPT, Claude 3.5 etc. to build effectively with Codegen. - -import { - CODEGEN_SYSTEM_PROMPT -} from "/snippets/links.mdx"; - - - Download System Prompt - - -Learn about leveraging this in IDE chat assistants like Cursor [here](/introduction/ide-usage#iterating-with-chat-assistants) - -## Generating System Prompts - -The [Codegen CLI](/cli/about) provides commands to generate `.md` files that can be fed to any AI assistant for more accurate and contextual help. - -When you create a new codemod via [codegen create](/cli/create): - -```bash -codegen create delete-dead-imports --description "Delete unused imports" -``` - -Codegen automatically generates an optimized ["system prompt"](https://news.ycombinator.com/item?id=37880023) that includes: - -- An introduction to Codegen -- Codegen API documentation -- Examples of relevant transformations - -You can find this generated prompt in the `.codegen/prompts/-system-prompt.md` file. - - - All contents of the `.codegen/prompts` directory are by default ignored the - `.gitignore` file. after running [codegen init](/cli/init) - - -This `.md` file can be used with any AI assistant (Claude, GPT-4, etc.) to get more accurate and contextual help. - -## Example Workflow - - - - Use the [create command](/cli/create) with a detailed description of what you want to accomplish: - ```bash - codegen create modernize-components --description "Convert class components to functional components with hooks" - ``` - - - Check the AI context that Codegen generated for your transformation: ```bash - cat codegen-sh/codemods/modernize-components/prompt.md ``` - - - - Reference your codemod when asking questions to get contextual help: ``` - @codegen-sh/codemods/modernize-components How should I handle - componentDidMount? ``` - - - - The AI will understand you're working on React modernization and provide relevant suggestions about using useEffect hooks and other modern React patterns. - - - -## Copilot, Cursor and Windsurf (IDEs) - -When using IDE chat assistants, you can leverage Codegen's context by mentioning your codemod in composer mode: - -```bash -@.codegen/codemods/upgrade-react18 @.codegen/prompts/system-prompt.md -``` - -This will ensure that the IDE's native chat model is aware of the APIs and common patterns for Codegen. - -## Devin, OpenHands and Semi-autonomous Code Agents - -Coming soon! - - ---- -title: "Under the Hood" -sidebarTitle: "How it Works" -icon: "gear" -iconType: "solid" -subtitle: "How Codegen's codebase graph works" ---- - -Codegen performs advanced static analysis to build a rich graph representation of your codebase. This pre-computation step analyzes dependencies, references, types, and control flow to enable fast and reliable code manipulation operations. - - - Codegen is built on top of - [Tree-sitter](https://tree-sitter.github.io/tree-sitter/) and - [rustworkx](https://github.com/Qiskit/rustworkx) and has implemented most - language server features from scratch. - - - Codegen is open source. Check out the [source - code](https://github.com/codegen-sh/codegen-sdk) to learn more! - - -## The Codebase Graph - -At the heart of Codegen is a comprehensive graph representation of your code. When you initialize a [Codebase](/api-reference/core/Codebase), it performs static analysis to construct a rich graph structure connecting code elements: - -```python -# Initialize and analyze the codebase -from codegen import Codebase -codebase = Codebase("./") - -# Access pre-computed relationships -function = codebase.get_symbol("process_data") -print(f"Dependencies: {function.dependencies}") # Instant lookup -print(f"Usages: {function.usages}") # No parsing needed -``` - -### Building the Graph - -Codegen's graph construction happens in two stages: - -1. **AST Parsing**: We use [Tree-sitter](https://tree-sitter.github.io/tree-sitter/) as our foundation for parsing code into Abstract Syntax Trees. Tree-sitter provides fast, reliable parsing across multiple languages. - -2. **Multi-file Graph Construction**: Custom parsing logic, implemented in [rustworkx](https://github.com/Qiskit/rustworkx) and Python, analyzes these ASTs to construct a more sophisticated graph structure. This graph captures relationships between [symbols](/building-with-codegen/symbol-api), [files](/building-with-codegen/files-and-directories), [imports](/building-with-codegen/imports), and more. - -### Performance Through Pre-computation - -Pre-computing a rich index enables Codegen to make certain operations very fast that that are relevant to refactors and code analysis: - -- Finding all usages of a symbol -- Detecting circular dependencies -- Analyzing the dependency graphs -- Tracing call graphs -- Static analysis-based code retrieval for RAG -- ...etc. - - - Pre-parsing the codebase enables constant-time lookups rather than requiring - re-parsing or real-time analysis. - - -## Multi-Language Support - -One of Codegen's core principles is that many programming tasks are fundamentally similar across languages. - -Currently, Codegen supports: - -- [Python](/api-reference/python) -- [TypeScript](/api-reference/typescript) -- [React & JSX](/building-with-codegen/react-and-jsx) - - - Learn about how Codegen handles language specifics in the [Language - Support](/building-with-codegen/language-support) guide. - - -We've started with these ecosystems but designed our architecture to be extensible. The graph-based approach provides a consistent interface across languages while handling language-specific details under the hood. - -## Build with Us - -Codegen is just getting started, and we're excited about the possibilities ahead. We enthusiastically welcome contributions from the community, whether it's: - -- Adding support for new languages -- Implementing new analysis capabilities -- Improving performance -- Expanding the API -- Adding new transformations -- Improving documentation - -Check out our [community guide](/introduction/community) to get involved! - - ---- -title: "Advanced Settings" -sidebarTitle: "Advanced Settings" -icon: "memory" -iconType: "solid" ---- - -Codegen's [Codebase](/api-reference/core/Codebase) constructor accepts a `CodebaseConfig` object which is used to configure more advanced behaviors of the graph construction process. - -These flags are helpful for debugging problematic repos, optimizing Codegen's performance, or testing unreleased or experimental (potentially backwards-breaking) features. - - -**These are considered experimental features and may change in the future!** - -As such, they may have little to no testing or documentation. Many of these flags may also be unsupported in the future! - -If you need help, please visit our [community](/introduction/community). - - - -These configuration options are defined in [src/codegen/configs/models/codebase.py](https://github.com/codegen-sh/codegen/blob/develop/src/codegen/configs/models/codebase.py). - - -# Usage - -You can customize the behavior of the graph construction process when initializing a [Codebase](/api-reference/core/Codebase) by passing a `CodebaseConfig` object with the desired configuration flags. - -```python -from codegen import Codebase -from codegen.configs import CodebaseConfig - -# Initialize a Codebase with custom configuration -codebase = Codebase( - "", - config=CodebaseConfig( - flag1=..., - flag2=..., - ... - ) -) -``` - -# Table of Contents - -- [debug](#flag-debug) -- [verify-graph](#flag-verify-graph) -- [track-graph](#flag-track-graph) -- [method-usages](#flag-method-usages) -- [sync-enabled](#flag-sync-enabled) -- [full-range-index](#flag-full-range-index) -- [ignore-process-errors](#flag-ignore-process-errors) -- [disable-graph](#flag-disable-graph) -- [disable-file-parse](#flag-disable-file-parse) -- [exp-lazy-graph](#flag-exp-lazy-graph) -- [generics](#flag-generics) -- [import-resolution-paths](#flag-import-resolution-paths) -- [import-resolution-overrides](#flag-import-resolution-overrides) -- [py-resolve-syspath](#flag-py-resolve-syspath) -- [ts-dependency-manager](#flag-ts-dependency-manager) -- [ts-language-engine](#flag-ts-language-engine) -- [v8-ts-engine](#flag-v8-ts-engine) -- [unpacking-assignment-partial-removal](#flag-unpacking-assignment-partial-removal) - -# Configuration Flags - -## Flag: `debug` -> **Default: `False`** - -Enables verbose logging for debugging purposes. In its current form, it enables: -- Verbose logging when adding nodes to the graph -- Verbose logging during initial file parsing -- Additional assertions on graph creation -- Additional (costly) debug metrics on codebase construction -- etc. - - -This flag may be very noisy and significantly impact performance. It is generally not recommended to use. - - -## Flag: `verify_graph` -> **Default: `False`** - -Adds assertions for graph state during reset resync. Used to test and debug graph desyncs after a codebase reset. - -Runs `post_reset_validation` after a reset resync. - - -This is an internal debug flag. - - -## Flag: `track_graph` -> **Default: `False`** - -Keeps a copy of the original graph before a resync. Used in conjunction with `verify_graph` to test and debug graph desyncs. - -Original graph is saved as `ctx.old_graph`. - - -This is an internal debug flag. - - -## Flag: `method_usages` -> **Default: `True`** - -Enables and disables resolving method usages. - -**Example Codebase:** -```python -class Foo: - def bar(): - ... - -obj = Foo() -obj.bar() # Method Usage -``` - -**Codemod with `method_usages` on:** -```python -bar_func = codebase.get_class("Foo").get_method("bar") -len(bar_func.usages) # 1 -bar_func.usages # [obj.bar()] -``` - -**Codemod with `method_usages` off:** -```python -bar_func = codebase.get_class("Foo").get_method("bar") -len(bar_func.usages) # 0 -bar_func.usages # [] -``` - -Method usage resolution could be disabled for a marginal performance boost. However, it is generally recommended to leave it enabled. - -## Flag: `sync_enabled` -> **Default: `False`** - -Enables or disables graph sync during `codebase.commit`. - - -Implementation-specific details on sync graph can be found [here](https://github.com/codegen-sh/codegen/blob/develop/architecture/6.%20incremental-computation/C.%20Graph%20Recomputation.md). - - -This section won't go into the specific details of sync graph, but the general idea is that enabling sync graph will update the Codebase object to whatever new changes were made. - -**Example with `sync_enabled` on:** -```python -file = codebase.get_file(...) -file.insert_after("foobar = 1") -codebase.commit() - -foobar = codebase.get_symbol("foobar") -assert foobar # foobar is available after commit / graph sync -``` - -**Example with `sync_enabled` disabled:** -```python -file = codebase.get_file(...) -file.insert_after("foobar = 1") - -foobar = codebase.get_symbol("foobar", optional=True) -assert not foobar # foobar is not available after commit -``` - - -Enabling sync graph will have a performance impact on codebase commit, but will also unlock a bunch of operations that were previously not possible. - - -## Flag: `full_range_index` -> **Default: `False`** - -By default, Codebase maintains an internal range-to-node index for fast lookups. (i.e. `bytes 120 to 130 maps to node X`). -For optimization purposes, this only applies to nodes defined and handled by `parser.py`. - -Enabling `full_range_index` will create an additional index that maps **all** tree-sitter ranges to nodes. -This can be useful for debugging or when you need to build any applications that require a full range-to-node index (i.e. a codebase tree lookup). - - -This flag **significantly** increases memory usage! - - -## Flag: `ignore_process_errors` -> **Default: `True`** - -Controls whether to ignore errors that occur during external process execution (such as dependency manager or language engine). - -Disabling `ignore_process_errors` would make Codegen fail on errors that would otherwise be logged then ignored. - -## Flag: `disable_graph` -> **Default: `False`** - -Disables the graph construction process. Any operations that require the graph will no longer work. (In other words, this turns off import resolution and usage/dependency resolution) - -Functions that operate purely on AST such as getting and editing parameters or modifying function and class definitions will still work. - - -For codemods that do not require the graph (aka only AST/Syntax-level changes), **disabling graph parse could yield a 30%-40% decrease in parse time and memory usage**! - - -## Flag: `disable_file_parse` -> **Default: `False`** - -Disables **ALL** parsing, including file and graph parsing. This essentially treats all codebases as the "UNSUPPORTED" language mode. - -Nearly all functions except for editing primitives like `codebase.get_file` and `file.edit` will no longer work. - - -This flag is useful for any usages of Codegen that do **NOT** require any AST/CST/Graph parsing. (i.e. using Codegen purely as a file editing harness) - -If this is your use case, this **could decrease parse and memory usage by 95%.** - - -## Flag: `exp_lazy_graph` -> **Default: `False`** - -This experimental flag pushes the graph creation back until the graph is needed. This is an experimental feature and may have some unintended consequences. - -**Example Codemod:** -```python -from codegen import Codebase -from codegen.configs import CodebaseConfig - -# Enable lazy graph parsing -codebase = Codebase("", config=CodebaseConfig(exp_lazy_graph=True)) - -# The codebase object will be created immediately with no parsing done -# These all do not require graph parsing -codebase.files -codebase.directories -codebase.get_file("...") - -# These do require graph parsing, and will create the graph only if called -codebase.get_function("...") -codebase.get_class("...") -codebase.imports -``` - - -This may have a very slight performance boost. Use at your own risk! - - -## Flag: `generics` -> **Default: `True`** - -Enables and disables generic type resolution. - -**Example Codebase:** -```python -class Point: - def scale(cls, n: int): - pass - -class List[T](): - def pop(self) -> T: - ... - -l: List[Point] = [] -l.pop().scale(1) # Generic Usage -``` - -**Codemod with `generics` on:** -```python -bar_func = codebase.get_class("Point").get_method("scale") -len(bar_func.usages) # 1 -bar_func.usages # [l.pop().scale(1)] -``` - -**Codemod with `generics` off:** -```python -bar_func = codebase.get_class("Point").get_method("scale") -len(bar_func.usages) # 0 -bar_func.usages # [] -``` - - -Generic resolution is still largely WIP and experimental, and may not work in all cases. In some rare circumstances, disabling generics may result in a significant performance boost. - - -## Flag: `import_resolution_paths` -> **Default: `[]`** - -Controls alternative paths to resolve imports from. - -**Example Codebase:** -```python -# a/b/c/src.py -def update(): - pass - -# consumer.py -from c import src as operations - -operations.update() -``` - -**Codemod:** -```python -codebase.ctx.config.import_resolution_paths = ["a/b"] -``` - -## Flag: `import_resolution_overrides` -> **Default: `{}`** - -Controls import path overrides during import resolution. - -**Example** -`from a.b.c import d` with the override `a/b` -> `foo/bar` will internally resolve the import as `from foo.bar.c import d`. - -## Flag: `py_resolve_syspath` -> **Default: `False`** - -Enables and disables resolution of imports from `sys.path`. - - -For this to properly work, you must also set `allow_external` to `True`. - - -## Flag: `allow_external` -> **Default: `False`** - -Enables resolving imports, files, modules, and directories from outside of the repo path. - - -Turning this flag off may allow for bad actors to access files outside of the repo path! Use with caution! - - -## Flag: `ts_dependency_manager` -> **Default: `False`** - - -**This is an internal flag used for Codegen Cloud and should not be used externally!** - -This flag **WILL** nuke any existing `node_modules` folder! - - - -This flag also assumes many constants for Codegen Cloud. Very likely this will not work if run locally. - -Instead, just install `node_modules` as normal (either through `npm`, `pnpm`, or `yarn`) and skip this setting! - - -Enables Codegen's internal dependency installer for TypeScript. This will modify `package.json` and install the bare minimum set of installable dependencies. - - -More documentation on TypeScript dependency manager can be found [here](https://github.com/codegen-sh/codegen/blob/develop/architecture/external/dependency-manager.md) - - -## Flag: `ts_language_engine` -> **Default: `False`** - - -This feature was built primarily with Codegen Cloud in mind. As such, this assumes a valid NodeJS and TypeScript environment. - - -Enables using the TypeScript compiler to extract information from the codebase. Enables commands such as `inferred_return_type`. - - -This will increase memory usage and parsing time. Larger repos may even hit resource constraints with the bundled TypeScript compiler integration. - - -## Flag: `v8_ts_engine` -> **Default: `False`** - - -This feature flag requires `ts_language_engine` to be enabled as well. - - -Enables using the **V8-based TypeScript compiler** to extract information from the codebase. Enables commands such as `inferred_return_type`. - -The V8 implementation (as opposed to the default external-process based implementation) is less stable, but provides the entire TypeScript API to be used from within Codegen. - - -This will increase memory usage and parsing time. Larger repos may even hit resource constraints with the V8-based TypeScript compiler integration. - - -## Flag: `unpacking_assignment_partial_removal` -> **Default: `False`** - -Enables smarter removal of unpacking assignments. - -**Example Codebase:** -```python -a, b, c = (1, 2, 3) -``` - -**Codemod with `unpacking_assignment_partial_removal` on:** -```python -file = codebase.get_file(...) -b = file.get_symbol("b") -b.remove() -codebase.commit() - -file.symbols # [a, c] -file.source # "a, c = (1, 3)" -``` - -**Codemod with `unpacking_assignment_partial_removal` off:** -```python -file = codebase.get_file(...) -b = file.get_symbol("b") -b.remove() -codebase.commit() - -file.symbols # [] -file.source # "" -``` - - ---- -title: "Guiding Principles" -sidebarTitle: "Principles" -icon: "compass" -iconType: "solid" ---- - -Codegen was developed by working backwards from real-world, large-scale codebase migrations. Instead of starting with abstract syntax trees and parser theory, we started with the question: "How do developers actually think about code changes?" - -This practical origin led to four core principles that shape Codegen's design: - -## Intuitive APIs - -Write code that reads like natural language, without worrying about abstract syntax trees or parser internals. Codegen provides high-level APIs that map directly to the transformations developers want to perform: - -```python -# Methods that read like English -function.rename("new_name") # Not ast.update_node(function_node, "name", "new_name") -function.move_to_file("new_file.py") # Not ast.relocate_node(function_node, "new_file.py") - -# Clean, readable properties -if function.is_async: # Not ast.get_node_attribute(function_node, "async") - print(function.name) # Not ast.get_node_name(function_node) - -# Natural iteration patterns -for usage in function.usages: # Not ast.find_references(function_node) - print(f"Used in {usage.file.name}") -``` - -## No Sharp Edges - -Focus on your high-level intent while Codegen handles the intricate details. - -Codegen operations handle the edge cases - it should be hard to break lint. - -```python -# Moving a function? Codegen handles: -function.move_to_file("new_file.py") -# ✓ Updating all import statements -# ✓ Preserving dependencies -# ✓ Maintaining references -# ✓ Fixing relative imports -# ✓ Resolving naming conflicts - -# Renaming a symbol? Codegen manages: -class_def.rename("NewName") -# ✓ Updating all usages -# ✓ Handling string references -# ✓ Preserving docstrings -# ✓ Maintaining inheritance -``` - -## Performance through Pre-Computation - -Codegen frontloads as much as possible to enable fast, efficient transformations. - -It is built with the insight that each codebase only needs to be parsed once per commit. - - - Learn more about parsing the codebase graph in the [How it - Works](/introduction/how-it-works) guide. - - -## Python-First Composability - -Codegen embraces Python's strength as a "glue language" - its ability to seamlessly integrate different tools and APIs. This makes it natural to compose Codegen with your existing toolchain: - -- Build complex transforms by combining simpler operations -- Integrate Codegen with your existing tools (linters, type checkers, test frameworks, AI tools) - - - Python's rich ecosystem makes it ideal for code manipulation tasks. Codegen is - designed to be one tool in your toolbox, not a replacement for your entire - workflow. - - - ---- -title: "Community & Contributing" -sidebarTitle: "Community" -icon: "people-group" -iconType: "solid" ---- - -import { - COMMUNITY_SLACK_URL, - CODEGEN_SDK_GITHUB_URL, -} from "/snippets/links.mdx"; - -Join the growing Codegen community! We're excited to have you be part of our journey to make codebase manipulation and transformation more accessible. - - - - Connect with the community, get help, and share your Codegen projects in our - active Slack workspace. - - - Star us on GitHub, report issues, submit PRs, and contribute to the project. - - - Follow us for updates, tips, and community highlights. - - - Learn how to use Codegen effectively with our comprehensive guides. - - - - - Please help us improve this library and documentation by submitting a PR! - - -## Contributing - -We welcome contributions of all kinds! Whether you're fixing a typo in documentation, reporting a bug, or implementing a new feature, we appreciate your help in making Codegen better. - -Check out our [Contributing Guide](https://github.com/codegen-sh/codegen-sdk/blob/develop/CONTRIBUTING.md) on GitHub to learn how to: - -- Set up your development environment -- Submit pull requests -- Report issues -- Contribute to documentation - - ---- -title: "Codegen, Inc." -sidebarTitle: "About Us" -icon: "building" -iconType: "solid" ---- - - - -## Our Mission - -Our mission is to build fully-autonomous software engineering - the equivalent of self-driving cars for code. - -We believe the highest leverage path to autonomous development is enabling AI agents to "act via code." - -Just as self-driving cars need sophisticated sensors and controls to navigate the physical world, AI agents need powerful, precise tools to manipulate codebases. We're building that foundational layer: a programmatic interface that lets AI agents express complex code transformations through code itself. - -This approach creates a shared language that both humans and AI can use to: - -- Express powerful changes with precision and predictability -- Build sophisticated tools from primitive operations -- Create and maintain their own abstractions -- Scale transformations across massive codebases - -## The Team - -Based in San Francisco, we're a team of engineers and researchers passionate about: - -- Making large-scale code changes more accessible -- Building tools that work the way developers think -- Creating the infrastructure for AI-powered code manipulation -- Advancing the state of the art in program transformation - -## Open Source - -We believe in the power of open source software. Our core library, [codegen](https://github.com/codegen-sh/codegen-sdk), is freely available and open to contributions from the community. - -## Join Us - - - - We're hiring! Join us in building the future of code transformation. - - - Connect with other developers and share your Codegen experiences. - - - -## Connect with Us - - - - Follow us for updates and announcements - - - Connect with our team and stay updated on company news - - - - - Want to learn more about what we're building? Check out our [getting started - guide](/introduction/getting-started) or join our [community - Slack](https://community.codegen.com). - - - ---- -title: "Frequently Asked Questions" -sidebarTitle: "FAQ" -icon: "square-question" -iconType: "solid" ---- - - - - Codegen currently parses two languages: - - [Python](/api-reference/python) - - [TypeScript](/api-reference/typescript) - - We're actively working on expanding language support based on community needs. - - Learn more about how Codegen handles language specifics in the [Language - Support](/building-with-codegen/language-support) guide. - - - Interested in adding support for your language? [Let us know](https://x.com/codegen) or [contribute](/introduction/community)! - - - - - Pretty much! Codegen is roughly on par with `mypy` and `tsc`. There are always edge cases in static analysis that are provably impossible to get (for example doing `eval()` on a string), but all of Codegen's APIs are intended to be exact unless otherwise specified. Please reach out if you find an edge case and we will do our best to patch it. - - - Yes! Codegen was developed on multmillion-line Python and Typescript codebases - and includes optimizations for handling large-scale transformations. - - For enterprise support, please reach out to [team@codegen.com](mailto:team@codegen.com) - - - - Yes - [by design](/introduction/guiding-principles#python-first-composability). - - Codegen works like any other python package. It works alongside your IDE, version control system, and other development tools. - - - Start by trying out Codegen, joining our [Slack community](https://community.codegen.com), and looking for - issues labeled "good first issue" on [GitHub](https://github.com/codegen-sh/codegen-sdk). We welcome contributions to - documentation, examples, and code improvements. - - - Yes, Codegen is [open source](https://github.com/codegen-sh/codegen-sdk) and free to use under the [Apache 2.0 - license](https://github.com/codegen-sh/codegen-sdk?tab=Apache-2.0-1-ov-file). - You can use it for both personal and commercial projects. - - - The best places to get help are: - 1. Our community [Slack channel](https://community.codegen.com) - 2. [GitHub issues](https://github.com/codegen-sh/codegen-sdk) for bug reports - 3. Reach out to us on [Twitter](https://x.com/codegen) - - - - ---- -title: "Building with Codegen" -sidebarTitle: "At a Glance" -icon: "book" -iconType: "solid" ---- - -Learn how to use Codegen's core APIs to analyze and transform code. - -## Core Concepts - - - - Understand how Codegen parses and analyzes different programming languages. - - - Learn how to work with files, directories, and navigate the codebase - structure. - - - Learn how to safely modify code while preserving formatting and comments. - - - Master the core abstractions for manipulating code safely and effectively. - - - - -## Navigating the Code Graph - - - - Analyze relationships between code elements and track symbol references. - - - Understand function call patterns and manipulate call sites. - - - Work with module imports and manage dependencies. - - - Navigate function call relationships and analyze code flow. - - - -## Code Manipulation - - - - Relocate functions, classes, and other symbols while updating references. - - - Work with code blocks, control flow, and statement manipulation. - - - Handle variable declarations, assignments, and scope. - - - Work with groups of related code elements like functions, classes, and - imports. - - - -## Special Features - - - - Work with React components, JSX syntax, and component transformations. - - - Analyze and manipulate local variable usage and scope. - - - Integrate AI assistance into your code transformations. - - - Visualize code relationships and dependencies. - - - - - Each guide includes practical examples and best practices. Start with core - concepts or jump directly to the topics most relevant to your needs. - - - ---- -title: "Parsing Codebases" -sidebarTitle: "Parsing Codebases" -icon: "power-off" -iconType: "solid" ---- - -The primary entrypoint to programs leveraging Codegen is the [Codebase](/api-reference/core/Codebase) class. - -## Local Codebases - -Construct a Codebase by passing in a path to a local `git` repository or any subfolder within it. The path must be within a git repository (i.e., somewhere in the parent directory tree must contain a `.git` folder). - -```python -from codegen import Codebase - -# Parse from a git repository root -codebase = Codebase("path/to/repository") - -# Parse from a subfolder within a git repository -codebase = Codebase("path/to/repository/src/subfolder") - -# Parse from current directory (must be within a git repo) -codebase = Codebase("./") - -# Specify programming language (instead of inferring from file extensions) -codebase = Codebase("./", language="typescript") -``` - - - By default, Codegen will automatically infer the programming language of the codebase and - parse all files in the codebase. You can override this by passing the `language` parameter - with a value from the `ProgrammingLanguage` enum. - - - - The initial parse may take a few minutes for large codebases. This - pre-computation enables constant-time operations afterward. [Learn more - here.](/introduction/how-it-works) - - -## Remote Repositories - -To fetch and parse a repository directly from GitHub, use the `from_repo` function. - -```python -from codegen import Codebase -# Fetch and parse a repository (defaults to /tmp/codegen/{repo_name}) -codebase = Codebase.from_repo('fastapi/fastapi') - -# Customize temp directory, clone depth, specific commit, or programming language -codebase = Codebase.from_repo( - 'fastapi/fastapi', - tmp_dir='/custom/temp/dir', # Optional: custom temp directory - commit='786a8ada7ed0c7f9d8b04d49f24596865e4b7901', # Optional: specific commit - shallow=False, # Optional: full clone instead of shallow - language="python" # Optional: override language detection -) -``` - - - Remote repositories are cloned to the `/tmp/codegen/{repo_name}` directory by - default. The clone is shallow by default for better performance. - - -## Configuration Options - -You can customize the behavior of your Codebase instance by passing a `CodebaseConfig` object. This allows you to configure secrets (like API keys) and toggle specific features: - -```python -from codegen import Codebase -from codegen.configs.models.codebase import CodebaseConfig -from codegen.configs.models.secrets import SecretsConfig - -codebase = Codebase( - "path/to/repository", - config=CodebaseConfig(debug=True), - secrets=SecretsConfig(openai_api_key="your-openai-key") # For AI-powered features -) -``` - -- `CodebaseConfig` and `SecretsConfig` allow you to configure - - `config`: Toggle specific features like language engines, dependency management, and graph synchronization - - `secrets`: API keys and other sensitive information needed by the codebase - -For a complete list of available feature flags and configuration options, see the [source code on GitHub](https://github.com/codegen-sh/codegen-sdk/blob/develop/src/codegen/sdk/codebase/config.py). - -## Advanced Initialization - -For more complex scenarios, Codegen supports an advanced initialization mode using `ProjectConfig`. This allows for fine-grained control over: - -- Repository configuration -- Base path and subdirectory filtering -- Multiple project configurations - -Here's an example: - -```python -from codegen import Codebase -from codegen.git.repo_operator.local_repo_operator import LocalRepoOperator -from codegen.git.schemas.repo_config import BaseRepoConfig -from codegen.sdk.codebase.config import ProjectConfig - -codebase = Codebase( - projects = [ - ProjectConfig( - repo_operator=LocalRepoOperator( - repo_path="/tmp/codegen-sdk", - repo_config=BaseRepoConfig(), - bot_commit=True - ), - language="typescript", - base_path="src/codegen/sdk/typescript", - subdirectories=["src/codegen/sdk/typescript"] - ) - ] -) -``` - -For more details on advanced configuration options, see the [source code on GitHub](https://github.com/codegen-sh/codegen-sdk/blob/develop/src/codegen/sdk/core/codebase.py). - -## Supported Languages - -Codegen currently supports: - -- [Python](/api-reference/python) -- [TypeScript/JavaScript](/api-reference/typescript) -- [React/JSX](/building-with-codegen/react-and-jsx) - - ---- -title: "Reusable Codemods" -sidebarTitle: "Reusable Codemods" -icon: "arrows-rotate" -iconType: "solid" ---- - -Codegen enables you to create reusable code transformations using Python functions decorated with `@codegen.function`. These codemods can be shared, versioned, and run by your team. - -## Creating Codemods - -The easiest way to create a new codemod is using the CLI [create](/cli/create) command: - -```bash -codegen create rename-function -``` - -This creates a new codemod in your `.codegen/codemods` directory: - -```python -import codegen -from codegen import Codebase - -@codegen.function("rename-function") -def run(codebase: Codebase): - """Add a description of what this codemod does.""" - # Add your code here - pass -``` - - - Codemods are stored in `.codegen/codemods/name/name.py` and are tracked in Git for easy sharing. - - -### AI-Powered Generation with `-d` - -You can use AI to generate an initial implementation by providing a description: - -```bash -codegen create rename-function -d "Rename the getUserData function to fetchUserProfile" -``` - -This will: -1. Generate an implementation based on your description -2. Create a custom system prompt that you can provide to an IDE chat assistant (learn more about [working with AI](/introduction/work-with-ai)) -3. Place both files in the codemod directory - -## Running Codemods - -Once created, run your codemod using: - -```bash -codegen run rename-function -``` - -The execution flow: -1. Codegen parses your codebase into a graph representation -2. Your codemod function is executed against this graph -3. Changes are tracked and applied to your filesystem -4. A diff preview shows what changed - - -## Codemod Structure - -A codemod consists of three main parts: - -1. The `@codegen.function` decorator that names your codemod -2. A `run` function that takes a `Codebase` parameter -3. Your transformation logic using the Codebase API - -```python -import codegen -from codegen import Codebase - -@codegen.function("update-imports") -def run(codebase: Codebase): - """Update import statements to use new package names.""" - for file in codebase.files: - for imp in file.imports: - if imp.module == "old_package": - imp.rename("new_package") - codebase.commit() -``` - -## Arguments - -Codemods can accept arguments using Pydantic models: - -```python -from pydantic import BaseModel - -class RenameArgs(BaseModel): - old_name: str - new_name: str - -@codegen.function("rename-function") -def run(codebase: Codebase, arguments: RenameArgs): - """Rename a function across the codebase.""" - old_func = codebase.get_function(arguments.old_name) - if old_func: - old_func.rename(arguments.new_name) - codebase.commit() -``` - -Run it with: -```bash -codegen run rename-function --arguments '{"old_name": "getUserData", "new_name": "fetchUserProfile"}' -``` - -## Directory Structure - -Your codemods live in a dedicated directory structure: - -``` -.codegen/ -└── codemods/ - └── rename_function/ - ├── rename_function.py # The codemod implementation - └── rename_function_prompt.md # System prompt (if using AI) -``` - ---- -title: "The .codegen Directory" -sidebarTitle: ".codegen Directory" -icon: "folder" -iconType: "solid" ---- - -The `.codegen` directory contains your project's Codegen configuration, codemods, and supporting files. It's automatically created when you run `codegen init`. - -## Directory Structure - -```bash -.codegen/ -├── .venv/ # Python virtual environment (gitignored) -├── config.toml # Project configuration -├── codemods/ # Your codemod implementations -├── jupyter/ # Jupyter notebooks for exploration -└── codegen-system-prompt.txt # AI system prompt -``` - -## Initialization - -The directory is created and managed using the `codegen init` command: - -```bash -codegen init [--fetch-docs] [--repo-name NAME] [--organization-name ORG] -``` - - -The `--fetch-docs` flag downloads API documentation and examples specific to your project's programming language. - - -## Virtual Environment - -Codegen maintains its own virtual environment in `.codegen/.venv/` to ensure consistent package versions and isolation from your project's dependencies. This environment is: - -- Created using `uv` for fast, reliable package management -- Initialized with Python 3.13 -- Automatically managed by Codegen commands -- Used for running codemods and Jupyter notebooks -- Gitignored to avoid committing environment-specific files - -The environment is created during `codegen init` and used by commands like `codegen run` and `codegen notebook`. - -To debug codemods, you will need to set the python virtual environment in your IDE to `.codegen/.venv` - -### Configuration - -The `.env` file stores your project settings: - -```env -REPOSITORY_OWNER = "your-org" -REPOSITORY_PATH = "/root/git/your-repo" -REPOSITORY_LANGUAGE = "python" # or other supported language -``` - -This configuration is used by Codegen to provide language-specific features and proper repository context. - -## Git Integration - -Codegen automatically adds appropriate entries to your `.gitignore`: - -```gitignore -# Codegen -.codegen/.venv/ -.codegen/docs/ -.codegen/jupyter/ -.codegen/codegen-system-prompt.txt -``` - - -- While most directories are ignored, your codemods in `.codegen/codemods/` and `config.toml` are tracked in Git -- The virtual environment and Jupyter notebooks are gitignored to avoid environment-specific issues - - -## Working with Codemods - -The `codemods/` directory is where your transformation functions live. You can create new codemods using: - -```bash -codegen create my-codemod [--description "what it does"] -``` - -This will: -1. Create a new file in `.codegen/codemods/` -2. Generate a system prompt in `.codegen/prompts/` (if using `--description`) -3. Set up the necessary imports and decorators - - -Use `codegen list` to see all codemods in your project. - - -## Jupyter Integration - -The `jupyter/` directory contains notebooks for interactive development: - -```python -from codegen import Codebase - -# Initialize codebase -codebase = Codebase('../../') - -# Print stats -print(f"📚 Total Files: {len(codebase.files)}") -print(f"⚡ Total Functions: {len(codebase.functions)}") -``` - - -A default notebook is created during initialization to help you explore your codebase. - - -## Next Steps - -After initializing your `.codegen` directory: - -1. Create your first codemod: -```bash -codegen create my-codemod -d "describe what you want to do" -``` - -2. Run it: -```bash -codegen run my-codemod --apply-local -``` - -3. Deploy it for team use: -```bash -codegen deploy my-codemod -``` - - ---- -title: Function Decorator -sidebarTitle: "@codegen.function" -icon: "at" -iconType: "solid" ---- - -# Function Decorator - -The `function` decorator is used to define codegen functions within your application. It allows you to specify a name for the function that will be ran making it easier to run specific codemods - -## Usage - -To use the `function` decorator, simply annotate your function with `@codegen.function` and provide a name as an argument. - -### Example - -```python -@codegen.function('my-function') -def run(codebase): - pass -``` - -In this example, the function `run` is decorated with `@codegen.function` and given the name `'my-function'`. This name will be used when the function is ran. - -## Parameters - -- `name` (str): The name of the function to be used when ran. - -## Description - -The `function` decorator is part of the codegen SDK CLI and is used to mark functions that are intended to be ran as part of a code generation process. It ensures that the function is properly registered and can be invoked with the specified name. - - -## CLI Examples - -### Running a Function - -To run a deployed function using the CLI, use the following command: - -```bash -codegen run my-function -``` - -This command runs the function named `my-function`. - -## See Also - -- [Codebase Visualization](./codebase-visualization.mdx): For visualizing codebases in your application. -- [CLI Init Command](../cli/init.mdx): For initializing projects or environments related to the function decorator. -- [CLI Create Command](../cli/create.mdx): For creating new functions or projects using the CLI. -- [CLI Run Command](../cli/run.mdx): For running code or scripts using the CLI. - - ---- -title: "Language Support" -sidebarTitle: "Language Support" -icon: "binary" -iconType: "solid" ---- - -Codegen provides first-class support for both Python and TypeScript codebases. The language is automatically inferred when you initialize a codebase. - -## Language Detection - -When you create a new `Codebase` instance, Codegen automatically detects the programming language: - -```python -from codegen import Codebase - -# Automatically detects Python or TypeScript -codebase = Codebase("./") - -# View language with `codebase.language` -print(codebase.language) # "python" or "typescript" -``` - - - Learn more about codebase initialization options in [Parsing - Codebases](/building-with-codegen/parsing-codebases). - - -## Type System - -Codegen uses specialized types for each language. These are defined as type aliases: - -```python -# Python codebases use PyCodebaseType -PyCodebaseType = Codebase[ - PyFile, Directory, PySymbol, PyClass, PyFunction, - PyImport, PyAssignment, Interface, TypeAlias, - PyParameter, PyCodeBlock -] - -# TypeScript codebases use TSCodebaseType -TSCodebaseType = Codebase[ - TSFile, Directory, TSSymbol, TSClass, TSFunction, - TSImport, TSAssignment, TSInterface, TSTypeAlias, - TSParameter, TSCodeBlock -] -``` - -Every code element has both a Python and TypeScript implementation that inherits from a common base class. For example: - -- [Function](/api-reference/core/Function) - - [PyFunction](/api-reference/python/PyFunction) - - [TSFunction](/api-reference/typescript/TSFunction) -- [Class](/api-reference/core/Class) - - [PyClass](/api-reference/python/PyClass) - - [TSClass](/api-reference/typescript/TSClass) -- [Import](/api-reference/core/Import) - - [PyImport](/api-reference/python/PyImport) - - [TSImport](/api-reference/typescript/TSImport) - -... - -```python -# Base class (core/function.py) -class Function: - """Abstract representation of a Function.""" - pass - -# Python implementation (python/function.py) -class PyFunction(Function): - """Extends Function for Python codebases.""" - pass - -# TypeScript implementation (typescript/function.py) -class TSFunction(Function): - """Extends Function for TypeScript codebases.""" - pass -``` - -This inheritance pattern means that most Codegen programs can work with either Python or TypeScript without modification, since they share the same API structure. - -```python -# Works for both Python and TypeScript -for function in codebase.functions: - print(f"Function: {function.name}") - print(f"Parameters: {[p.name for p in function.parameters]}") - print(f"Return type: {function.return_type}") -``` - -## TypeScript-Specific Features - -Some features are only available in TypeScript codebases: - -- **Types and Interfaces**: TypeScript's rich type system ([TSTypeAlias](/api-reference/typescript/TSTypeAlias), [TSInterface](/api-reference/typescript/TSInterface)) -- **Exports**: Module exports and re-exports ([TSExport](/api-reference/typescript/TSExport)) -- **JSX/TSX**: React component handling (see [React and JSX](/building-with-codegen/react-and-jsx)) - -Example of TypeScript-specific features: - -```python -# Only works with TypeScript codebases -if isinstance(codebase, TSCodebaseType): - # Work with TypeScript interfaces - for interface in codebase.interfaces: - print(f"Interface: {interface.name}") - print(f"Extends: {[i.name for i in interface.parent_interfaces]}") - - # Work with type aliases - for type_alias in codebase.type_aliases: - print(f"Type alias: {type_alias.name}") -``` - - ---- -title: "Commit and Reset" -sidebarTitle: "Commit and Reset" -icon: "arrows-rotate" -iconType: "solid" ---- - -Codegen requires you to explicitly commit changes by calling [codebase.commit()](/api-reference/core/Codebase#commit). - - - Keeping everything in memory enables fast, large-scale writes. See the [How it - Works](/introduction/how-it-works) guide to learn more. - - -You can manage your codebase's state with two core APIs: - -- [Codebase.commit()](/api-reference/core/Codebase#commit) - Commit changes to disk -- [Codebase.reset()](/api-reference/core/Codebase#reset) - Reset the `codebase` and filesystem to its initial state - -## Committing Changes - -When you make changes to your codebase through Codegen's APIs, they aren't immediately written to disk. You need to explicitly commit them with [codebase.commit()](/api-reference/core/Codebase#commit): - -```python -from codegen import Codebase - -codebase = Codebase("./") - -# Make some changes -file = codebase.get_file("src/app.py") -file.before("# 🌈 hello, world!") - -# Changes aren't on disk yet -codebase.commit() # Now they are! -``` - -This transaction-like behavior helps ensure your changes are atomic and consistent. - -## Resetting State - -The [codebase.reset()](/api-reference/core/Codebase#reset) method allows you to revert the codebase to its initial state: - -```python -# Make some changes -codebase.get_file("src/app.py").remove() -codebase.create_file("src/new_file.py", "x = 1") - -# Check the changes -assert codebase.get_file("src/app.py", optional=True) is None -assert codebase.get_file("src/new_file.py") is not None - -# Reset everything -codebase.reset() - -# Changes are reverted -assert codebase.get_file("src/app.py") is not None -assert codebase.get_file("src/new_file.py", optional=True) is None -``` - - - `reset()` reverts both the in-memory state and any uncommitted filesystem - changes. However, it preserves your codemod implementation in `.codegen/`. - - - ---- -title: "Git Operations" -sidebarTitle: "Git Operations" -icon: "code-branch" ---- - -Many workflows require Git operations. Codegen provides a high-level API for common Git operations through the [Codebase](/api-reference/core/Codebase) class, including: - -- [Codebase.git_commit(...)](/api-reference/core/Codebase#git_commit) -- [Codebase.checkout(...)](/api-reference/core/Codebase#checkout) - -## Committing Changes to Git - -You can commit changes to Git using the [Codebase.git_commit(...)](/api-reference/core/Codebase#git_commit): - -```python -# Make some changes and call `commit()` to sync them to disk -codebase.functions[0].rename('foo') -codebase.commit() - -# Commit all staged changes to git with a message -commit = codebase.git_commit("feat: update function signatures") - -# You can also verify the commit (runs pre-commit hooks) -commit = codebase.git_commit("feat: update signatures", verify=True) - -# The method returns the commit object if changes were committed, None otherwise -if commit: - print(f"Created commit: {commit.hexsha}") -``` - - - `git_commit` will only commit changes that have been synced to the filesystem - by calling [Codebase.commit()](/api-reference/core/Codebase#commit). See - [Commit and Reset](/building-with-codegen/commit-and-reset) for more - details. - - -## Checking Current Git State - -Codegen provides properties to check the current Git state: - -```python -# Get the default branch (e.g. 'main' or 'master') -default = codebase.default_branch -print(f"Default branch: {default}") - -# Get the current commit -current = codebase.current_commit -if current: - print(f"Current commit: {current.hexsha}") -``` - -## Checking Out Branches and Commits - -The [Codebase.checkout(...)](/api-reference/core/Codebase#checkout) method allows you to switch between branches and commits. - -This will automatically re-parse the codebase to reflect the new state. - -```python -# Checkout a branch -result = codebase.checkout(branch="feature/new-api") - -# Create a new branch if it doesn't exist -result = codebase.checkout(branch="feature/new-api", create_if_missing=True) - -# Checkout a specific commit -result = codebase.checkout(commit="abc123") - -# Checkout and pull from remote -result = codebase.checkout(branch="main", remote=True) -``` - - ---- -title: "Files and Directories" -sidebarTitle: "Files & Directories" -icon: "folder-tree" -iconType: "solid" ---- - -Codegen provides three primary abstractions for working with your codebase's file structure: - -- [File](/api-reference/core/File) - Represents a file in the codebase (e.g. README.md, package.json, etc.) -- [SourceFile](/api-reference/core/SourceFile) - Represents a source code file (e.g. Python, TypeScript, React, etc.) -- [Directory](/api-reference/core/Directory) - Represents a directory in the codebase - - - [SourceFile](/api-reference/core/SourceFile) is a subclass of [File](/api-reference/core/File) that provides additional functionality for source code files. - - - -## Accessing Files and Directories - -You typically access files from the [codebase](/api-reference/core/Codebase) object with two APIs: - -- [codebase.get_file(...)](/api-reference/core/Codebase#get-file) - Get a file by its path -- [codebase.files](/api-reference/core/Codebase#files) - Enables iteration over all files in the codebase - -```python -# Get a file from the codebase -file = codebase.get_file("path/to/file.py") - -# Iterate over all files in the codebase -for file in codebase.files: - pass - -# Check if a file exists -exists = codebase.has_file("path/to/file.py") - -``` - - -These APIs are similar for [Directory](/api-reference/core/Directory), which provides similar methods for accessing files and subdirectories. - -```python -# Get a directory -dir = codebase.get_directory("path/to/dir") - -# Iterate over all files in the directory -for file in dir.files: - pass - -# Get the directory containing a file: -dir = file.directory - -# Check if a directory exists -exists = codebase.has_directory("path/to/dir") -``` - -## Differences between SourceFile and File - -- [File](/api-reference/core/File) - a general purpose class that represents any file in the codebase including non-code files like README.md, .env, .json, image files, etc. -- [SourceFile](/api-reference/core/SourceFile) - a subclass of [File](/api-reference/core/File) that provides additional functionality for source code files written in languages supported by the [codegen-sdk](/introduction/overview) (Python, TypeScript, JavaScript, React). - -The majority of intended use cases involve using exclusively [SourceFile](/api-reference/core/SourceFile) objects as these contain code that can be parsed and manipulated by the [codegen-sdk](/introduction/overview). However, there may be cases where it will be necessary to work with non-code files. In these cases, the [File](/api-reference/core/File) class can be used. - -By default, the `codebase.files` property will only return [SourceFile](/api-reference/core/SourceFile) objects. To include non-code files the `extensions='*'` argument must be used. - -```python -# Get all source files in the codebase -source_files = codebase.files - -# Get all files in the codebase (including non-code files) -all_files = codebase.files(extensions="*") -``` - - -When getting a file with `codebase.get_file`, files ending in `.py, .js, .ts, .jsx, .tsx` are returned as [SourceFile](/api-reference/core/SourceFile) objects while other files are returned as [File](/api-reference/core/File) objects. - -Furthermore, you can use the `isinstance` function to check if a file is a [SourceFile](/api-reference/core/SourceFile): - -```python -py_file = codebase.get_file("path/to/file.py") -if isinstance(py_file, SourceFile): - print(f"File {py_file.filepath} is a source file") - -# prints: `File path/to/file.py is a source file` - -mdx_file = codebase.get_file("path/to/file.mdx") -if not isinstance(mdx_file, SourceFile): - print(f"File {mdx_file.filepath} is a non-code file") - -# prints: `File path/to/file.mdx is a non-code file` -``` - - - Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. - - -## Accessing Code - -[SourceFiles](/api-reference/core/SourceFile) and [Directories](/api-reference/core/Directory) provide several APIs for accessing and iterating over their code. - -See, for example: - -- `.functions` ([SourceFile](/api-reference/core/SourceFile#functions) / [Directory](/api-reference/core/Directory#functions)) - All [Functions](/api-reference/core/Function) in the file/directory -- `.classes` ([SourceFile](/api-reference/core/SourceFile#classes) / [Directory](/api-reference/core/Directory#classes)) - All [Classes](/api-reference/core/Class) in the file/directory -- `.imports` ([SourceFile](/api-reference/core/SourceFile#imports) / [Directory](/api-reference/core/Directory#imports)) - All [Imports](/api-reference/core/Import) in the file/directory -- `.get_function(...)` ([SourceFile](/api-reference/core/SourceFile#get-function) / [Directory](/api-reference/core/Directory#get-function)) - Get a specific function by name -- `.get_class(...)` ([SourceFile](/api-reference/core/SourceFile#get-class) / [Directory](/api-reference/core/Directory#get-class)) - Get a specific class by name -- `.get_global_var(...)` ([SourceFile](/api-reference/core/SourceFile#get-global-var) / [Directory](/api-reference/core/Directory#get-global-var)) - Get a specific global variable by name - - -```python -# Get all functions in a file -for function in file.functions: - print(f"Found function: {function.name}") - print(f"Parameters: {[p.name for p in function.parameters]}") - print(f"Return type: {function.return_type}") - -# Get all classes -for cls in file.classes: - print(f"Found class: {cls.name}") - print(f"Methods: {[m.name for m in cls.methods]}") - print(f"Attributes: {[a.name for a in cls.attributes]}") - -# Get imports (can also do `file.import_statements`) -for imp in file.imports: - print(f"Import from: {imp.module}") - print(f"Imported symbol: {[s.name for s in imp.imported_symbol]}") - -# Get specific symbols -main_function = file.get_function("main") -user_class = file.get_class("User") -config = file.get_global_var("CONFIG") - -# Access code blocks -if main_function: - for statement in main_function.code_block.statements: - print(f"Statement type: {statement.statement_type}") - -# Get local variables in a function -if main_function: - local_vars = main_function.code_block.get_local_var_assignments() - for var in local_vars: - print(f"Local var: {var.name} = {var.value}") -``` - -## Working with Non-Code Files (README, JSON, etc.) - -By default, Codegen focuses on source code files (Python, TypeScript, etc). However, you can access all files in your codebase, including documentation, configuration, and other non-code [files](/api-reference/core/File) like README.md, package.json, or .env: - -```python -# Get all files in the codebase (including README, docs, config files) -files = codebase.files(extensions="*") - -# Print files that are not source code (documentation, config, etc) -for file in files: - if not file.filepath.endswith(('.py', '.ts', '.js')): - print(f"📄 Non-code file: {file.filepath}") -``` - -You can also filter for specific file types: - -```python -# Get only markdown documentation files -docs = codebase.files(extensions=[".md", ".mdx"]) - -# Get configuration files -config_files = codebase.files(extensions=[".json", ".yaml", ".toml"]) -``` - -These APIs are similar for [Directory](/api-reference/core/Directory), which provides similar methods for accessing files and subdirectories. - -## Raw Content and Metadata - -```python -# Grab raw file string content -content = file.content # For text files -print('Length:', len(content)) -print('# of functions:', len(file.functions)) - -# Access file metadata -name = file.name # Base name without extension -extension = file.extension # File extension with dot -filepath = file.filepath # Full relative path -dir = file.directory # Parent directory - -# Access directory metadata -name = dir.name # Base name without extension -path = dir.path # Full relative path from repository root -parent = dir.parent # Parent directory -``` - -## Editing Files Directly - -Files themselves are [Editable](/api-reference/core/Editable.mdx) objects, just like Functions and Classes. - - - Learn more about the [Editable API](/building-with-codegen/the-editable-api). - - -This means they expose many useful operations, including: - -- [File.search](/api-reference/core/File#search) - Search for all functions named "main" -- [File.edit](/api-reference/core/File#edit) - Edit the file -- [File.replace](/api-reference/core/File#replace) - Replace all instances of a string with another string -- [File.insert_before](/api-reference/core/File#insert-before) - Insert text before a specific string -- [File.insert_after](/api-reference/core/File#insert-after) - Insert text after a specific string -- [File.remove](/api-reference/core/File#remove) - Remove a specific string - -```python -# Get a file -file = codebase.get_file("path/to/file.py") - -# Replace all instances of a string -file.replace("name", "new_name") -file.replace("name", "new_name", include_comments=False) # Don't edit comments - -# Replace entire text of the file -file.edit('hello, world!') - -# Get + delete all instances of a string -for editable in file.search("foo"): - editable.remove() - -# Insert text at the top of the file -file.insert_before("def main():\npass") -# ... or at the bottom -file.insert_after("def end():\npass") - -# Delete the file -file.remove() -``` - -You can frequently do bulk modifictions via the [.edit(...)](/api-reference/core/Editable#edit) method or [.replace(...)](/api-reference/core/File#replace) method. - - - Most useful operations will have bespoke APIs that handle edge cases, update - references, etc. - - -## Moving and Renaming Files - -Files can be manipulated through methods like [File.update_filepath()](/api-reference/core/File#update-filepath), [File.rename()](/api-reference/core/File#rename), and [File.remove()](/api-reference/core/File#remove): - -```python -# Move/rename a file -file.update_filepath("/path/to/foo.py") # Move to new location -file.rename("bar") # Rename preserving extension, e.g. `bar.py` - -# Remove a file (potentially destructive) -file.remove() - -# Move all tests to a tests directory -for file in codebase.files: - if 'test_' in file.name: - # This will handle updating imports and other references - file.update_filepath('tests/' + file.filepath.replace("test_", "")) -``` - - - Removing files is a potentially breaking operation. Only remove files if they - have no external usages. - - -## Directories - -[Directories](/api-reference/core/Directory) expose a similar API to the [File](/api-reference/core/File.mdx) class, with the addition of the `subdirectories` property. - -```python -# Get a directory -dir = codebase.get_directory("path/to/dir") - -# Iterate over all directories in the codebase -for directory in codebase.directories: - print(f"Found directory: {directory.path}") - -# Check directory existence -exists = codebase.has_directory("path/to/dir") - -# Access metadata -name = dir.name # Directory name -path = dir.path # Full path -parent = dir.parent # Parent directory - -# Get specific items -file = dir.get_file("file.py") -subdir = dir.get_subdirectory("subdir") - -# Get all ancestor subdirectories -subdirs = dir.subdirectories - -# Get the parent directory -parent_dir = dir.parent - -# Find all child directories -for subdir in dir.subdirectories: - if dir.parent == subdir: - print(f"Found child subdirectory: {subdir.path}") - -# Move to new location -dir.update_filepath("new/path") - -# Rename directory in place -dir.rename("new_name") - -# Remove a directory and all contents (potentially destructive) -dir.remove() -``` - - - Removing directories is a potentially destructive operation. Only remove - directories if they have no external usages. - - - ---- -title: "The Editable API" -sidebarTitle: "Editables" -icon: "pencil" -iconType: "solid" ---- - -Every code element in Codegen is an [Editable](../api-reference/core/Editable) - meaning it can be manipulated while maintaining correctness. - -All higher-level code manipulation APIs are built on top of the atomic Editable API. - -## Core Concepts - -Every Editable provides: - -- Information about the source code: - - [source](../api-reference/core/Editable#source) - the text content of the Editable - - [extended_source](../api-reference/core/Editable#extended_source) - includes relevant content like decorators, comments, etc. -- Information about the file that contains the Editable: - - [file](../api-reference/core/Editable#file) - the [SourceFile](../api-reference/core/SourceFile) that contains this Editable -- Relationship tracking - - [parent_class](../api-reference/core/Editable#parent-class) - the [Class](../api-reference/core/Class) that contains this Editable - - [parent_function](../api-reference/core/Editable#parent-function) - the [Function](../api-reference/core/Function) that contains this Editable - - [parent_statement](../api-reference/core/Editable#parent-statement) - the [Statement](../api-reference/core/Statement) that contains this Editable -- Safe modification operations - -## Basic Editing - -There are several fundamental ways to modify code with Editables: - -```python -# 1. edit() - Replace entire source with new content -function = codebase.get_function("process_data") -function.edit(""" -def process_data(input_data: dict) -> dict: - return transform(input_data) -""") - -# 2. Replace - Substitute text while preserving context -class_def = codebase.get_class("UserModel") -class_def.replace("user_id", "account_id") # Updates all occurrences - -# 3. Remove - Safely delete code with proper cleanup -unused_import = file.get_import("from utils import deprecated_func") -unused_import.remove() # Handles formatting, commas, etc - -# 4. Insert - Add code before or after an element -function.insert_before("# Process user input") # Adds comment before function -function.insert_after(""" -def validate_data(data: dict) -> bool: - return all(required in data for required in REQUIRED_FIELDS) -""") # Adds new function after -``` - -## Finding and Searching - -Editables provide powerful search capabilities: - -```python -# Find string literals -results = function.find_string_literals(["error", "warning"]) -results = function.find_string_literals(["error"], fuzzy_match=True) - -# Search with regex -matches = function.search(r"data\['[^']*'\]") # Find dict access -matches = function.search("TODO:", include_comments=True) - -# Find specific patterns -variables = function.get_variable_usages("config") -function_calls = function.function_calls # All function calls within this node -``` - -## Smart Formatting - -Codegen handles formatting details automatically: - -```python -# Adding to import statements -import_stmt = file.get_import("from mylib import func1") -import_stmt.add_symbol("func2") # Handles comma placement -import_stmt.add_symbol("func3") # Maintains proper formatting - -# Multi-line formatting is preserved -from mylib import ( - func1, - func2, # New imports maintain - func3 # existing style -) -``` - -## Safe Removals - -Removing code elements is safe and clean: - -```python -# Remove a function and its decorators -function.remove() # Removes associated comments and formatting - -# Remove imports cleanly -import_stmt.remove() # Handles commas and whitespace -``` - -## Working with References - -Editables track their relationships to other code elements: - -```python -# Find and update all references -function = codebase.get_function("old_name") -function.rename("new_name") # Updates all usages - -# Navigate relationships -print(function.parent_function) # Containing function -print(function.parent_class) # Containing class -print(function.parent_statement) # Containing statement -``` - -## Understanding Context - -Editables provide rich information about their location and context in the code: - -### Parent Relationships - -```python -# Get containing elements -function = codebase.get_function("process_data") -print(function.parent_class) # Class containing this function -print(function.parent_function) # Function containing this function (for nested functions) -print(function.parent_statement) # Statement containing this function - -# Check if top-level -is_top_level = function.parent_class is None and function.parent_function is None -``` - -### Statement Containment - -The `is_wrapped_in` method lets you check if an Editable is contained within specific types of statements: - -```python -# Check containment in statement types -is_in_try = function.is_wrapped_in("try") -is_in_if = function.is_wrapped_in("if") -is_in_while = function.is_wrapped_in("while") - -# Get the first parent statements of a certain type -if_block = function.parent_of_type(IfStatement) - -# Common patterns -if function.is_wrapped_in(IfStatement): - print("This is in an IfBlock") - -if variable.is_wrapped_in(WithStatement): - print("Variable used in WithStatement") -``` - -### Common Use Cases - -```python -# Move nested functions to module level -for func in file.functions: - if func.parent_function: # This is a nested function - func.parent_function.insert_before(func.source) # Move to module level - func.remove() # Remove the nested function - -# Find variables defined in unsafe blocks -for var in function.code_block.get_local_var_assignments(): - if var.is_wrapped_in(TryStatement): - print(f"Warning: {var.name} defined in try block") -``` - - - ---- -title: "The Symbol API" -sidebarTitle: "Symbols" -icon: "shapes" -iconType: "solid" ---- - -The [Symbol](/api-reference/core/Symbol) is the primary way developers interact with code in Codegen. It maps to how developers think about code - as functions, classes, variables, and other named entities. - -Both the [Function](/api-reference/core/Function) and [Class](/api-reference/core/Class) symbols are subclasses of the [Symbol](/api-reference/core/Symbol) class. - -## Accessing Symbols - -The [Codebase](/api-reference/core/Codebase) class provides getters and iterators for functions, classes and symbols: - -```python -# Core symbol types -symbol = codebase.get_symbol("process_data") # will return a Function, Class, etc. -function = codebase.get_function("process_data") -class_def = codebase.get_class("DataProcessor") - -# Iterate over all symbols (includes functions + classes) -for symbol in codebase.symbols: - print(symbol.name) - -# Iterate over all functions and classes -for symbol in codebase.functions + codebase.classes: - print(symbol.name) -``` - -## Shared APIs - -All symbols share common APIs for manipulation: - -- The [Editable](/api-reference/core/Editable) API -- Metadata - - [symbol.name](/api-reference/core/Symbol#name) - - [symbol.source](/api-reference/core/Symbol#source) - - [symbol.docstring](/api-reference/core/Symbol#docstring) -- Edit operations - - [symbol.set_docstring](/api-reference/core/Symbol#set-docstring) - - [symbol.move_to_file](/api-reference/core/Symbol#move-to-file) (see [Moving Symbols](/building-with-codegen/moving-symbols)) -- Graph relations (See [Usages and Dependencies](/building-with-codegen/dependencies-and-usages)) - - [symbol.usages](/api-reference/core/Symbol#usages) - - [symbol.dependencies](/api-reference/core/Symbol#dependencies) - -## Name operations - -```python -# Name operations -print(symbol.name) -symbol.rename("new_name") - -# Source code -print(symbol.source) # Get source code -symbol.edit("new source code") # Modify source - -# Documentation -print(symbol.docstring) # Get docstring -symbol.set_docstring("New documentation") - -# Move symbol to new file -symbol.move_to_file(new_file) - -# Add before/after other symbols -symbol.insert_before("# deprecated") -symbol.insert_after("# end deprecated") -``` - -## Function Statement Manipulation - -Functions provide special APIs for adding statements to their body: - -- [Function.prepend_statements](/api-reference/core/Function#prepend_statements) - add statements to the start of the function body -- [Function.add_statements](/api-reference/core/Function#add_statements) - add statements to the end of the function body - -```python -# Add statements at the start of a function -function.prepend_statements("print('Starting function')") -method.prepend_statements("self.validate_input()") - -# Add statements at the end of a function -function.add_statements("print('Done')") -method.add_statements("return self.result") -``` - - - The statement manipulation APIs (`prepend_statements` and `add_statements`) - are only available on Function objects. For other symbols, use the general - Editable APIs like `insert_before` and `insert_after`. - - -## Common Patterns - -Most Codegen programs focus on finding and manipulating symbols: - -```python -# Find and modify functions -for function in codebase.functions: - if function.name.startswith("old_"): - # Rename function - function.rename(function.name.replace("old_", "new_")) - # Update docstring - function.set_docstring("Updated version of function") - -# Update class methods -for method in class_def.methods: - # Add logging - method.prepend_statements("logger.info('Called {}'".format(method.name)) -``` - - - The Symbol API is designed to be intuitive and match how developers think - about code. Most transformations start with finding relevant symbols and then - applying changes to them. - - - ---- -title: "The Class API" -sidebarTitle: "Classes" -icon: "cube" -iconType: "solid" ---- - -The [Class](/api-reference/core/Class) API extends the [Symbol](/building-with-codegen/symbol-api) API to support methods, attributes, and inheritance hierarchies. - -## Methods and Method Usages - -Classes provide access to their methods and method [usages](/building-with-codegen/dependencies-and-usages) through an intuitive API: - -```python -# Access methods -for method in class_def.methods: - print(f"Method: {method.name}") - # Find all usages of this method - for usage in method.usages: - print(f"Used in {usage.file.name}") - -# Get specific methods -init_method = class_def.constructor # Get __init__ method -process_method = class_def.get_method("process_data") - -# Filter methods -public_methods = class_def.methods(private=False) # Exclude private methods -regular_methods = class_def.methods(magic=False) # Exclude magic methods -``` - - - Methods are typed as [Function](/api-reference/core/Function) objects. - - -## Class Attributes - -[Attributes](/api-reference/core/Attribute) can be accessed and modified easily: - -```python -# Access all attributes -for attr in class_def.attributes: - print(f"Attribute: {attr.name}") - -# Add new attributes -class_def.add_attribute_from_source("count: int = 0") - -# Get specific attribute -name_attr = class_def.get_attribute("name") - -# Add attribute from another class -other_class = codebase.get_class("OtherClass") -class_def.add_attribute( - other_class.get_attribute("config"), - include_dependencies=True # Also adds required imports -) -``` - -### Manipulating Attributes - -[Attributes](/api-reference/core/Attribute) expose their own API for modification and analysis: - -```python -# Modify attribute values and types -attr = class_def.get_attribute("count") -attr.set_value("42") # Change value -attr.assignment.set_type_annotation("float") # Change type -attr.assignment.type.remove() # Remove type annotation - -# Find attribute usages -for usage in attr.usages: - print(f"Used in {usage.file.name}") - -# Find local usages (within the class) -for usage in attr.local_usages: - print(f"Used in method: {usage.parent_function.name}") - -# Rename attributes (updates all references) -attr.rename("new_name") # Also updates self.count -> self.new_name - -# Remove attributes -attr.remove() # Removes the attribute definition - -# Check attribute properties -if attr.is_private: # Starts with underscore - print("Private attribute") -if attr.is_optional: # Optional[Type] or Type | None - print("Optional attribute") - -# Access underlying value -if attr.value: # The expression assigned to the attribute - print(f"Default value: {attr.value.source}") -``` - - - Attribute operations automatically handle all references, including - `self.attribute` usages in methods and string references. - - -### Working with Inheritance - -You can navigate inheritance hierarchies with APIs including: - -- [Class.superclasses](/api-reference/core/Class#superclasses) -- [Class.subclasses](/api-reference/core/Class#subclasses) -- [Class.is_subclass_of](/api-reference/core/Class#is-subclass-of) - -```python -class_def = codebase.get_class("Cube") - -# View ancestors -all_ancestors = class_def.superclasses # All classes inherited -immediate_parents = class_def.superclasses(max_depth=1) # Direct parents only - -# Inheritance-aware method lookup -method = class_def.get_method("process") # Searches up inheritance chain -if method.parent_class != class_def: - print(f"Method inherited from {method.parent_class.name}") - -# Handle external dependencies -if class_def.is_subclass_of("Enum"): # Works with stdlib/external classes - print("This is an enum class") -``` - -Likewise, you can modify inheritance by accessing: - -- [Class.parent_class_names](/api-reference/core/Class#parent-class-names) -- [Class.get_parent_class(cls_name)](/api-reference/core/Class#get-parent-class) - -Which return lists of [Name](/api-reference/core/Name) objects. - -```python -# Modify inheritance -parent_names = class_def.parent_class_names -if parent_names[0] == 'BaseClass': - parent_names[0].edit("NewBaseClass") # Change parent class - -# Get specific parent class -parent_class = class_def.get_parent_class("BaseClass") -if parent_class: - parent_class.edit("NewBaseClass") # Change parent class -``` - - - When working with inheritance, use `max_depth` to control how far up the - inheritance chain to look. `max_depth=0` means current class only, - `max_depth=None` means traverse entire hierarchy. - - - - Codegen handles both internal and external parent classes (like stdlib - classes). The `superclasses` property follows the language's MRO rules for - method resolution. - - -## Method Resolution Order (MRO) - -Codegen follows the target language's method resolution order (MRO) for inheritance: - -```python -# Access superclasses -for parent in class_def.superclasses: - print(f"Parent: {parent.name}") - -# Check inheritance -if class_def.is_subclass_of("BaseClass"): - print("This is a subclass of BaseClass") - -# Get all subclasses -for child in class_def.subclasses: - print(f"Child class: {child.name}") - -# Access inherited methods/attributes -all_methods = class_def.methods(max_depth=None) # Include inherited methods -all_attrs = class_def.attributes(max_depth=None) # Include inherited attributes -``` - - ---- -title: "The Import API" -sidebarTitle: "Imports" -icon: "file-import" -iconType: "solid" ---- - -The [Import](/api-reference/core/Import) API provides tools for working with imports and managing dependencies between files. - -## Accessing Imports - -You can access these through [File.imports](/api-reference/core/File#imports) and [File.import_statements](/api-reference/core/File#import-statements): - -```python -# Direct access to imports via file -for imp in file.imports: - ... - -# Grab by name of symbol being imported -imp = file.get_import('math') - -# Grab and filter from a codebase -from codegen.sdk import ExternalModule - -external_imports = [i for i in codebase.imports if isinstance(i, ExternalModule)] -``` - -## Common Operations - -The Import API provides several methods for modifying imports: - -```python -# Get a specific import -import_stmt = file.get_import("MyComponent") - -# Change import source -import_stmt.set_module("./new/path") - -# Add/update alias -import_stmt.set_alias("MyAlias") # import X as MyAlias - -# TypeScript-specific operations -import_stmt.make_type_import() # Convert to 'import type' -import_stmt.make_value_import() # Remove 'type' modifier - -# Update multiple properties -import_stmt.update( - module="./new/path", - alias="NewAlias", - is_type=True -) -``` - -## Import Resolution - -Imports can be traced to their original symbols: - -```python -# Follow import chain to source -import_stmt = file.get_import("MyComponent") -original = import_stmt.resolved_symbol - -if original: - print(f"Defined in: {original.file.filepath}") - print(f"Original name: {original.name}") - -# Get file relationships -print(f"From file: {import_stmt.from_file.filepath}") -print(f"To file: {import_stmt.to_file.filepath}") -``` - - -With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving -packages. - - -## Working with External Modules - -You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so: - -```python -# Check if import is from external package -for imp in file.imports: - if isinstance(imp.imported_symbol, ExternalModule): - print(f"External import: {imp.name} from {imp.module}") - else: - print(f"Local import: {imp.name}") -``` - -Learn more about [external modules here](/building-with-codegen/external-modules) - - -## Bulk Operations - -Here are patterns for working with multiple imports: - -```python -# Update imports from a specific module -old_path = "./old/path" -new_path = "./new/path" - -for imp in file.imports: - if imp.module == old_path: - imp.set_module(new_path) - -# Remove unused imports (excluding external) -for imp in file.imports: - if not imp.usages and not isinstance(imp.resolved_symbol, ExternalModule): - print(f"Removing: {imp.name}") - imp.remove() - -# Consolidate duplicate imports -from collections import defaultdict - -module_imports = defaultdict(list) -for imp in file.imports: - module_imports[imp.module].append(imp) - -for module, imports in module_imports.items(): - if len(imports) > 1: - # Create combined import - symbols = [imp.name for imp in imports] - file.add_import( - f"import {{ {', '.join(symbols)} }} from '{module}'" - ) - # Remove old imports - for imp in imports: - imp.remove() -``` - - -Always check if imports resolve to external modules before modification to avoid breaking third-party package imports. - - -## Import Statements vs Imports - -Codegen provides two levels of abstraction for working with imports: - -- [ImportStatement](/api-reference/core/ImportStatement) - Represents a complete import statement -- [Import](/api-reference/core/Import) - Represents individual imported symbols - - -```python Python -# One ImportStatement containing multiple Import objects -from math import sin, cos as cosine -# Creates: -# - Import for 'sin' -# - Import for 'cos' with alias 'cosine' -``` - -```typescript Typescript -// One ImportStatement containing multiple Import objects -import { sin, cos as cosine } from 'math'; -// Creates: -// - Import for 'sin' -// - Import for 'cos' with alias 'cosine' -``` - - -You can access these through [File.imports](/api-reference/core/File#imports) and [File.import_statements](/api-reference/core/File#import-statements): - -```python -# Direct access to imports -for imp in file.imports: - ... - -# Access to imports via statements -for stmt in file.import_statements: - for imp in stmt.imports: - ... -``` - - -ImportStatement inherits from [Statement](/building-with-codegen/statements-and-code-blocks), providing operations like `remove()` and `insert_before()`. - - ---- -title: "The Export API" -sidebarTitle: "Exports" -icon: "file-export" -iconType: "solid" ---- - -The [Export](/api-reference/core/Export) API provides tools for managing exports and module boundaries in TypeScript codebases. - -Exports are a TS-only language feature - -## Export Statements vs Exports - -Similar to imports, Codegen provides two levels of abstraction for working with exports: - -- [ExportStatement](/api-reference/core/ExportStatement) - Represents a complete export statement -- [Export](/api-reference/core/Export) - Represents individual exported symbols - -```typescript -// One ExportStatement containing multiple Export objects -export { foo, bar as default, type User }; -// Creates: -// - Export for 'foo' -// - Export for 'bar' as default -// - Export for 'User' as a type - -// Direct exports create one ExportStatement per export -export const value = 42; -export function process() {} -``` - -You can access these through your file's collections: - -```python -# Access all exports in the codebase -for export in codebase.exports: - ... - -# Access all export statements -for stmt in file.export_statements: - for exp in stmt.exports: - ... -``` - - -ExportStatement inherits from [Statement](/building-with-codegen/statements-and-code-blocks), providing operations like `remove()` and `insert_before()`. This is particularly useful when you want to manipulate the entire export declaration. - - -## Common Operations - -Here are common operations for working with exports: - -```python -# Add exports from source code -file.add_export_from_source("export { MyComponent };") -file.add_export_from_source("export type { MyType } from './types';") - -# Export existing symbols -component = file.get_function("MyComponent") -file.add_export(component) # export { MyComponent } -file.add_export(component, alias="default") # export { MyComponent as default } - -# Convert to type export -export = file.get_export("MyType") -export.make_type_export() - -# Remove exports -export = file.get_export("MyComponent") -export.remove() # Removes export but keeps the symbol - -# Remove multiple exports -for export in file.exports: - if not export.is_type_export(): - export.remove() - -# Update export properties -export.update( - name="NewName", - is_type=True, - is_default=False -) - -# Export from another file -other_file = codebase.get_file("./components.ts") -component = other_file.get_class("Button") -file.add_export(component, from_file=other_file) # export { Button } from './components'; - -# Analyze symbols being exported -for export in file.exports: - if isinstance(export.exported_symbol, ExternalModule): - print('Exporting ExternalModule') - else: - ... -``` - - -When adding exports, you can: -- Add from source code with `add_export_from_source()` -- Export existing symbols with `add_export()` -- Re-export from other files by specifying `from_file` - -The export will automatically handle adding any required imports. - - -## Export Types - -Codegen supports several types of exports: - -```typescript -// Direct exports -export const value = 42; // Value export -export function myFunction() {} // Function export -export class MyClass {} // Class export -export type MyType = string; // Type export -export interface MyInterface {} // Interface export -export enum MyEnum {} // Enum export - -// Re-exports -export { foo, bar } from './other-file'; // Named re-exports -export type { Type } from './other-file'; // Type re-exports -export * from './other-file'; // Wildcard re-exports -export * as utils from './other-file'; // Namespace re-exports - -// Aliased exports -export { foo as foop }; // Basic alias -export { foo as default }; // Default export alias -export { bar as baz } from './other-file'; // Re-export with alias -``` - -## Identifying Export Types - -The Export API provides methods to identify and filter exports: -- [.is_type_export()](/api-reference/typescript/TSExport#is-type-export) -- [.is_default_export()](/api-reference/typescript/TSExport#is-default-export) -- [.is_wildcard_export()](/api-reference/typescript/TSExport#is-wildcard-export) - - -```python -# Check export types -for exp in file.exports: - if exp.is_type_export(): - print(f"Type export: {exp.name}") - elif exp.is_default_export(): - print(f"Default export: {exp.name}") - elif exp.is_wildcard_export(): - print(f"Wildcard export from: {exp.from_file.filepath}") -``` - -## Export Resolution - -You can trace exports to their original symbols: - -```python -for exp in file.exports: - if exp.is_reexport(): - # Get original and current symbols - current = exp.exported_symbol - original = exp.resolved_symbol - - print(f"Re-exporting {original.name} from {exp.from_file.filepath}") - print(f"Through: {' -> '.join(e.file.filepath for e in exp.export_chain)}") -``` - -## Managing Re-exports - -You can manage re-exports with the [TSExport.is_reexport()](/api-reference/typescript/TSExport#is-reexport) API: - -```python -# Create public API -index_file = codebase.get_file("index.ts") - -# Re-export from internal files -for internal_file in codebase.files: - if internal_file.name != "index": - for symbol in internal_file.symbols: - if symbol.is_public: - index_file.add_export( - symbol, - from_file=internal_file - ) - -# Convert default to named exports -for exp in file.exports: - if exp.is_default_export(): - exp.make_named_export() - -# Consolidate re-exports -from collections import defaultdict - -file_exports = defaultdict(list) -for exp in file.exports: - if exp.is_reexport(): - file_exports[exp.from_file].append(exp) - -for from_file, exports in file_exports.items(): - if len(exports) > 1: - # Create consolidated re-export - names = [exp.name for exp in exports] - file.add_export_from_source( - f"export {{ {', '.join(names)} }} from '{from_file.filepath}'" - ) - # Remove individual exports - for exp in exports: - exp.remove() -``` - - -When managing exports, consider the impact on your module's public API. Not all symbols that can be exported should be exported. - - ---- -title: "Inheritable Behaviors" -sidebarTitle: "Inheritable Behaviors" -icon: "puzzle-piece" -iconType: "solid" ---- - -Codegen uses a set of core behaviors that can be inherited by code elements. These behaviors provide consistent APIs across different types of symbols. - - -## Core Behaviors - -- [HasName](/api-reference/core/HasName): For elements with [Names](/api-reference/core/Name) (Functions, Classes, Assignments, etc.) -- [HasValue](/api-reference/core/HasValue): For elements with [Values](/api-reference/core/Value) (Arguments, Assignments, etc.) -- [HasBlock](/api-reference/core/HasBlock): For elements containing [CodeBlocks](/api-reference/core/CodeBlock) (Files, Functions, Classes) -- [Editable](/api-reference/core/Editable): For elements that can be safely modified ([learn more](/building-with-codegen/the-editable-api)) - -These "behaviors" are implemented as inherited classes. - -## Working with Names - -The [HasName](/api-reference/core/HasName) behavior provides APIs for working with named elements: - -```python -# Access the name -print(function.name) # Base name without namespace -print(function.full_name) # Full qualified name with namespace - -# Modify the name -function.set_name("new_name") # Changes just the name -function.rename("new_name") # Changes name and updates all usages - -# Get the underlying name node -name_node = function.get_name() -``` - -## Working with Values - -The [HasValue](/api-reference/core/HasValue) behavior provides APIs for elements that have values: - -```python -# Access the value -value = variable.value # Gets the value Expression node -print(value.source) # Gets the string content - -# Modify the value -variable.set_value("new_value") - -# Common patterns -if variable.value is not None: - print(f"{variable.name} = {variable.value.source}") -``` - -## Working with Code Blocks - -The [HasBlock](/api-reference/core/HasBlock) behavior provides APIs for elements containing code: - -```python -# Access the code block -block = function.code_block -print(len(block.statements)) # Number of statements -printS(block.source) -``` - - - Learn more about [CodeBlocks and Statements - here](/building-with-codegen/statements-and-code-blocks) - - -## Working with Attributes - -The [get_attribute](/api-reference/core/Class#get-attribute) method provides APIs for attribute access: - -```python -# Common patterns -class_attr = class_def.get_attribute("attribute_name") -if class_attr: - print(f"Class variable value: {class_attr.value.source}") -``` - - - Learn more about [working with Attributes - here](/building-with-codegen/class-api#class-attributes). - - -## Behavior Combinations - -Many code elements inherit multiple behaviors. For example, a function typically has: - -```python -# Functions combine multiple behaviors -function = codebase.get_function("process_data") - -# HasName behavior -print(function.name) -function.rename("process_input") - -# HasBlock behavior -print(len(function.code_block.statements)) -function.add_decorator("@timer") - -# Editable behavior -function.edit("def process_input():\n pass") -``` - - ---- -title: "Statements and Code Blocks" -sidebarTitle: "Statements and Code Blocks" -icon: "code" -iconType: "solid" ---- - -Codegen uses two classes to represent code structure at the highest level: - -- [Statement](../api-reference/core/Statement): Represents a single line or block of code - - - Can be assignments, imports, loops, conditionals, etc. - - Contains source code, dependencies, and type information - - May contain nested code blocks (like in functions or loops) - -- [CodeBlock](../api-reference/core/CodeBlock): A container for multiple Statements - - Found in files, functions, classes, and control flow blocks - - Provides APIs for analyzing and manipulating statements - - Handles scope, variables, and dependencies - -Codegen provides rich APIs for working with code statements and blocks, allowing you to analyze and manipulate code structure at a granular level. - -## Working with Statements - -### Basic Usage - -Every file, function, and class in Codegen has a [CodeBlock](../api-reference/core/CodeBlock) that contains its statements: - -```python -# Access statements in a file -file = codebase.get_file("main.py") -for statement in file.code_block.statements: - print(f"Statement type: {statement.statement_type}") - -# Access statements in a function -function = file.get_function("process_data") -for statement in function.code_block.statements: - print(f"Statement: {statement.source}") -``` - -### Filtering Statements - -Filter through statements using Python's builtin `isinstance` function. - -```python -# Filter statements by type -for stmt in file.code_block.statements: - if isinstance(stmt, ImportStatement): - print(stmt) -``` - -### Adding Statements - -Functions and Files support [.prepend_statement(...)](../api-reference/core/Symbol#prepend-statement) and [.add_statement(...)](../api-reference/core/Function#add-statement) to add statements to the symbol. - - - See [Adding - Statements](/building-with-codegen/symbol-api#function-statement-manipulation) - for details. - - -### Working with Nested Structures - -Frequently you will want to check if a statement is nested within another structure, for example if a statement is inside an `if` block or a `try/catch` statement. - -Codegen supports this functionality with the [Editable.is_wrapped_in(...)](../api-reference/core/Editable#is-wrapped-in) method. - -```python -func = codebase.get_function("process_data") -for usage in func.local_variable_usages: - if usage.is_wrapped_in(IfStatement): - print(f"Usage of {usage.name} is inside an if block") -``` - -Similarly, all Editable objects support the `.parent_statement`, which can be used to navigate the statement hierarchy. - -```python -func = codebase.get_function("process_data") -for usage in func.local_variable_usages: - if isinstance(usage.parent_statement, IfStatement): - print(f"Usage of {usage.name} is directly beneath an IfStatement") -``` - -### Wrapping and Unwrapping Statements - -[CodeBlocks](../api-reference/core/CodeBlock) support wrapping and unwrapping with the following APIs: - -- [.wrap(...)](../api-reference/core/CodeBlock#wrap) - allows you to wrap a statement in a new structure. -- [.unwrap(...)](../api-reference/core/CodeBlock#unwrap) - allows you to remove the wrapping structure while preserving the code block's contents. - -```python -# Wrap code blocks with new structures -function.code_block.wrap("with open('test.txt', 'w') as f:") -# Result: -# with open('test.txt', 'w') as f: -# original_code_here... - -# Wrap code in a function -file.code_block.wrap("def process_data(a, b):") -# Result: -# def process_data(a, b): -# original_code_here... - -# Unwrap code from its container -if_block.code_block.unwrap() # Removes the if statement but keeps its body -while_loop.code_block.unwrap() # Removes the while loop but keeps its body -``` - - - Both `wrap` and `unwrap` are potentially unsafe changes and will modify - business logic. - - - - The `unwrap()` method preserves the indentation of the code block's contents - while removing the wrapping structure. This is useful for refactoring nested - code structures. - - -## Statement Types - -Codegen supports various statement types, each with specific APIs: - -### [Import Statements](../api-reference/core/ImportStatement) / [Export Statements](../api-reference/core/ExportStatement) - - - See [imports](/building-with-codegen/imports) and [exports](../building-with-codegen/exports) for - more details. - - -```python -# Access import statements -for import_stmt in file.import_statements: - print(f"Module: {import_stmt.module}") - for imported in import_stmt.imports: - print(f" Imported: {imported.name}") - -# Remove specific imports -import_stmt = file.import_statements[0] -import_stmt.imports[0].remove() # Remove first import - -# Remove entire import statement -import_stmt.remove() -``` - -### [If/Else Statements](../api-reference/core/IfBlockStatement) - -If/Else statements provide rich APIs for analyzing and manipulating conditional logic: - -```python -# Access if/else blocks -if_block = file.code_block.statements[0] -print(f"Condition: {if_block.condition.source}") - -# Check block types -if if_block.is_if_statement: - print("Main if block") -elif if_block.is_elif_statement: - print("Elif block") -elif if_block.is_else_statement: - print("Else block") - -# Access alternative blocks -for elif_block in if_block.elif_statements: - print(f"Elif condition: {elif_block.condition.source}") - -if else_block := if_block.else_statement: - print("Has else block") - -# Access nested code blocks -for block in if_block.nested_code_blocks: - print(f"Block statements: {len(block.statements)}") -``` - -If blocks also support condition reduction, which can simplify conditional logic: - -```python -# Reduce if condition to True -if_block.reduce_condition(True) -# Before: -# if condition: -# print("a") -# else: -# print("b") -# After: -# print("a") - -# Reduce elif condition to False -elif_block.reduce_condition(False) -# Before: -# if a: -# print("a") -# elif condition: -# print("b") -# else: -# print("c") -# After: -# if a: -# print("a") -# else: -# print("c") -``` - - - When reducing conditions, Codegen automatically handles the restructuring of - elif/else chains and preserves the correct control flow. - - -### [Switch](../api-reference/core/SwitchStatement)/[Match](../api-reference/python/PyMatchStatement) Statements - -```python -# TypeScript switch statements -switch_stmt = file.code_block.statements[0] -for case_stmt in switch_stmt.cases: - print(f"Case condition: {case_stmt.condition}") - print(f"Is default: {case_stmt.default}") - - # Access statements in each case - for statement in case_stmt.code_block.statements: - print(f"Statement: {statement.source}") - -# Python match statements -match_stmt = file.code_block.statements[0] -for case in match_stmt.cases: - print(f"Pattern: {case.pattern}") - for statement in case.code_block.statements: - print(f"Statement: {statement.source}") -``` - -### [While Statements](../api-reference/core/WhileStatement) - -```python -while_stmt = file.code_block.statements[0] -print(f"Condition: {while_stmt.condition}") - -# Access loop body -for statement in while_stmt.code_block.statements: - print(f"Body statement: {statement.source}") - -# Get function calls within the loop -for call in while_stmt.function_calls: - print(f"Function call: {call.source}") -``` - -### [Assignment Statements](../api-reference/core/AssignmentStatement) - -```python -# Access assignments in a code block -for statement in code_block.statements: - if statement.statement_type == StatementType.ASSIGNMENT: - for assignment in statement.assignments: - print(f"Variable: {assignment.name}") - print(f"Value: {assignment.value}") -``` - -## Working with Code Blocks - -Code blocks provide several ways to analyze and manipulate their content: - -### Statement Access - -```python -code_block = function.code_block - -# Get all statements -all_statements = code_block.statements - -# Get statements by type -if_blocks = code_block.if_blocks -while_loops = code_block.while_loops -try_blocks = code_block.try_blocks - -# Get local variables -local_vars = code_block.get_local_var_assignments() -``` - -### Statement Dependencies - -```python -# Get dependencies between statements -function = file.get_function("process") -for statement in function.code_block.statements: - deps = statement.dependencies - print(f"Statement {statement.source} depends on: {[d.name for d in deps]}") -``` - -### Parent-Child Relationships - -```python -# Access parent statements -function = file.get_function("main") -parent_stmt = function.parent_statement - -# Access nested symbols -class_def = file.get_class("MyClass") -for method in class_def.methods: - parent = method.parent_statement - print(f"Method {method.name} is defined in {parent.source}") -``` - -## Common Operations - -### Finding Statements - -```python -# Find specific statements -assignments = [s for s in code_block.statements - if s.statement_type == StatementType.ASSIGNMENT] - -# Find statements by content -matching = [s for s in code_block.statements - if "specific_function()" in s.source] -``` - -### Analyzing Flow Control - -```python -# Analyze control flow -for statement in code_block.statements: - if statement.statement_type == StatementType.IF_BLOCK: - print("Condition:", statement.condition) - print("Then:", statement.consequence_block.statements) - if statement.alternative_block: - print("Else:", statement.alternative_block.statements) -``` - -### Working with Functions - -```python -# Analyze function calls in statements -for statement in code_block.statements: - for call in statement.function_calls: - print(f"Calls function: {call.name}") - print(f"With arguments: {[arg.source for arg in call.arguments]}") -``` - - ---- -title: "Dependencies and Usages" -sidebarTitle: "Dependencies and Usages" -icon: "share-nodes" -iconType: "solid" ---- - -Codegen pre-computes dependencies and usages for all symbols in the codebase, enabling constant-time queries for these relationships. - -## Overview - -Codegen provides two main ways to track relationships between symbols: - -- [.dependencies](/api-reference/core/Symbol#dependencies) / - What symbols does this symbol depend on? -- [.usages](/api-reference/core/Symbol#usages) / [.usages(...)](/api-reference/core/Symbol#usages) - Where is this symbol used? - -Dependencies and usages are inverses of each other. For example, given the following input code: - -```python -# Input code -from module import BaseClass - -class MyClass(BaseClass): - pass -``` - -The following assertions will hold in the Codegen API: - -```python -base = codebase.get_symbol("BaseClass") -my_class = codebase.get_symbol("MyClass") - -# MyClass depends on BaseClass -assert base in my_class.dependencies - -# BaseClass is used by MyClass -assert my_class in base.usages -``` - -If `A` depends on `B`, then `B` is used by `A`. This relationship is tracked in both directions, allowing you to navigate the codebase from either perspective. - -```mermaid - -flowchart LR - B(BaseClass) - - - - A(MyClass) - B ---| used by |A - A ---|depends on |B - - classDef default fill:#fff,stroke:#000,color:#000; -``` - -- `MyClass.dependencies` answers the question: *"which symbols in the codebase does MyClass depend on?"* - -- `BaseClass.usages` answers the question: *"which symbols in the codebase use BaseClass?"* - -## Usage Types - -Both APIs use the [UsageType](../api-reference/core/UsageType) enum to specify different kinds of relationships: - -```python -class UsageType(IntFlag): - DIRECT = auto() # Direct usage within the same file - CHAINED = auto() # Usage through attribute access (module.symbol) - INDIRECT = auto() # Usage through a non-aliased import - ALIASED = auto() # Usage through an aliased import -``` - -### DIRECT Usage - -A direct usage occurs when a symbol is used in the same file where it's defined, without going through any imports or attribute access. - -```python -# Define MyClass -class MyClass: - def __init__(self): - pass - -# Direct usage of MyClass in same file -class Child(MyClass): - pass -``` - -### CHAINED Usage - -A chained usage occurs when a symbol is accessed through module or object attribute access, using dot notation. - -```python -import module - -# Chained usage of ClassB through module -obj = module.ClassB() -# Chained usage of method through obj -result = obj.method() -``` - -### INDIRECT Usage - -An indirect usage happens when a symbol is used through a non-aliased import statement. - -```python -from module import BaseClass - -# Indirect usage of BaseClass through import -class MyClass(BaseClass): - pass -``` - -### ALIASED Usage - -An aliased usage occurs when a symbol is used through an import with an alias. - -```python -from module import BaseClass as AliasedBase - -# Aliased usage of BaseClass -class MyClass(AliasedBase): - pass -``` - -## Dependencies API - -The dependencies API lets you find what symbols a given symbol depends on. - -### Basic Usage - -```python -# Get all direct dependencies -deps = my_class.dependencies # Shorthand for dependencies(UsageType.DIRECT) - -# Get dependencies of specific types -direct_deps = my_class.dependencies(UsageType.DIRECT) -chained_deps = my_class.dependencies(UsageType.CHAINED) -indirect_deps = my_class.dependencies(UsageType.INDIRECT) -``` - -### Combining Usage Types - -You can combine usage types using the bitwise OR operator: - -```python -# Get both direct and indirect dependencies -deps = my_class.dependencies(UsageType.DIRECT | UsageType.INDIRECT) - -# Get all types of dependencies -deps = my_class.dependencies( - UsageType.DIRECT | UsageType.CHAINED | - UsageType.INDIRECT | UsageType.ALIASED -) -``` - -### Common Patterns - -1. Finding dead code (symbols with no usages): - -```python -# Check if a symbol is unused -def is_dead_code(symbol): - return not symbol.usages - -# Find all unused functions in a file -dead_functions = [f for f in file.functions if not f.usages] -``` - - - See [Deleting Dead Code](/tutorials/deleting-dead-code) to learn more about finding - unused code. - - -2. Finding all imports that a symbol uses: - -```python -# Get all imports a class depends on -class_imports = [dep for dep in my_class.dependencies if isinstance(dep, Import)] - -# Get all imports used by a function, including indirect ones -all_function_imports = [ - dep for dep in my_function.dependencies(UsageType.DIRECT | UsageType.INDIRECT) - if isinstance(dep, Import) -] -``` -## Traversing the Dependency Graph - -Sometimes you need to analyze not just direct dependencies, but the entire dependency graph up to a certain depth. The `dependencies` method allows you to traverse the dependency graph and collect all dependencies up to a specified depth level. - -### Basic Usage - -```python - -# Get only direct dependencies -deps = symbol.dependencies(max_depth=1) - -# Get deep dependencies (up to 5 levels) -deps = symbol.dependencies(max_depth=5) -``` - -The method returns a dictionary mapping each symbol to its list of direct dependencies. This makes it easy to analyze the dependency structure: - -```python -# Print the dependency tree -for sym, direct_deps in deps.items(): - print(f"{sym.name} depends on: {[d.name for d in direct_deps]}") -``` - -### Example: Analyzing Class Inheritance - -Here's an example of using `dependencies` to analyze a class inheritance chain: - -```python -class A: - def method_a(self): pass - -class B(A): - def method_b(self): - self.method_a() - -class C(B): - def method_c(self): - self.method_b() - -# Get the full inheritance chain -symbol = codebase.get_class("C") -deps = symbol.dependencies( - max_depth=3 -) - -# Will show: -# C depends on: [B] -# B depends on: [A] -# A depends on: [] -``` - -### Handling Cyclic Dependencies - -The method properly handles cyclic dependencies in the codebase: - -```python -class A: - def method_a(self): - return B() - -class B: - def method_b(self): - return A() - -# Get dependencies including cycles -symbol = codebase.get_class("A") -deps = symbol.dependencies() - -# Will show: -# A depends on: [B] -# B depends on: [A] -``` - - - The `max_depth` parameter helps prevent excessive recursion in large codebases or when there are cycles in the dependency graph. - - - ---- -title: "Function Calls and Call Sites" -sidebarTitle: "Function Calls" -icon: "function" -iconType: "solid" ---- - -Codegen provides comprehensive APIs for working with function calls through several key classes: - -- [FunctionCall](../api-reference/core/FunctionCall) - Represents a function invocation -- [Argument](../api-reference/core/Argument) - Represents arguments passed to a function -- [Parameter](../api-reference/core/Parameter) - Represents parameters in a function definition - - - See [Migrating APIs](/tutorials/migrating-apis) for relevant tutorials and - applications. - - -## Navigating Function Calls - -Codegen provides two main ways to navigate function calls: - -1. From a function to its call sites using [call_sites](../api-reference/core/Function#call-sites) -2. From a function to the calls it makes (within it's [CodeBlock](../api-reference/core/CodeBlock)) using [function_calls](../api-reference/core/Function#function-calls) - -Here's how to analyze function usage patterns: - -```python -# Find the most called function -most_called = max(codebase.functions, key=lambda f: len(f.call_sites)) -print(f"\nMost called function: {most_called.name}") -print(f"Called {len(most_called.call_sites)} times from:") -for call in most_called.call_sites: - print(f" - {call.parent_function.name} at line {call.start_point[0]}") - -# Find function that makes the most calls -most_calls = max(codebase.functions, key=lambda f: len(f.function_calls)) -print(f"\nFunction making most calls: {most_calls.name}") -print(f"Makes {len(most_calls.function_calls)} calls to:") -for call in most_calls.function_calls: - print(f" - {call.name}") - -# Find functions with no callers (potential dead code) -unused = [f for f in codebase.functions if len(f.call_sites) == 0] -print(f"\nUnused functions:") -for func in unused: - print(f" - {func.name} in {func.filepath}") - -# Find recursive functions -recursive = [f for f in codebase.functions - if any(call.name == f.name for call in f.function_calls)] -print(f"\nRecursive functions:") -for func in recursive: - print(f" - {func.name}") -``` - -This navigation allows you to: - -- Find heavily used functions -- Analyze call patterns -- Map dependencies between functions - -## Arguments and Parameters - -The [Argument](../api-reference/core/Argument) class represents values passed to a function, while [Parameter](../api-reference/core/Parameter) represents the receiving variables in the function definition: - -Consider the following code: - -```python -# Source code: -def process_data(input_data: str, debug: bool = False): - pass - -process_data("test", debug=True) -``` - -You can access and modify the arguments and parameters of the function call with APIs detailed below. - -### Finding Arguments - -The primary APIs for finding arguments are: - -- [FunctionCall.args](/api-reference/core/FunctionCall#args) -- [FunctionCall.get_arg_by_parameter_name(...)](/api-reference/core/FunctionCall#get-arg-by-parameter-name) -- [FunctionCall.get_arg_by_index(...)](/api-reference/core/FunctionCall#get-arg-by-index) - -```python -# Get the function call -call = file.function_calls[0] - -# Working with arguments -for arg in call.args: - print(f"Arg {arg.index}: {arg.value}") # Access argument value - print(f"Is named: {arg.is_named}") # Check if it's a kwarg - print(f"Name: {arg.name}") # For kwargs, e.g. "debug" - - # Get corresponding parameter - if param := arg.parameter: - print(f"Parameter type: {param.type}") - print(f"Is optional: {param.is_optional}") - print(f"Has default: {param.default}") - -# Finding specific arguments -debug_arg = call.get_arg_by_parameter_name("debug") -first_arg = call.get_arg_by_index(0) -``` - -### Modifying Arguments - -There are two ways to modify function call arguments: - -1. Using [FunctionCall.set_kwarg(...)](/api-reference/core/FunctionCall#set-kwarg) to add or modify keyword arguments: - -```python -# Modifying keyword arguments -call.set_kwarg("debug", "False") # Modifies existing kwarg -call.set_kwarg("new_param", "value", create_on_missing=True) # Adds new kwarg -call.set_kwarg("input_data", "'new_value'", override_existing=True) # Converts positional to kwarg -``` - -2. Using [FuncionCall.args.append(...)](/api-reference/core/FunctionCall#args) to add new arguments: - - [FunctionCall.args](/api-reference/core/FunctionCall#args) is a - [Collection](/building-with-codegen/collections) of - [Argument](/api-reference/core/Argument) objects, so it supports - [.append(...)](/api-reference/core/List#append), - [.insert(...)](/api-reference/core/List#insert) and other collection - methods. - - -```python -# Adding new arguments -call.args.append('cloud="aws"') # Add a new keyword argument -call.args.append('"value"') # Add a new positional argument - -# Real-world example: Adding arguments to a decorator -@app.function(image=runner_image) -def my_func(): - pass - -# Add cloud and region if not present -if "cloud=" not in decorator.call.source: - decorator.call.args.append('cloud="aws"') -if "region=" not in decorator.call.source: - decorator.call.args.append('region="us-east-1"') -``` - -The `set_kwarg` method provides intelligent argument manipulation: - -- If the argument exists and is positional, it converts it to a keyword argument -- If the argument exists and is already a keyword, it updates its value (if override_existing=True) -- If the argument doesn't exist, it creates it (if create_on_missing=True) -- When creating new arguments, it intelligently places them based on parameter order - -Arguments and parameters support safe edit operations like so: - -```python -# Modifying arguments -debug_arg.edit("False") # Change argument value -first_arg.add_keyword("input_data") # Convert to named argument - -# modifying parameters -param = codebase.get_function('process_data').get_parameter('debug') -param.rename('_debug') # updates all call-sites -param.set_type_annotation('bool') -``` - -## Finding Function Definitions - -Every [FunctionCall](../api-reference/core/FunctionCall) can navigate to its definition through [function_definition](../api-reference/core/FunctionCall#function-definition) and [function_definitions](../api-reference/core/FunctionCall#function-definitions): - -```python -function_call = codebase.files[0].function_calls[0] -function_definition = function_call.function_definition -print(f"Definition found in: {function_definition.filepath}") -``` - -## Finding Parent (Containing) Functions - -FunctionCalls can access the function that invokes it via [parent_function](../api-reference/core/FunctionCall#parent-function). - -For example, given the following code: - -```python -# Source code: -def outer(): - def inner(): - helper() - inner() -``` - -You can find the parent function of the helper call: - -```python -# Manipulation code: -# Find the helper() call -helper_call = file.get_function("outer").function_calls[1] - -# Get containing function -parent = helper_call.parent_function -print(f"Call is inside: {parent.name}") # 'inner' - -# Get the full call hierarchy -outer = parent.parent_function -print(f"Which is inside: {outer.name}") # 'outer' -``` - -## Method Chaining - -Codegen enables working with chained method calls through [predecessor](../api-reference/core/FunctionCall#predecessor) and related properties: - -For example, for the following database query: - -```python -# Source code: -query.select(Table) - .where(id=1) - .order_by("name") - .limit(10) -``` - -You can access the chain of calls: - -```python -# Manipulation code: -# Get the `limit` call in the chain -limit_call = next(f for f in file.function.function_calls if f.name == "limit", None) - -# Navigate backwards through the chain -order_by = limit_call.predecessor -where = order_by.predecessor -select = where.predecessor - -# Get the full chain at once -chain = limit_call.call_chain # [select, where, order_by, limit] - -# Access the root object -base = limit_call.base # Returns the 'query' object - -# Check call relationships -print(f"After {order_by.name}: {limit_call.name}") -print(f"Before {where.name}: {select.name}") -``` - - ---- -title: "Variable Assignments" -sidebarTitle: "Variable Assignments" -icon: "equals" -iconType: "solid" ---- - -Codegen's enables manipulation of variable assignments via the following classes: - -- [AssignmentStatement](../api-reference/core/AssignmentStatement) - A statement containing one or more assignments -- [Assignment](../api-reference/core/Assignment) - A single assignment within an AssignmentStatement - - -### Simple Value Changes - -Consider the following source code: - -```typescript -const userId = 123; -const [userName, userAge] = ["Eve", 25]; -``` - -In Codegen, you can access assignments with the [get_local_var_assignment](../api-reference/core/CodeBlock#get-local-var-assignment) method. - -You can then manipulate the assignment with the [set_value](../api-reference/core/Assignment#set-value) method. - -```python -id_assignment = file.code_block.get_local_var_assignment("userId") -id_assignment.set_value("456") - -name_assignment = file.code_block.get_local_var_assignment("name") -name_assignment.rename("userName") -``` - - - Assignments inherit both [HasName](/api-reference/core/HasName) and - [HasValue](/api-reference/core/HasValue) behaviors. See [Inheritable - Behaviors](/building-with-codegen/inheritable-behaviors) for more details. - - -### Type Annotations - -Similarly, you can set type annotations with the [set_type_annotation](../api-reference/core/Assignment#set-type-annotation) method. - -For example, consider the following source code: - -```typescript -let status; -const data = fetchData(); -``` - -You can manipulate the assignments as follows: - -```python -status_assignment = file.code_block.get_local_var_assignment("status") -status_assignment.set_type_annotation("Status") -status_assignment.set_value("Status.ACTIVE") - -data_assignment = file.code_block.get_local_var_assignment("data") -data_assignment.set_type_annotation("ResponseData") - -# Result: -let status: Status = Status.ACTIVE; -const data: ResponseData = fetchData(); -``` - -## Tracking Usages and Dependencies - -Like other symbols, Assignments support [usages](/api-reference/core/Assignment#usages) and [dependencies](/api-reference/core/Assignment#dependencies). - -```python -assignment = file.code_block.get_local_var_assignment("userId") - -# Get all usages of the assignment -usages = assignment.usages - -# Get all dependencies of the assignment -dependencies = assignment.dependencies -``` - - - See [Dependencies and Usages](/building-with-codegen/dependencies-and-usages) - for more details. - - - ---- -title: "Local Variables" -sidebarTitle: "Local Variables" -icon: "cube" -iconType: "solid" ---- - -This document explains how to work with local variables in Codegen. - -## Overview - -Through the [CodeBlock](../api-reference/core/CodeBlock) class, Codegen exposes APIs for analyzing and manipulating local variables within code blocks. - -- [local_var_assignments](../api-reference/core/CodeBlock#local-var-assignments): find all [Assignments](../api-reference/core/Assignment) in this scope -- [get_local_var_assignment(...)](../api-reference/core/CodeBlock#get-local-var-assignment): get specific [Assignments](../api-reference/core/Assignment) by name -- [rename_local_variable(...)](../api-reference/core/CodeBlock#rename-local-variable): rename variables safely across the current scope - -## Basic Usage - -Every code block (function body, loop body, etc.) provides access to its local variables: - -```python -# Get all local variables in a function -function = codebase.get_function("process_data") -local_vars = function.code_block.local_var_assignments -for var in local_vars: - print(var.name) - -# Find a specific variable -config_var = function.code_block.get_local_var_assignment("config") -config_var.rename("settings") # Updates all references safely - -# Rename a variable used in this scope (but not necessarily declared here) -function.rename_local_variable("foo", "bar") -``` - -## Fuzzy Matching - -Codegen supports fuzzy matching when searching for local variables. This allows you to find variables whose names contain a substring, rather than requiring exact matches: - -```python -# Get all local variables containing "config" -function = codebase.get_function("process_data") - -# Exact match - only finds variables named exactly "config" -exact_matches = function.code_block.get_local_var_assignments("config") -# Returns: config = {...} - -# Fuzzy match - finds any variable containing "config" -fuzzy_matches = function.code_block.get_local_var_assignments("config", fuzzy_match=True) -# Returns: config = {...}, app_config = {...}, config_settings = {...} - -# Fuzzy matching also works for variable usages -usages = function.code_block.get_variable_usages("config", fuzzy_match=True) - -# And for renaming variables -function.code_block.rename_variable_usages("config", "settings", fuzzy_match=True) -# Renames: config -> settings, app_config -> app_settings, config_settings -> settings_settings -``` - - - Be careful with fuzzy matching when renaming variables, as it will replace the - matched substring in all variable names. This might lead to unintended renames - like `config_settings` becoming `settings_settings`. - - - ---- -title: "Comments and Docstrings" -sidebarTitle: "Comments & Docstrings" -icon: "comment" -iconType: "solid" ---- - -Codegen enables reading, modifying, and manipulating comments and docstrings while preserving proper formatting. - -This guide describes proper usage of the following classes: - -- [Comment](/api-reference/core/Comment) - Represents a single comment. -- [CommentGroup](/api-reference/core/CommentGroup) - Represents a group of comments. - -## Accessing with Comments - -Comments can be accessed through any symbol or directly from code blocks. Each comment is represented by a `Comment` object that provides access to both the raw source and parsed text: - -```python -# Find all comments in a file -file = codebase.get_file("my_file.py") -for comment in file.code_block.comments: - print(comment.text) - -# Access comments associated with a symbol -symbol = file.get_symbol("my_function") -if symbol.comment: - print(symbol.comment.text) # Comment text without delimiters - print(symbol.comment.source) # Full comment including delimiters - -# Access inline comments -if symbol.inline_comment: - print(symbol.inline_comment.text) - -# Accessing all comments in a function -for comment in symbol.code_block.comments: - print(comment.text) -``` - -### Editing Comments - -Comments can be modified using the `edit_text()` method, which handles formatting and delimiters automatically: - -```python -# Edit a regular comment -symbol.comment.edit_text("Updated comment text") - -# Edit an inline comment -symbol.set_inline_comment("New inline comment") -``` - -### Comment Groups - -Multiple consecutive comments are automatically grouped into a `CommentGroup`, which can be edited as a single unit: - -```python -# Original comments: -# First line -# Second line -# Third line - -comment_group = symbol.comment -print(comment_group.text) # "First line\nSecond line\nThird line" - -# Edit the entire group at once -comment_group.edit_text("New first line\nNew second line") -``` - -## Working with Docstrings - -Docstrings are special comments that document functions, classes, and modules. Codegen provides similar APIs for working with docstrings: - -```python -function = file.get_symbol("my_function") -if function.docstring: - print(function.docstring.text) # Docstring content - print(function.docstring.source) # Full docstring with delimiters -``` - -### Adding Docstrings - -You can add docstrings to any symbol that supports them: - -```python -# Add a single-line docstring -function.set_docstring("A brief description") - -# Add a multi-line docstring -function.set_docstring(""" - A longer description that - spans multiple lines. - - Args: - param1: Description of first parameter -""") -``` - -### Language-Specific Formatting - -Codegen automatically handles language-specific docstring formatting: - -```python -# Python: Uses triple quotes -def my_function(): - """Docstring is formatted with triple quotes.""" - pass -``` - -```typescript -// TypeScript: Uses JSDoc style -function myFunction() { - /** Docstring is formatted as JSDoc */ -} -``` - -### Editing Docstrings - -Like comments, docstrings can be modified while preserving formatting: - -```python -# Edit a docstring -function.docstring.edit_text("Updated documentation") - -# Edit a multi-line docstring -function.docstring.edit_text(""" - Updated multi-line documentation - that preserves indentation and formatting. -""") -``` - -## Comment Operations - -Codegen provides utilities for working with comments at scale. For example, you can update or remove specific types of comments across your codebase: - -```python -# Example: Remove eslint disable comments for a specific rule -for file in codebase.files: - for comment in file.code_block.comments: - if "eslint-disable" in comment.source: - # Check if comment disables specific rule - if "@typescript-eslint/no-explicit-any" in comment.text: - comment.remove() -``` - - - When editing multi-line comments or docstrings, Codegen automatically handles - indentation and maintains the existing comment style. - - -## Special APIs and AI Integration - -### Google Style Docstrings - -Codegen supports Google-style docstrings and can handle their specific formatting, using the [CommentGroup.to_google_docstring(...)](/api-reference/core/CommentGroup#to-google-docstring) method. - -```python -# Edit while preserving Google style -symbol_a = file.get_symbol("SymbolA") -func_b = symbol_a.get_method("funcB") -func_b.docstring.to_google_docstring(func_b) -``` - -### Using AI for Documentation - -Codegen integrates with LLMs to help generate and improve documentation. You can use the [Codebase.ai(...)](/api-reference/core/Codebase#ai) method to: - -- Generate comprehensive docstrings -- Update existing documentation -- Convert between documentation styles -- Add parameter descriptions - -```python -# Generate a docstring using AI -function = codebase.get_function("my_function") - -new_docstring = codebase.ai( - "Generate a comprehensive docstring in Google style", - target=function - context={ - # provide additional context to the LLM - 'usages': function.usages, - 'dependencies': function.dependencies - } -) -function.set_docstring(new_docstring) -``` - - - Learn more about AI documentation capabilities in our [Documentation - Guide](/tutorials/creating-documentation) and [LLM Integration - Guide](/building-with-codegen/calling-out-to-llms). - - -### Documentation Coverage - -You can analyze and improve documentation coverage across your codebase: - -```python -# Count documented vs undocumented functions -total = 0 -documented = 0 -for function in codebase.functions: - total += 1 - if function.docstring: - documented += 1 - -coverage = (documented / total * 100) if total > 0 else 0 -print(f"Documentation coverage: {coverage:.1f}%") -``` - - - Check out the [Documentation Guide](/tutorials/creating-documentation) for - more advanced coverage analysis and bulk documentation generation. - - - ---- -title: "External Modules" -sidebarTitle: "External Modules" -icon: "box-archive" -iconType: "solid" ---- - -Codegen provides a way to handle imports from external packages and modules through the [ExternalModule](/api-reference/core/ExternalModule) class. - -```python -# Python examples -import datetime -from requests import get - -# TypeScript/JavaScript examples -import React from 'react' -import { useState, useEffect } from 'react' -import type { ReactNode } from 'react' -import axios from 'axios' -``` - -## What are External Modules? - -When writing code, you often import from packages that aren't part of your project - like `datetime` and `requests` in Python, or `react` and `axios` in TypeScript. In Codegen, these are represented as [ExternalModule](/api-reference/core/ExternalModule) instances. - -```python -for imp in codebase.imports: - if isinstance(imp.symbol, ExternalModule): - print(f"Importing from external package: {imp.resolved_symbol.source}") -``` - - - External modules are read-only - you can analyze them but can't modify their - implementation. This makes sense since they live in your project's - dependencies! - - -## Working with External Modules - -The most common use case is handling external modules differently from your project's code: - -### Identifying Function Calls as External Modules - -For [FunctionCall](/api-reference/core/FunctionCall) instances, you can check if the function definition is an [ExternalModule](/api-reference/core/ExternalModule) via the [FunctionCall.function_definition](/api-reference/core/FunctionCall#function-definition) property: - -```python -for fcall in file.function_calls: - definition = fcall.function_definition - if isinstance(definition, ExternalModule): - # Skip external functions - print(f'External function: {definition.name}') - else: - # Process local functions... - print(f'Local function: {definition.name}') -``` - -### Import Resolution - -Similarly, when working with imports, you can determine if they resolve to external modules by checking the [Import.resolved_symbol](/api-reference/core/Import#resolved-symbol) property: - -```python -for imp in file.imports: - resolved = imp.resolved_symbol - if isinstance(resolved, ExternalModule): - print(f"Import from external package: from {imp.module} import {imp.name}") -``` - - - Use `isinstance(symbol, ExternalModule)` to reliably identify external - modules. This works better than checking names or paths since it handles all - edge cases. - - -## Properties and Methods - -External modules provide several useful properties: - -```python -# Get the module name -module_name = external_module.name # e.g. "datetime" or "useState" - -# Check if it's from node_modules (TypeScript/JavaScript) -if external_module.filepath == "": - print("This is an external package from node_modules") -``` - -## Common Patterns - -Here are some typical ways you might work with external modules: - -### Skip External Processing: - -When modifying function calls or imports, skip external modules since they can't be changed: - -```python -# Example from a codemod that adds type hints -def add_type_hints(function): - if isinstance(function.definition, ExternalModule): - return # Can't add type hints to external modules like React.FC - # Add type hints to local functions... -``` - -### Analyze Dependencies - -Track which external packages your code uses: - -```python -# Find all external package dependencies -external_deps = set() -for imp in codebase.imports: - if isinstance(imp.resolved_symbol, ExternalModule): - external_deps.add(imp.resolved_symbol.source) - # Will find things like 'react', 'lodash', 'datetime', etc. -``` - - - When working with imports, always handle external modules as a special case. - This ensures your codemods work correctly with both local and external code. - - - ---- -title: "Working with Type Annotations" -sidebarTitle: "Type Annotations" -icon: "code" -iconType: "solid" ---- - -This guide covers the core APIs and patterns for working with type annotations in Codegen. - -## Type Resolution - -Codegen builds a complete dependency graph of your codebase, connecting functions, classes, imports, and their relationships. This enables powerful type resolution capabilities: - -```python -from codegen import Codebase - -# Initialize codebase with dependency graph -codebase = Codebase("./") - -# Get a function with a type annotation -function = codebase.get_file("path/to/file.py").get_function("my_func") - -# Resolve its return type to actual symbols -return_type = function.return_type -resolved_symbols = return_type.resolved_types # Returns the actual Symbol objects - -# For generic types, you can resolve parameters -if hasattr(return_type, "parameters"): - for param in return_type.parameters: - resolved_param = param.resolved_types # Get the actual type parameter symbols - -# For assignments, resolve their type -assignment = codebase.get_file("path/to/file.py").get_assignment("my_var") -resolved_type = assignment.type.resolved_types -``` - - - Type resolution follows imports and handles complex cases like type aliases, forward references, and generic type parameters. - - -## Core Interfaces - -Type annotations in Codegen are built on two key interfaces: - -- [Typeable](/api-reference/core/Typeable) - The base interface for any node that can have a type annotation (parameters, variables, functions, etc). Provides `.type` and `.is_typed`. -- [Type](/api-reference/core/Type) - The base class for all type annotations. Provides type resolution and dependency tracking. - -Any node that inherits from `Typeable` will have a `.type` property that returns a `Type` object, which can be used to inspect and modify type annotations. - -Learn more about [inheritable behaviors](/building-with-codegen/inheritable-behaviors) like Typeable here - -## Core Type APIs - -Type annotations can be accessed and modified through several key APIs: - -### Function Types - -The main APIs for function types are [Function.return_type](/api-reference/python/PyFunction#return-type) and [Function.set_return_type](/api-reference/python/PyFunction#set-return-type): - -```python -# Get return type -return_type = function.return_type # -> TypeAnnotation -print(return_type.source) # "List[str]" -print(return_type.is_typed) # True/False - -# Set return type -function.set_return_type("List[str]") -function.set_return_type(None) # Removes type annotation -``` - -### Parameter Types - -Parameters use [Parameter.type](/api-reference/core/Parameter#type) and [Parameter.set_type_annotation](/api-reference/core/Parameter#set-type-annotation): - -```python -for param in function.parameters: - # Get parameter type - param_type = param.type # -> TypeAnnotation - print(param_type.source) # "int" - print(param_type.is_typed) # True/False - - # Set parameter type - param.set_type("int") - param.set_type(None) # Removes type annotation -``` - -### Variable Types - -Variables and attributes use [Assignment.type](/api-reference/core/Assignment#type) and [Assignment.set_type_annotation](/api-reference/core/Assignment#set-type-annotation). This applies to: -- Global variables -- Local variables -- Class attributes (via [Class.attributes](/api-reference/core/Class#attributes)) - -```python -# For global/local assignments -assignment = file.get_assignment("my_var") -var_type = assignment.type # -> TypeAnnotation -print(var_type.source) # "str" - -# Set variable type -assignment.set_type("str") -assignment.set_type(None) # Removes type annotation - -# For class attributes -class_def = file.get_class("MyClass") -for attr in class_def.attributes: - # Each attribute has an assignment property - attr_type = attr.assignment.type # -> TypeAnnotation - print(f"{attr.name}: {attr_type.source}") # e.g. "x: int" - - # Set attribute type - attr.assignment.set_type("int") - -# You can also access attributes directly by index -first_attr = class_def.attributes[0] -first_attr.assignment.set_type("str") -``` - -## Working with Complex Types - -### Union Types - -Union types ([UnionType](/api-reference/core/UnionType)) can be manipulated as collections: - -```python -# Get union type -union_type = function.return_type # -> A | B -print(union_type.symbols) # ["A", "B"] - -# Add/remove options -union_type.append("float") -union_type.remove("None") - -# Check contents -if "str" in union_type.options: - print("String is a possible type") -``` -Learn more about [working with collections here](/building-with-codegen/collections) - -### Generic Types - -Generic types ([GenericType](/api-reference/core/GenericType)) expose their parameters as collection of [Parameters](/api-reference/core/Parameter): - -```python -# Get generic type -generic_type = function.return_type # -> GenericType -print(generic_type.base) # "List" -print(generic_type.parameters) # ["str"] - -# Modify parameters -generic_type.parameters.append("int") -generic_type.parameters[0] = "float" - -# Create new generic -function.set_return_type("List[str]") -``` -Learn more about [working with collections here](/building-with-codegen/collections) - -### Type Resolution - -Type resolution uses [Type.resolved_value](/api-reference/core/Type#resolved-value) to get the actual symbols that a type refers to: - -```python -# Get the actual symbols for a type -type_annotation = function.return_type # -> Type -resolved_types = type_annotation.resolved_value # Returns an Expression, likely a Symbol or collection of Symbols - -# For generic types, resolve each parameter -if hasattr(type_annotation, "parameters"): - for param in type_annotation.parameters: - param_types = param.resolved_value # Get symbols for each parameter - -# For union types, resolve each option -if hasattr(type_annotation, "options"): - for option in type_annotation.options: - option_types = option.resolved_value # Get symbols for each union option -``` - - ---- -title: "Moving Symbols" -sidebarTitle: "Moving Symbols" -icon: "arrows-up-down-left-right" -iconType: "solid" ---- - -Codegen provides fast, configurable and safe APIs for moving symbols (functions, classes, variables) between files while automatically handling imports and dependencies. - -The key API is [Symbol.move_to_file(...)](/api-reference/core/Symbol#move-to-file). - -## Basic Symbol Movement - -Simply call [Symbol.move_to_file(...)](/api-reference/core/Symbol#move-to-file) to move a symbol to a new file. - -```python -# Manipulation code: -file1 = codebase.get_file("file1.py") -file2 = codebase.get_file("file2.py") - -helper_func = file1.get_symbol("helper") - -# Ensure the destination file exists -if not file2.exists(): - file2 = codebase.create_file('file2.py') - -# Move the symbol -helper_func.move_to_file(file2) -``` - - - By default, this will move any dependencies, including imports, to the new - file. - - -## Moving Strategies - -The [Symbol.move_to_file(...)](/api-reference/core/Symbol#move-to-file) method accepts a `strategy` parameter, which can be used to control how imports are updated. - -Your options are: - -- `"update_all_imports"`: Updates all import statements across the codebase (default) -- `"add_back_edge"`: Adds import and re-export in the original file - -`"add_back_edge"` is useful when moving a symbol that is depended on by other symbols in the original file, and will result in smaller diffs. - - - `"add_back_edge"` will result in circular dependencies if the symbol has - non-import dependencies in it's original file. - - -## Moving Symbols in Bulk - -Make sure to call [Codebase.commit(...)](/api-reference/core/Codebase#commit) _after_ moving symbols in bulk for performant symbol movement. - -```python -# Move all functions with a specific prefix -for file in codebase.files: - for function in file.functions: - if function.name.startswith("pylsp_"): - function.move_to_file( - shared_file, - include_dependencies=True, - strategy="update_all_imports" - ) - -# Commit the changes once, at the end -codebase.commit() -``` - - ---- -title: "Collections" -sidebarTitle: "Collections" -icon: "layer-group" -iconType: "solid" ---- - -Codegen enables traversing and manipulating collections through the [List](/api-reference/core/List) and [Dict](/api-reference/core/Dict) classes. - -These APIs work consistently across Python and TypeScript while preserving formatting and structure. - -## Core Concepts - -The [List](/api-reference/core/List) and [Dict](/api-reference/core/Dict) classes provide a consistent interface for working with ordered sequences of elements. Key features include: - -- Standard sequence operations (indexing, length, iteration) -- Automatic formatting preservation -- Safe modification operations -- Language-agnostic behavior -- Comment and whitespace preservation - -Collections handle: - -- Proper indentation -- Delimiters (commas, newlines) -- Multi-line formatting -- Leading/trailing whitespace -- Nested structures - -## List Operations - -Lists in both Python and TypeScript can be manipulated using the same APIs: - -```python -# Basic operations -items_list = file.get_symbol("items").value # Get list value -first = items_list[0] # Access elements -length = len(items_list) # Get length -items_list[0] = "new" # Modify element -items_list.append("d") # Add to end -items_list.insert(1, "x") # Insert at position -del items_list[1] # Remove element - -# Iteration -for item in items_list: - print(item.source) - -# Bulk operations -items_list.clear() # Remove all elements -``` - -### Single vs Multi-line Lists - -Collections automatically preserve formatting: - -```python -# Source code: -items = [a, b, c] -config = [ - "debug", - "verbose", - "trace", -] - -# Manipulation code: -items_list = file.get_symbol("items").value -items_list.append("d") # Adds new element - -config_list = file.get_symbol("config").value -config_list.append("info") # Adds with formatting - -# Result: -items = [a, b, c, d] -config = [ - "debug", - "verbose", - "trace", - "info", -] -``` - -## Dictionary Operations - -Dictionaries provide a similar consistent interface: - -```python -# Basic operations -settings = file.get_symbol("settings").value # Get dict value -value = settings["key"] # Get value -settings["key"] = "value" # Set value -del settings["key"] # Remove key -has_key = "key" in settings # Check existence - -# Iteration -for key in settings: - print(f"{key}: {settings[key]}") - -# Bulk operations -settings.clear() # Remove all entries -``` - - ---- -title: "Traversing the Call Graph" -sidebarTitle: "Call Graph" -icon: "sitemap" -iconType: "solid" ---- - -Codegen provides powerful capabilities for analyzing and visualizing function call relationships in your codebase. This guide will show you how to traverse the call graph and create visual representations of function call paths. - -## Understanding Call Graph Traversal - -At the heart of call graph traversal is the [.function_calls](/api-reference/core/Function#function-calls) property, which returns information about all function calls made within a function: - -```python -def example_function(): - result = helper_function() - process_data() - return result - -# Get all calls made by example_function -successors = example_function.function_calls -for successor in successors: - print(f"Call: {successor.source}") # The actual function call - print(f"Called: {successor.function_definition.name}") # The function being called -``` - -## Building a Call Graph - -Here's how to build a directed graph of function calls using NetworkX: - -```python -import networkx as nx -from codegen.sdk.core.interfaces.callable import FunctionCallDefinition -from codegen.sdk.core.function import Function -from codegen.sdk.core.external_module import ExternalModule - -def create_call_graph(start_func, end_func, max_depth=5): - G = nx.DiGraph() - - def traverse_calls(parent_func, current_depth): - if current_depth > max_depth: - return - - # Determine source node - if isinstance(parent_func, Function): - src_call = src_func = parent_func - else: - src_func = parent_func.function_definition - src_call = parent_func - - # Skip external modules - if isinstance(src_func, ExternalModule): - return - - # Traverse all function calls - for call in src_func.function_calls: - func = call.function_definition - - # Skip recursive calls - if func.name == src_func.name: - continue - - # Add nodes and edges - G.add_node(call) - G.add_edge(src_call, call) - - # Check if we reached the target - if func == end_func: - G.add_edge(call, end_func) - return - - # Continue traversal - traverse_calls(call, current_depth + 1) - - # Initialize graph - G.add_node(start_func, color="blue") # Start node - G.add_node(end_func, color="red") # End node - - # Start traversal - traverse_calls(start_func, 1) - return G - -# Usage example -start = codebase.get_function("create_skill") -end = codebase.get_function("auto_define_skill_description") -graph = create_call_graph(start, end) -``` - -## Filtering and Visualization - -You can filter the graph to show only relevant paths and visualize the results: - -```python -# Find all paths between start and end -all_paths = nx.all_simple_paths(graph, source=start, target=end) - -# Create subgraph of only the nodes in these paths -nodes_in_paths = set() -for path in all_paths: - nodes_in_paths.update(path) -filtered_graph = graph.subgraph(nodes_in_paths) - -# Visualize the graph -codebase.visualize(filtered_graph) -``` - -## Advanced Usage - -### Example: Finding Dead Code - -You can use call graph analysis to find unused functions: - -```python -def find_dead_code(codebase): - dead_functions = [] - for function in codebase.functions: - if not any(function.function_calls): - # No other functions call this one - dead_functions.append(function) - return dead_functions -``` - -### Example: Analyzing Call Chains - -Find the longest call chain in your codebase: - -```python -def get_max_call_chain(function): - G = nx.DiGraph() - - def build_graph(func, depth=0): - if depth > 10: # Prevent infinite recursion - return - for call in func.function_calls: - called_func = call.function_definition - G.add_edge(func, called_func) - build_graph(called_func, depth + 1) - - build_graph(function) - return nx.dag_longest_path(G) -``` - - -The `.function_calls` property is optimized for performance and uses Codegen's internal graph structure to quickly traverse relationships. It's much faster than parsing the code repeatedly. - - - -When traversing call graphs, be mindful of: -- Recursive calls that could create infinite loops -- External module calls that might not be resolvable -- Dynamic/runtime function calls that can't be statically analyzed - - - ---- -title: "React and JSX" -sidebarTitle: "React and JSX" -icon: "react" -iconType: "brands" ---- - -GraphSitter exposes several React and JSX-specific APIs for working with modern React codebases. - -Key APIs include: - -- [Function.is_jsx](/api-reference/typescript/TSFunction#is-jsx) - Check if a function contains JSX elements -- [Class.jsx_elements](/api-reference/typescript/TSClass#jsx-elements) - Get all JSX elements in a class -- [Function.jsx_elements](/api-reference/typescript/TSFunction#jsx-elements) - Get all JSX elements in a function -- [JSXElement](/api-reference/typescript/JSXElement) - Manipulate JSX elements -- [JSXProp](/api-reference/typescript/JSXProp) - Manipulate JSX props - - - See [React Modernization](/tutorials/react-modernization) for tutorials and - applications of the concepts described here - - -## Detecting React Components with `is_jsx` - -Codegen exposes a `is_jsx` property on both classes and functions, which can be used to check if a symbol is a React component. - -```python -# Check if a function is a React component -function = file.get_function("MyComponent") -is_component = function.is_jsx # True for React components - -# Check if a class is a React component -class_def = file.get_class("MyClassComponent") -is_component = class_def.is_jsx # True for React class components -``` - -## Working with JSX Elements - -Given a React component, you can access its JSX elements using the [jsx_elements](/api-reference/typescript/TSFunction#jsx-elements) property. - -You can manipulate these elements by using the [JSXElement](/api-reference/typescript/JSXElement) and [JSXProp](/api-reference/typescript/JSXProp) APIs. - -```python -# Get all JSX elements in a component -for element in component.jsx_elements: - # Access element name - if element.name == "Button": - # Wrap element in a div - element.wrap("
", "
") - - # Get specific prop - specific_prop = element.get_prop("className") - - # Iterate over all props - for prop in element.props: - if prop.name == "className": - # Set prop value - prop.set_value('"my-classname"') - - # Modify element - element.set_name("NewComponent") - element.add_prop("newProp", "{value}") - - # Get child JSX elements - child_elements = element.jsx_elements - - # Wrap element in a JSX expression (preserves whitespace) - element.wrap("
", "
") -``` - -## Common React Operations - -See [React Modernization](/tutorials/react-modernization) for more - -### Refactoring Components into Separate Files - -Split React components into individual files: - -```python -# Find (named) React components -react_components = [ - func for func in codebase.functions - if func.is_jsx and func.name is not None -] - -# Filter out those that are not the default export -non_default_components = [ - comp for comp in react_components - if not comp.export or not comp.export.is_default_export() -] - -# Move these non-default components to new files -for component in react_components: - if component != default_component: - # Create new file - new_file_path = '/'.join(component.filepath.split('/')[:-1]) + f"{component.name}.tsx" - new_file = codebase.create_file(new_file_path) - - # Move component and update imports - component.move_to_file(new_file, strategy="add_back_edge") -``` - - - See [Moving Symbols](/building-with-codegen/moving-symbols) for more details - on moving symbols between files. - - -### Updating Component Names and Props - -Replace components throughout the codebase with prop updates: - -```python -# Find target component -new_component = codebase.get_symbol("NewComponent") - -for function in codebase.functions: - if function.is_jsx: - # Update JSX elements - for element in function.jsx_elements: - if element.name == "OldComponent": - # Update name - element.set_name("NewComponent") - - # Edit props - needs_clsx = not file.has_import("clsx") - for prop in element.props: - if prop.name == "className": - prop.set_value('clsx("new-classname")') - needs_clsx = True - elif prop.name == "onClick": - prop.set_name('handleClick') - - # Add import if needed - if needs_clsx: - file.add_import_from_import_source("import clsx from 'clsx'") - - # Add import if needed - if not file.has_import("NewComponent"): - file.add_import(new_component) -``` - - ---- -title: "Codebase Visualization" -sidebarTitle: "Visualization" -icon: "share-nodes" -iconType: "solid" ---- - -Codegen provides the ability to create interactive graph visualizations via the [codebase.visualize(...)](/api-reference/core/Codebase#visualize) method. - -These visualizations have a number of applications, including: - -- Understanding codebase structure -- Monitoring critical code paths -- Analyzing dependencies -- Understanding inheritance hierarchies - -This guide provides a basic overview of graph creation and customization. Like the one below which displays the call_graph for the [modal/client.py](https://github.com/modal-labs/modal-client/blob/v0.72.49/modal/client.py) module. - - - - - Codegen visualizations are powered by [NetworkX](https://networkx.org/) and - rendered using [d3](https://d3js.org/what-is-d3). - - -## Basic Usage - -The [Codebase.visualize](/api-reference/core/Codebase#visualize) method operates on a NetworkX [DiGraph](https://networkx.org/documentation/stable/reference/classes/graph.DiGraph.html). - -```python -import networkx as nx - -# Basic visualization -G = nx.grid_2d_graph(5, 5) -# Or start with an empty graph -# G = nx.DiGraph() -codebase.visualize(G) - -``` - -It is up to the developer to add nodes and edges to the graph. - -### Adding Nodes and Edges - -When adding nodes to your graph, you can either add the symbol directly or just its name: - -```python -import networkx as nx -G = nx.DiGraph() -function = codebase.get_function("my_function") - -# Add the function object directly - enables source code preview -graph.add_node(function) # Will show function's source code on click - -# Add just the name - no extra features -graph.add_node(function.name) # Will only show the name -``` - - - Adding symbols to the graph directly (as opposed to adding by name) enables - automatic type information, code preview on hover, and more. - - -## Common Visualization Types - -### Call Graphs - -Visualize how functions call each other and trace execution paths: - -```python -def create_call_graph(entry_point: Function): - graph = nx.DiGraph() - - def add_calls(func): - for call in func.call_sites: - called_func = call.resolved_symbol - if called_func: - # Add function objects for rich previews - graph.add_node(func) - graph.add_node(called_func) - graph.add_edge(func, called_func) - add_calls(called_func) - - add_calls(entry_point) - return graph - -# Visualize API endpoint call graph -endpoint = codebase.get_function("handle_request") -call_graph = create_call_graph(endpoint) -codebase.visualize(call_graph, root=endpoint) -``` - - - Learn more about [traversing the call graph - here](/building-with-codegen/traversing-the-call-graph). - - -### React Component Trees - -Visualize the hierarchy of React components: - -```python -def create_component_tree(root_component: Class): - graph = nx.DiGraph() - - def add_children(component): - for usage in component.usages: - if isinstance(usage.parent, Class) and "Component" in usage.parent.bases: - graph.add_edge(component.name, usage.parent.name) - add_children(usage.parent) - - add_children(root_component) - return graph - -# Visualize component hierarchy -app = codebase.get_class("App") -component_tree = create_component_tree(app) -codebase.visualize(component_tree, root=app) -``` - -### Inheritance Graphs - -Visualize class inheritance relationships: - -```python -import networkx as nx - -G = nx.DiGraph() -base = codebase.get_class("BaseModel") - -def add_subclasses(cls): - for subclass in cls.subclasses: - G.add_edge(cls, subclass) - add_subclasses(subclass) - -add_subclasses(base) - -codebase.visualize(G, root=base) -``` - -### Module Dependencies - -Visualize dependencies between modules: - -```python -def create_module_graph(start_file: File): - G = nx.DiGraph() - - def add_imports(file): - for imp in file.imports: - if imp.resolved_symbol and imp.resolved_symbol.file: - graph.add_edge(file, imp.resolved_symbol.file) - add_imports(imp.resolved_symbol.file) - - add_imports(start_file) - return graph - -# Visualize module dependencies -main = codebase.get_file("main.py") -module_graph = create_module_graph(main) -codebase.visualize(module_graph, root=main) -``` - -### Function Modularity - -Visualize function groupings by modularity: - -```python -def create_modularity_graph(functions: list[Function]): - graph = nx.Graph() - - # Group functions by shared dependencies - for func in functions: - for dep in func.dependencies: - if isinstance(dep, Function): - weight = len(set(func.dependencies) & set(dep.dependencies)) - if weight > 0: - graph.add_edge(func.name, dep.name, weight=weight) - - return graph - -# Visualize function modularity -funcs = codebase.functions -modularity_graph = create_modularity_graph(funcs) -codebase.visualize(modularity_graph) -``` - -## Customizing Visualizations - -You can customize your visualizations using NetworkX's attributes while still preserving the smart node features: - -```python -def create_custom_graph(codebase): - graph = nx.DiGraph() - - # Add nodes with custom attributes while preserving source preview - for func in codebase.functions: - graph.add_node(func, - color='red' if func.is_public else 'blue', - shape='box' if func.is_async else 'oval' - ) - - # Add edges between actual function objects - for func in codebase.functions: - for call in func.call_sites: - if call.resolved_symbol: - graph.add_edge(func, call.resolved_symbol, - style='dashed' if call.is_conditional else 'solid', - weight=call.count - ) - - return graph -``` - -## Best Practices - -1. **Use Symbol Objects for Rich Features** - - ```python - # Better: Add symbol objects for rich previews - # This will include source code previews, syntax highlighting, type information, etc. - for func in api_funcs: - graph.add_node(func) - - # Basic: Just names, no extra features - for func in api_funcs: - graph.add_node(func.name) - ``` - -2. **Focus on Relevant Subgraphs** - - ```python - # Better: Visualize specific subsystem - api_funcs = [f for f in codebase.functions if "api" in f.filepath] - api_graph = create_call_graph(api_funcs) - codebase.visualize(api_graph) - - # Avoid: Visualizing entire codebase - full_graph = create_call_graph(codebase.functions) # Too complex - ``` - -3. **Use Meaningful Layouts** - - ```python - # Group related nodes together - graph.add_node(controller_class, cluster="api") - graph.add_node(service_class, cluster="db") - ``` - -4. **Add Visual Hints** - ```python - # Color code by type while preserving rich previews - for node in codebase.functions: - if "Controller" in node.name: - graph.add_node(node, color="red") - elif "Service" in node.name: - graph.add_node(node, color="blue") - ``` - -## Limitations - -- Large graphs may become difficult to read -- Complex relationships might need multiple views -- Some graph layouts may take time to compute -- Preview features only work when adding symbol objects directly - - - ---- -title: "Flagging Symbols" -description: "Learn how to use symbol flags for debugging, tracking changes, and marking code for review" -icon: "flag" -iconType: "solid" ---- - -# Flagging Symbols - -Symbol flags are a powerful feature in Codegen that allow you to mark and track specific code elements during development, debugging, or code review processes. Flags can be used to visually highlight code in the editor and can also integrate with various messaging systems. - -## Basic Usage - -The simplest way to flag a symbol is to call the `flag()` method on any symbol: - -```python -# Flag a function -function.flag(message="This function needs optimization") - -# Flag a class -my_class.flag(message="Consider breaking this into smaller classes") - -# Flag a variable -variable.flag(message="Type hints needed here") -``` - -When you flag a symbol, two things happen: -1. A visual flag emoji (🚩) is added as an inline comment -2. A `CodeFlag` object is created to track the flag in the system - - -## Language-Specific Behavior - -The flag system adapts automatically to the programming language being used: - -```python -# Python -# Results in: def my_function(): # 🚩 Review needed -python_function.flag(message="Review needed") - -# TypeScript -# Results in: function myFunction() { // 🚩 Review needed -typescript_function.flag(message="Review needed") -``` - - -## Example: Code Analysis - -Here's an example of using flags during code analysis: - -```python -def analyze_codebase(codebase): - for function in codebase.functions: - # Check documentation - if not function.docstring: - function.flag( - message="Missing docstring", - ) - - # Check error handling - if function.is_async and not function.has_try_catch: - function.flag( - message="Async function missing error handling", - ) -``` - -This feature is particularly useful when building, and iterating on the symbols that you are trying to modify. - ---- -title: "Calling Out to LLMs" -sidebarTitle: "LLM Integration" -icon: "brain" -iconType: "solid" ---- - -Codegen natively integrates with LLMs via the [codebase.ai(...)](../api-reference/core/Codebase#ai) method, which lets you use large language models (LLMs) to help generate, modify, and analyze code. - -## Configuration - -Before using AI capabilities, you need to provide an OpenAI API key via [codebase.set_ai_key(...)](../api-reference/core/Codebase#set-ai-key): - -```python -# Set your OpenAI API key -codebase.set_ai_key("your-openai-api-key") -``` - -## Calling Codebase.ai(...) - -The [Codebase.ai(...)](../api-reference/core/Codebase#ai) method takes three key arguments: - -```python -result = codebase.ai( - prompt="Your instruction to the AI", - target=symbol_to_modify, # Optional: The code being operated on - context=additional_info # Optional: Extra context from static analysis -) -``` - -- **prompt**: Clear instruction for what you want the AI to do -- **target**: The symbol (function, class, etc.) being operated on - its source code will be provided to the AI -- **context**: Additional information you want to provide to the AI, which you can gather using GraphSitter's analysis tools - - - Codegen does not automatically provide any context to the LLM by default. It - does not "understand" your codebase, only the context you provide. - - -The context parameter can include: - -- A single symbol (its source code will be provided) -- A list of related symbols -- A dictionary mapping descriptions to symbols/values -- Nested combinations of the above - -### How Context Works - -The AI doesn't automatically know about your codebase. Instead, you can provide relevant context by: - -1. Using GraphSitter's static analysis to gather information: - -```python -function = codebase.get_function("process_data") -context = { - "call_sites": function.call_sites, # Where the function is called - "dependencies": function.dependencies, # What the function depends on - "parent": function.parent, # Class/module containing the function - "docstring": function.docstring, # Existing documentation -} -``` - -2. Passing this information to the AI: - -```python -result = codebase.ai( - "Improve this function's implementation", - target=function, - context=context # AI will see the gathered information -) -``` - -## Common Use Cases - -### Code Generation - -Generate new code or refactor existing code: - -```python -# Break up a large function -function = codebase.get_function("large_function") -new_code = codebase.ai( - "Break this function into smaller, more focused functions", - target=function -) -function.edit(new_code) - -# Generate a test -my_function = codebase.get_function("my_function") -test_code = codebase.ai( - f"Write a test for the function {my_function.name}", - target=my_function -) -my_function.insert_after(test_code) -``` - -### Documentation - -Generate and format documentation: - -```python -# Generate docstrings for a class -class_def = codebase.get_class("MyClass") -for method in class_def.methods: - docstring = codebase.ai( - "Generate a docstring describing this method", - target=method, - context={ - "class": class_def, - "style": "Google docstring format" - } - ) - method.set_docstring(docstring) -``` - -### Code Analysis and Improvement - -Use AI to analyze and improve code: - -```python -# Improve function names -for function in codebase.functions: - if codebase.ai( - "Does this function name clearly describe its purpose? Answer yes/no", - target=function - ).lower() == "no": - new_name = codebase.ai( - "Suggest a better name for this function", - target=function, - context={"call_sites": function.call_sites} - ) - function.rename(new_name) -``` - -### Contextual Modifications - -Make changes with full context awareness: - -```python -# Refactor a class method -method = codebase.get_class("MyClass").get_method("target_method") -new_impl = codebase.ai( - "Refactor this method to be more efficient", - target=method, - context={ - "parent_class": method.parent, - "call_sites": method.call_sites, - "dependencies": method.dependencies - } -) -method.edit(new_impl) -``` - -## Best Practices - -1. **Provide Relevant Context** - - ```python - # Good: Providing specific, relevant context - summary = codebase.ai( - "Generate a summary of this method's purpose", - target=method, - context={ - "class": method.parent, # Class containing the method - "usages": list(method.usages), # How the method is used - "dependencies": method.dependencies, # What the method depends on - "style": "concise" - } - ) - - # Bad: Missing context that could help the AI - summary = codebase.ai( - "Generate a summary", - target=method # AI only sees the method's code - ) - ``` - -2. **Gather Comprehensive Context** - - ```python - # Gather relevant information before AI call - def get_method_context(method): - return { - "class": method.parent, - "call_sites": list(method.call_sites), - "dependencies": list(method.dependencies), - "related_methods": [m for m in method.parent.methods - if m.name != method.name] - } - - # Use gathered context in AI call - new_impl = codebase.ai( - "Refactor this method to be more efficient", - target=method, - context=get_method_context(method) - ) - ``` - -3. **Handle AI Limits** - - ```python - # Set custom AI request limits for large operations - codebase.set_session_options(max_ai_requests=200) - ``` - -4. **Review Generated Code** - ```python - # Generate and review before applying - new_code = codebase.ai( - "Optimize this function", - target=function - ) - print("Review generated code:") - print(new_code) - if input("Apply changes? (y/n): ").lower() == 'y': - function.edit(new_code) - ``` - -## Limitations and Safety - -- The AI doesn't automatically know about your codebase - you must provide relevant context -- AI-generated code should always be reviewed -- Default limit of 150 AI requests per codemod execution - - Use [set_session_options(...)](../api-reference/core/Codebase#set-session-options) to adjust limits: - ```python - codebase.set_session_options(max_ai_requests=200) - ``` - - You can also use `codebase.set_session_options` to increase the execution time and the number of operations allowed in a session. This is useful for handling larger tasks or more complex operations that require additional resources. Adjust the `max_seconds` and `max_transactions` parameters to suit your needs: - ```python - codebase.set_session_options(max_seconds=300, max_transactions=500) - ``` - - ---- -title: "Semantic Code Search" -sidebarTitle: "Semantic Code Search" -icon: "magnifying-glass" -iconType: "solid" ---- - -Codegen provides semantic code search capabilities using embeddings. This allows you to search codebases using natural language queries and find semantically related code, even when the exact terms aren't present. - -This is under active development. Interested in an application? [Reach out to the team!](/introduction/about.tsx) - -## Basic Usage - -Here's how to create and use a semantic code search index: - -```python -# Parse a codebase -codebase = Codebase.from_repo('fastapi/fastapi', language='python') - -# Create index -index = FileIndex(codebase) -index.create() # computes per-file embeddings - -# Save index to .pkl -index.save('index.pkl') - -# Load index into memory -index.load('index.pkl') - -# Update index after changes -codebase.files[0].edit('# 🌈 Replacing File Content 🌈') -codebase.commit() -index.update() # re-computes 1 embedding -``` - - -## Searching Code - -Once you have an index, you can perform semantic searches: - -```python -# Search with natural language -results = index.similarity_search( - "How does FastAPI handle dependency injection?", - k=5 # number of results -) - -# Print results -for file, score in results: - print(f"\nScore: {score:.3f} | File: {file.filepath}") - print(f"Preview: {file.content[:200]}...") -``` -The `FileIndex` returns tuples of ([File](/api-reference/core/SourceFile), `score`) - - -The search uses cosine similarity between embeddings to find the most semantically related files, regardless of exact keyword matches. - - -## Available Indices - -Codegen provides two types of semantic indices: - -### FileIndex - -The `FileIndex` operates at the file level: -- Indexes entire files, splitting large files into chunks -- Best for finding relevant files or modules -- Simpler and faster to create/update - -```python -from codegen.extensions.index.file_index import FileIndex - -index = FileIndex(codebase) -index.create() -``` - -### SymbolIndex (Experimental) - -The `SymbolIndex` operates at the symbol level: -- Indexes individual functions, classes, and methods -- Better for finding specific code elements -- More granular search results - -```python -from codegen.extensions.index.symbol_index import SymbolIndex - -index = SymbolIndex(codebase) -index.create() -``` - -## How It Works - -The semantic indices: -1. Process code at either file or symbol level -2. Split large content into chunks that fit within token limits -3. Use OpenAI's text-embedding-3-small model to create embeddings -4. Store embeddings efficiently for similarity search -5. Support incremental updates when code changes - -When searching: -1. Your query is converted to an embedding -2. Cosine similarity is computed with all stored embeddings -3. The most similar items are returned with their scores - - -Creating embeddings requires an OpenAI API key with access to the embeddings endpoint. - - -## Example Searches - -Here are some example semantic searches: - -```python -# Find authentication-related code -results = index.similarity_search( - "How is user authentication implemented?", - k=3 -) - -# Find error handling patterns -results = index.similarity_search( - "Show me examples of error handling and custom exceptions", - k=3 -) - -# Find configuration management -results = index.similarity_search( - "Where is the application configuration and settings handled?", - k=3 -) -``` - -The semantic search can understand concepts and return relevant results even when the exact terms aren't present in the code. - - ---- -title: "Reducing Conditions" -sidebarTitle: "Reducing Conditions" -icon: "code-branch" -iconType: "solid" ---- - -Codegen provides powerful APIs for reducing conditional logic to constant values. This is particularly useful for removing feature flags, cleaning up dead code paths, and simplifying conditional logic. - -## Overview - -The `reduce_condition()` method is available on various conditional constructs: - -- [If/else statements](/api-reference/core/IfBlockStatement#reduce-condition) -- [Ternary expressions](/api-reference/core/TernaryExpression#reduce-condition) -- [Binary expressions](/api-reference/core/BinaryExpression#reduce-condition) -- [Function calls](/api-reference/core/FunctionCall#reduce-condition) - -When you reduce a condition to `True` or `False`, Codegen automatically: - -1. Evaluates which code path(s) to keep -2. Removes unnecessary branches -3. Preserves proper indentation and formatting - -### Motivating Example - -For example, consider the following code: - -```python -flag = get_feature_flag('MY_FEATURE') -if flag: - print('MY_FEATURE: ON') -else: - print('MY_FEATURE: OFF') -``` - -`.reduce_condition` allows you to deterministically reduce this code to the following: - -```python -print('MY_FEATURE: ON') -``` - -This is useful when a feature flag is fully "rolled out". - -## Implementations - -### [IfBlockStatements](/api-reference/core/IfBlockStatement#reduce-condition) - -You can reduce if/else statements to either their "true" or "false" branch. - -For example, in the code snippet above: - -```python -# Grab if statement -if_block = file.code_block.statements[1] - -# Reduce to True branch -if_block.reduce_condition(True) -``` - -This will remove the `else` branch and keep the `print` statement, like so: - -```python -flag = get_feature_flag('MY_FEATURE') -print('MY_FEATURE: ON') -``` - -### Handling Elif Chains - -Codegen intelligently handles elif chains when reducing conditions: - -```python -# Original code -if condition_a: - print("A") -elif condition_b: - print("B") -else: - print("C") - -# Reduce first condition to False -if_block.reduce_condition(False) -# Result: -if condition_b: - print("B") -else: - print("C") - -# Reduce elif condition to True -elif_block.reduce_condition(True) -# Result: -print("B") -``` - -## Ternary Expressions - -Ternary expressions (conditional expressions) can also be reduced: - -```python -# Original code -result = 'valueA' if condition else 'valueB' - -# Reduce to True -ternary_expr.reduce_condition(True) -# Result: -result = 'valueA' - -# Reduce to False -ternary_expr.reduce_condition(False) -# Result: -result = 'valueB' -``` - -### Nested Ternaries - -Codegen handles nested ternary expressions correctly: - -```python -# Original code -result = 'A' if a else 'B' if b else 'C' - -# Reduce outer condition to False -outer_ternary.reduce_condition(False) -# Result: -result = 'B' if b else 'C' - -# Then reduce inner condition to True -inner_ternary.reduce_condition(True) -# Result: -result = 'B' -``` - -## Binary Operations - -Binary operations (and/or) can be reduced to simplify logic: - -```python -# Original code -result = (x or y) and b - -# Reduce x to True -x_assign.reduce_condition(True) -# Result: -result = b - -# Reduce y to False -y_assign.reduce_condition(False) -# Result: -result = x and b -``` - -## Function Calls - -[Function calls](/api-reference/core/FunctionCall#reduce-condition) can also be reduced, which is particularly useful when dealing with hooks or utility functions that return booleans: - -```typescript -// Original code -const isEnabled = useFeatureFlag("my_feature"); -return isEnabled ? : ; - -// After reducing useFeatureFlag to True -return ; -``` - -### Feature Flag Hooks - -A common use case is reducing feature flag hooks to constants. Consider the following code: - -```typescript -// Original code -function MyComponent() { - const showNewUI = useFeatureFlag("new_ui_enabled"); - - if (showNewUI) { - return ; - } - return ; -} -``` - -We can reduce the `useFeatureFlag` hook to a constant value like so, with [FunctionCall.reduce_condition](/api-reference/core/FunctionCall#reduce-condition): - -```python -hook = codebase.get_function("useFeatureFlag") -for usage in hook.usages(): - if isinstance(usage.match, FunctionCall): - fcall = usage.match - if fcall.args[0].value.content == 'new_ui_enabled': - # This will automatically reduce any conditions using the flag - fcall.reduce_condition(True) -``` - -This produces the following code: - -```typescript -function MyComponent() { - return ; -} -``` - -### Comprehensive Example - -Here's a complete example of removing a feature flag from both configuration and usage: - -```python -feature_flag_name = "new_ui_enabled" -target_value = True - -# 1. Remove from config -config_file = codebase.get_file("src/featureFlags/config.ts") -feature_flag_config = config_file.get_symbol("FEATURE_FLAG_CONFIG").value -feature_flag_config.pop(feature_flag_name) - -# 2. Find and reduce all usages -hook = codebase.get_function("useFeatureFlag") -for usage in hook.usages(): - fcall = usage.match - if isinstance(fcall, FunctionCall): - # Check if this usage is for our target flag - first_arg = fcall.args[0].value - if isinstance(first_arg, String) and first_arg.content == feature_flag_name: - print(f'Reducing in: {fcall.parent_symbol.name}') - # This will automatically reduce: - # - Ternary expressions using the flag - # - If statements checking the flag - # - Binary operations with the flag - fcall.reduce_condition(target_value) - -# Commit changes to disk -codebase.commit() -``` - -This example: - -1. Removes the feature flag from configuration -2. Finds all usages of the feature flag hook -3. Reduces each usage to a constant value -4. Automatically handles all conditional constructs using the flag - - - When reducing a function call, Codegen automatically handles all dependent - conditions. This includes: - [If/else - statements](/api-reference/core/IfBlockStatement#reduce-condition) - [Ternary - expressions](/api-reference/core/TernaryExpression#reduce-condition) - [Binary - operations](/api-reference/core/BinaryExpression#reduce-condition) - - -## TypeScript and JSX Support - -Condition reduction works with TypeScript and JSX, including conditional rendering: - -```typescript -// Original JSX -const MyComponent: React.FC = () => { - let isVisible = true; - return ( -
- {isVisible && Visible} - {!isVisible && Hidden} -
- ); -}; - -// After reducing isVisible to True -const MyComponent: React.FC = () => { - return ( -
- Visible -
- ); -}; -``` - - - Condition reduction is particularly useful for cleaning up feature flags in - React components, where conditional rendering is common. - - - ---- -title: "Learn by Example" -sidebarTitle: "At a Glance" -icon: "graduation-cap" -iconType: "solid" ---- - -Explore our tutorials to learn how to use Codegen for various code transformation tasks. - -## Featured Tutorials - - - - Create an intelligent code agent with Langchain and powerful, codegen-powered tools - - - Generate interactive visualizations of your codebase's structure, dependencies, and relationships. - - - Create high-quality training data for LLM pre-training similar to word2vec or node2vec - - - Remove unused imports, functions, and variables with confidence. - - - -## API Migrations - - - - Update API calls, handle breaking changes, and manage bulk updates across your codebase. - - - Update SQLAlchemy code to use the new 2.0-style query interface and patterns. - - - Convert Flask applications to FastAPI, updating routes and dependencies. - - - Migrate Python 2 code to Python 3, updating syntax and modernizing APIs. - - - -## Code Organization - - - - Restructure files, enforce naming conventions, and improve project layout. - - - Split large files, extract shared logic, and manage dependencies. - - - Organize and optimize TypeScript module exports. - - - Convert between default and named exports in TypeScript/JavaScript. - - - -## Testing & Types - - - - Convert unittest test suites to pytest's modern testing style. - - - Add TypeScript types, infer types from usage, and improve type safety. - - - -## Documentation & AI - - - - Generate JSDoc comments, README files, and API documentation. - - - Generate system prompts, create hierarchical documentation, and optimize for AI assistance. - - - - - Each tutorial includes practical examples, code snippets, and best practices. - Follow them in order or jump to the ones most relevant to your needs. - - - ---- -title: "Building Code Agents" -sidebarTitle: "Code Agent" -icon: "robot" -iconType: "solid" ---- - -This guide demonstrates how to build an intelligent code agent that can analyze and manipulate codebases. - -```python -from codegen import CodeAgent, Codebase - -# Grab a repo from Github -codebase = Codebase.from_repo('fastapi/fastapi') - -# Create a code agent with read/write codebase access -agent = CodeAgent(codebase) - -# Run the agent with a prompt -agent.run("Tell me about this repo") -``` - -The agent has access to powerful code viewing and manipulation tools powered by Codegen, including: -- `ViewFileTool`: View contents and metadata of files -- `SemanticEditTool`: Make intelligent edits to files -- `RevealSymbolTool`: Analyze symbol dependencies and usages -- `MoveSymbolTool`: Move symbols between files with import handling -- `ReplacementEditTool`: Make regex-based replacement editing on files -- `ListDirectoryTool`: List directory contents -- `SearchTool`: Search for files and symbols -- `CreateFileTool`: Create new files -- `DeleteFileTool`: Delete files -- `RenameFileTool`: Rename files -- `EditFileTool`: Edit files - - - -View the full code for the default tools and agent implementation in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/src/codegen/extensions/langchain/tools) - -# Basic Usage - -The following example shows how to create and run a `CodeAgent`: - -```python -from codegen import CodeAgent, Codebase - -# Grab a repo from Github -codebase = Codebase.from_repo('fastapi/fastapi') - -# Create a code agent with read/write codebase access -agent = CodeAgent(codebase) - -# Run the agent with a prompt -agent.run("Tell me about this repo") -``` - - -Your `ANTHROPIC_API_KEY` must be set in your env. - -The default implementation uses `anthropic/claude-3-5-sonnet-latest` for the model but this can be changed through the `model_provider` and `model_name` arguments. - -```python -agent = CodeAgent( - codebase=codebase, - model_provider="openai", - model_name="gpt-4o", -) -``` - -If using a non-default model provider, make sure to set the appropriate API key (e.g., `OPENAI_API_KEY` for OpenAI models) in your env. - -# Available Tools - -The agent comes with a comprehensive set of tools for code analysis and manipulation. Here are some key tools: - -```python -from codegen.extensions.langchain.tools import ( - CreateFileTool, - DeleteFileTool, - EditFileTool, - ListDirectoryTool, - MoveSymbolTool, - RenameFileTool, - ReplacementEditTool, - RevealSymbolTool, - SearchTool, - SemanticEditTool, - ViewFileTool, -) -``` - -View the full set of [tools on Github](https://github.com/codegen-sh/codegen-sdk/blob/develop/src/codegen/extensions/langchain/tools.py) - -Each tool provides specific capabilities: - -# Extensions - -## GitHub Integration - -The agent includes tools for GitHub operations like PR management. Set up GitHub access with: - -```bash -CODEGEN_SECRETS__GITHUB_TOKEN="..." -``` - -Import the GitHub tools: - -```python -from codegen.extensions.langchain.tools import ( - GithubCreatePRTool, - GithubViewPRTool, - GithubCreatePRCommentTool, - GithubCreatePRReviewCommentTool -) -``` - -These tools enable: -- Creating pull requests -- Viewing PR contents and diffs -- Adding general PR comments -- Adding inline review comments - -View all Github tools on [Github](https://github.com/codegen-sh/codegen-sdk/blob/develop/src/codegen/extensions/langchain/tools.py) - - -## Linear Integration - -The agent can interact with Linear for issue tracking and project management. To use Linear tools, set the following environment variables: - -```bash -LINEAR_ACCESS_TOKEN="..." -LINEAR_TEAM_ID="..." -LINEAR_SIGNING_SECRET="..." -``` - -Import and use the Linear tools: - -```python -from codegen.extensions.langchain.tools import ( - LinearGetIssueTool, - LinearGetIssueCommentsTool, - LinearCommentOnIssueTool, - LinearSearchIssuesTool, - LinearCreateIssueTool, - LinearGetTeamsTool -) -``` - -These tools allow the agent to: -- Create and search issues -- Get issue details and comments -- Add comments to issues -- View team information - -View all Linear tools on [Github](https://github.com/codegen-sh/codegen-sdk/blob/develop/src/codegen/extensions/langchain/tools.py) - - -## Adding Custom Tools - -You can extend the agent with custom tools: - -```python -from langchain.tools import BaseTool -from pydantic import BaseModel, Field -from codegen import CodeAgent - -class CustomToolInput(BaseModel): - """Input schema for custom tool.""" - param: str = Field(..., description="Parameter description") - -class CustomCodeTool(BaseTool): - """A custom tool for the code agent.""" - name = "custom_tool" - description = "Description of what the tool does" - args_schema = CustomToolInput - - def _run(self, param: str) -> str: - # Tool implementation - return f"Processed {param}" - -# Add custom tool to agent -tools.append(CustomCodeTool()) -agent = CodebaseAgent(codebase, tools=tools, model_name="claude-3-5-sonnet-latest") -``` - - ---- -title: "Building a RAG-powered Slack Bot" -sidebarTitle: "Slack Bot" -icon: "slack" -iconType: "solid" ---- - -This tutorial demonstrates how to build a Slack bot that can answer code questions using simple RAG (Retrieval Augmented Generation) over a codebase. The bot uses semantic search to find relevant code snippets and generates detailed answers using OpenAI's APIs. - -View the full code and setup instructions in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/slack_chatbot) - -While this example uses the Codegen codebase, you can adapt it to any repository by changing the repository URL - -## Overview - -The process involves three main steps: - -1. Initializing and indexing the codebase -2. Finding relevant code snippets for a query -3. Generating answers using RAG - -Let's walk through each step using Codegen. - -## Step 1: Initializing the Codebase - -First, we initialize the codebase and create a vector index for semantic search: - -```python -from codegen import Codebase -from codegen.extensions import VectorIndex - -def initialize_codebase(): - """Initialize and index the codebase.""" - # Initialize codebase with smart caching - codebase = Codebase.from_repo( - "codegen-sh/codegen-sdk", - language="python", - tmp_dir="/root" - ) - - # Initialize vector index - index = VectorIndex(codebase) - - # Try to load existing index or create new one - index_path = "/root/E.pkl" - try: - index.load(index_path) - except FileNotFoundError: - # Create new index if none exists - index.create() - index.save(index_path) - - return codebase, index -``` - - -The vector index is persisted to disk, so subsequent queries will be much faster. -See [semantic code search](/building-with-codegen/semantic-code-search) to learn more about VectorIndex. - - -## Step 2: Finding Relevant Code - -Next, we use the vector index to find code snippets relevant to a query: - -```python -def find_relevant_code(index: VectorIndex, query: str) -> list[tuple[str, float]]: - """Find code snippets relevant to the query.""" - # Get top 10 most relevant files - results = index.similarity_search(query, k=10) - - # Clean up chunk references from index - cleaned_results = [] - for filepath, score in results: - if "#chunk" in filepath: - filepath = filepath.split("#chunk")[0] - cleaned_results.append((filepath, score)) - - return cleaned_results -``` - - -VectorIndex automatically chunks large files for better search results. We clean up the chunk references to show clean file paths. - - -## Step 3: Generating Answers - -Finally, we use GPT-4 to generate answers based on the relevant code: - -```python -from openai import OpenAI - -def generate_answer(query: str, context: str) -> str: - """Generate an answer using RAG.""" - prompt = f"""You are a code expert. Given the following code context and question, -provide a clear and accurate answer. - -Note: Keep it short and sweet - 2 paragraphs max. - -Question: {query} - -Relevant code: -{context} - -Answer:""" - - client = OpenAI() - response = client.chat.completions.create( - model="gpt-4o", - messages=[ - {"role": "system", "content": "You are a code expert. Answer questions about the given repo based on RAG'd results."}, - {"role": "user", "content": prompt}, - ], - temperature=0, - ) - - return response.choices[0].message.content -``` - -## Putting It All Together - -Here's how the components work together to answer questions: - -```python -def answer_question(query: str) -> tuple[str, list[tuple[str, float]]]: - """Answer a question about the codebase using RAG.""" - # Initialize or load codebase and index - codebase, index = initialize_codebase() - - # Find relevant files - results = find_relevant_code(index, query) - - # Collect context from relevant files - context = "" - for filepath, score in results: - file = codebase.get_file(filepath) - context += f"File: {filepath}\n```\n{file.content}\n```\n\n" - - # Generate answer - answer = generate_answer(query, context) - - return answer, results -``` - -This will: -1. Load or create the vector index -2. Find relevant code snippets -3. Generate a detailed answer -4. Return both the answer and file references - -## Example Usage - -Here's what the output looks like: - -```python -answer, files = answer_question("How does VectorIndex handle large files?") - -print("Answer:", answer) -print("\nRelevant files:") -for filepath, score in files: - print(f"• {filepath} (score: {score:.2f})") -``` - -Output: -``` -Answer: -VectorIndex handles large files by automatically chunking them into smaller pieces -using tiktoken. Each chunk is embedded separately and can be searched independently, -allowing for more precise semantic search results. - -Relevant files: -• src/codegen/extensions/vector_index.py (score: 0.92) -• src/codegen/extensions/tools/semantic_search.py (score: 0.85) -``` - -## Extensions - -While this example demonstrates a simple RAG-based bot, you can extend it to build a more powerful code agent that can: -- Do more sophisticated code retrieval -- Make code changes using Codegen's edit APIs -- Gather further context from Slack channels -- ... etc. - -Check out our [Code Agent tutorial](/tutorials/build-code-agent) to learn how to build an intelligent agent with access to Codegen's full suite of tools - - -## Learn More - - - - Learn how to use VectorIndex for semantic code search and embeddings. - - - Create a more powerful agent with multi-step reasoning and code manipulation. - - - Learn about OpenAI's text embeddings and how they work. - - - Understand RAG patterns and best practices for better results. - - - - ---- -title: "Building an AI-Powered GitHub PR Review Bot" -sidebarTitle: "GitHub PR Review Bot" -icon: "github" -iconType: "solid" ---- - -This tutorial demonstrates how to build an intelligent GitHub PR review bot that automatically reviews pull requests when triggered by labels. The bot uses Codegen's GitHub integration and AI capabilities to provide comprehensive code reviews with actionable feedback. - - -The bot is triggered by adding a "Codegen" label to PRs, making it easy to integrate into your existing workflow - -## Overview - -The process involves three main components: - -1. Setting up a Modal web endpoint for GitHub webhooks -2. Handling PR label events -3. Running an AI-powered code review agent - -Let's walk through each component using Codegen. - -## Step 1: Setting Up the Modal App - -First, we set up a Modal application to handle GitHub webhooks: - -```python -import modal -from codegen.extensions.events.app import CodegenApp -from fastapi import Request - -# Set up the base image with required dependencies -base_image = ( - modal.Image.debian_slim(python_version="3.12") - .apt_install("git") - .pip_install( - "codegen>=0.18", - "openai>=1.1.0", - "fastapi[standard]", - "slack_sdk", - ) -) - -# Initialize the Codegen app with GitHub integration -app = CodegenApp(name="github", image=base_image) - -@app.function(secrets=[modal.Secret.from_dotenv()]) -@modal.web_endpoint(method="POST") -def entrypoint(event: dict, request: Request): - return app.github.handle(event, request) -``` - - -The Modal app provides a webhook endpoint that GitHub can call when PR events occur. -Make sure to configure your GitHub repository's webhook settings to point to your Modal endpoint. - - -## Step 2: Handling PR Events - -Next, we set up event handlers for PR label events: - -```python -from codegen.extensions.github.types.events.pull_request import ( - PullRequestLabeledEvent, - PullRequestUnlabeledEvent -) - -@app.github.event("pull_request:labeled") -def handle_labeled(event: PullRequestLabeledEvent): - """Handle PR labeled events.""" - if event.label.name == "Codegen": - # Optional: Notify a Slack channel - app.slack.client.chat_postMessage( - channel="YOUR_CHANNEL_ID", - text=f"PR #{event.number} labeled with Codegen, starting review", - ) - # Start the review process - pr_review_agent(event) - -@app.github.event("pull_request:unlabeled") -def handle_unlabeled(event: PullRequestUnlabeledEvent): - """Handle PR unlabeled events.""" - if event.label.name == "Codegen": - # Clean up bot comments when label is removed - remove_bot_comments(event) -``` - - -The bot only triggers on PRs labeled with "Codegen", giving you control over which PRs get reviewed. - - -## Step 3: Implementing the Review Agent - -Finally, we implement the AI-powered review agent: - -```python -from codegen import Codebase, CodeAgent -from codegen.extensions.langchain.tools import ( - GithubViewPRTool, - GithubCreatePRCommentTool, - GithubCreatePRReviewCommentTool, -) - -def pr_review_agent(event: PullRequestLabeledEvent) -> None: - """Run the PR review agent.""" - # Initialize codebase for the repository - repo_str = f"{event.organization.login}/{event.repository.name}" - codebase = Codebase.from_repo( - repo_str, - language='python', - secrets=SecretsConfig(github_token=os.environ["GITHUB_TOKEN"]) - ) - - # Create a temporary comment to show the bot is working - review_message = "CodegenBot is starting to review the PR please wait..." - comment = codebase._op.create_pr_comment(event.number, review_message) - - # Set up PR review tools - pr_tools = [ - GithubViewPRTool(codebase), - GithubCreatePRCommentTool(codebase), - GithubCreatePRReviewCommentTool(codebase), - ] - - # Create and run the review agent - agent = CodeAgent(codebase=codebase, tools=pr_tools) - prompt = f""" -Review this pull request like a senior engineer: -{event.pull_request.url} - -Be explicit about the changes, produce a short summary, and point out possible improvements. -Focus on facts and technical details, using code snippets where helpful. -""" - result = agent.run(prompt) - - # Clean up the temporary comment - comment.delete() -``` - -## Setting Up the Environment - -Before running the bot, you'll need: - -1. Create a `.env` file with your credentials: - -```env -GITHUB_TOKEN=your_github_token -GITHUB_API_KEY=your_github_token -ANTHROPIC_API_KEY=your_anthropic_key -SLACK_BOT_TOKEN=your_slack_token # Optional -``` - -2. Deploy the Modal app: -```bash -uv sync # Install dependencies -uv run modal deploy app.py -``` - -3. Configure GitHub webhook: - - Go to your repository settings - - Add webhook pointing to your Modal endpoint - - Select "Pull request" events - - Add a webhook secret (optional but recommended) - -## Example Usage - -1. Create or update a pull request in your repository -2. Add the "Codegen" label to trigger a review -3. The bot will: - - Post a temporary "starting review" comment - - Analyze the PR changes - - Post detailed review comments - - Remove the temporary comment when done - -To remove the bot's comments: -1. Remove the "Codegen" label -2. The bot will automatically clean up its comments - -## Extensions - -While this example demonstrates a basic PR review bot, you can extend it to: -- Customize the review criteria -- Add more sophisticated analysis tools -- Integrate with other services -- Add automatic fix suggestions -- ... etc. - -Check out our [Code Agent tutorial](/tutorials/build-code-agent) to learn more about building sophisticated AI agents with Codegen - ---- -title: "Deep Code Research with AI" -sidebarTitle: "Code Research Agent" -icon: "magnifying-glass" -iconType: "solid" ---- - -This guide demonstrates how to build an intelligent code research tool that can analyze and explain codebases using Codegen's and LangChain. The tool combines semantic code search, dependency analysis, and natural language understanding to help developers quickly understand new codebases. - -View the full code on [GitHub](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/deep_code_research) - -This example works with any public GitHub repository - just provide the repo name in the format owner/repo - -## Overview - -The process involves three main components: - -1. A CLI interface for interacting with the research agent -2. A set of code analysis tools powered by Codegen -3. An LLM-powered agent that combines the tools to answer questions - -Let's walk through building each component. - -## Step 1: Setting Up the Research Tools - -First, let's import the necessary components and set up our research tools: - -```python -from codegen import Codebase -from codegen.extensions.langchain.agent import create_agent_with_tools -from codegen.extensions.langchain.tools import ( - ListDirectoryTool, - RevealSymbolTool, - SearchTool, - SemanticSearchTool, - ViewFileTool, -) -from langchain_core.messages import SystemMessage -``` - -We'll create a function to initialize our codebase with a nice progress indicator: - -```python -def initialize_codebase(repo_name: str) -> Optional[Codebase]: - """Initialize a codebase with a spinner showing progress.""" - with console.status("") as status: - try: - status.update(f"[bold blue]Cloning {repo_name}...[/bold blue]") - codebase = Codebase.from_repo(repo_name) - status.update("[bold green]✓ Repository cloned successfully![/bold green]") - return codebase - except Exception as e: - console.print(f"[bold red]Error initializing codebase:[/bold red] {e}") - return None -``` - -Then we'll set up our research tools: - -```python -# Create research tools -tools = [ - ViewFileTool(codebase), # View file contents - ListDirectoryTool(codebase), # Explore directory structure - SearchTool(codebase), # Text-based search - SemanticSearchTool(codebase), # Natural language search - RevealSymbolTool(codebase), # Analyze symbol relationships -] -``` - -Each tool provides specific capabilities: -- `ViewFileTool`: Read and understand file contents -- `ListDirectoryTool`: Explore the codebase structure -- `SearchTool`: Find specific code patterns -- `SemanticSearchTool`: Search using natural language -- `RevealSymbolTool`: Analyze dependencies and usages - -## Step 2: Creating the Research Agent - -Next, we'll create an agent that can use these tools intelligently. We'll give it a detailed prompt about its role: - -```python -RESEARCH_AGENT_PROMPT = """You are a code research expert. Your goal is to help users understand codebases by: -1. Finding relevant code through semantic and text search -2. Analyzing symbol relationships and dependencies -3. Exploring directory structures -4. Reading and explaining code - -Always explain your findings in detail and provide context about how different parts of the code relate to each other. -When analyzing code, consider: -- The purpose and functionality of each component -- How different parts interact -- Key patterns and design decisions -- Potential areas for improvement - -Break down complex concepts into understandable pieces and use examples when helpful.""" - -# Initialize the agent -agent = create_agent_with_tools( - codebase=codebase, - tools=tools, - chat_history=[SystemMessage(content=RESEARCH_AGENT_PROMPT)], - verbose=True -) -``` - -## Step 3: Building the CLI Interface - -Finally, we'll create a user-friendly CLI interface using rich-click: - -```python -import rich_click as click -from rich.console import Console -from rich.markdown import Markdown - -@click.group() -def cli(): - """🔍 Codegen Code Research CLI""" - pass - -@cli.command() -@click.argument("repo_name", required=False) -@click.option("--query", "-q", default=None, help="Initial research query.") -def research(repo_name: Optional[str] = None, query: Optional[str] = None): - """Start a code research session.""" - # Initialize codebase - codebase = initialize_codebase(repo_name) - - # Create and run the agent - agent = create_research_agent(codebase) - - # Main research loop - while True: - if not query: - query = Prompt.ask("[bold cyan]Research query[/bold cyan]") - - result = agent.invoke( - {"input": query}, - config={"configurable": {"thread_id": 1}} - ) - console.print(Markdown(result["messages"][-1].content)) - - query = None # Clear for next iteration -``` - -## Using the Research Tool - -You can use the tool in several ways: - -1. Interactive mode (will prompt for repo): -```bash -python run.py research -``` - -2. Specify a repository: -```bash -python run.py research "fastapi/fastapi" -``` - -3. Start with an initial query: -```bash -python run.py research "fastapi/fastapi" -q "Explain the main components" -``` - -Example research queries: -- "Explain the main components and their relationships" -- "Find all usages of the FastAPI class" -- "Show me the dependency graph for the routing module" -- "What design patterns are used in this codebase?" - - - The agent maintains conversation history, so you can ask follow-up questions - and build on previous findings. - - -## Advanced Usage - -### Custom Research Tools - -You can extend the agent with custom tools for specific analysis needs: - -```python -from langchain.tools import BaseTool -from pydantic import BaseModel, Field - -class CustomAnalysisTool(BaseTool): - """Custom tool for specialized code analysis.""" - name = "custom_analysis" - description = "Performs specialized code analysis" - - def _run(self, query: str) -> str: - # Custom analysis logic - return results - -# Add to tools list -tools.append(CustomAnalysisTool()) -``` - -### Customizing the Agent - -You can modify the agent's behavior by adjusting its prompt: - -```python -CUSTOM_PROMPT = """You are a specialized code reviewer focused on: -1. Security best practices -2. Performance optimization -3. Code maintainability -... -""" - -agent = create_agent_with_tools( - codebase=codebase, - tools=tools, - chat_history=[SystemMessage(content=CUSTOM_PROMPT)], -) -``` - - ---- -title: "Codebase Analytics" -sidebarTitle: "Analytics" -icon: "calculator" -iconType: "solid" ---- - -This tutorial explains how codebase metrics are efficiently calculated using the `codegen` library in the Codebase Analytics Dashboard. The metrics include indices of codebase maintainabilith and complexity. - -View the full code and setup instructions in our [codebase-analytics repository](https://github.com/codegen-sh/codebase-analytics). - - -## Complexity Metrics - -Complexity metrics help quantify how easy or difficult a codebase is to understand and maintain. These metrics are calculated by analyzing various aspects of the code structure, including control flow, code volume, and inheritance patterns. The following metrics provide different perspectives on code complexity. - -### Cyclomatic Complexity -Cyclomatic Complexity measures the number of linearly independent paths through the codebase, making it a valuable indicator of how difficult code will be to test and maintain. - -**Calculation Method**: - - Base complexity of 1 - - +1 for each: - - if statement - - elif statement - - for loop - - while loop - - +1 for each boolean operator (and, or) in conditions - - +1 for each except block in try-catch statements - -The `calculate_cyclomatic_complexity()` function traverses the Codgen codebase object and uses the above rules to find statement objects within each function and calculate the overall cyclomatic complexity of the codebase. - -```python -def calculate_cyclomatic_complexity(function): - def analyze_statement(statement): - complexity = 0 - - if isinstance(statement, IfBlockStatement): - complexity += 1 - if hasattr(statement, "elif_statements"): - complexity += len(statement.elif_statements) - - elif isinstance(statement, (ForLoopStatement, WhileStatement)): - complexity += 1 - - return complexity -``` - -### Halstead Volume -Halstead Volume is a software metric which measures the complexity of a codebase by counting the number of unique operators and operands. It is calculated by multiplying the sum of unique operators and operands by the logarithm base 2 of the sum of unique operators and operands. - -**Halstead Volume**: `V = (N1 + N2) * log2(n1 + n2)` - -This calculation uses codegen's expression types to make this calculation very efficient - these include BinaryExpression, UnaryExpression and ComparisonExpression. The function extracts operators and operands from the codebase object and calculated in `calculate_halstead_volume()` function. - -```python -def calculate_halstead_volume(operators, operands): - n1 = len(set(operators)) - n2 = len(set(operands)) - - N1 = len(operators) - N2 = len(operands) - - N = N1 + N2 - n = n1 + n2 - - if n > 0: - volume = N * math.log2(n) - return volume, N1, N2, n1, n2 - return 0, N1, N2, n1, n2 -``` - -### Depth of Inheritance (DOI) -Depth of Inheritance measures the length of inheritance chain for each class. It is calculated by counting the length of the superclasses list for each class in the codebase. The implementation is handled through a simple calculation using codegen's class information in the `calculate_doi()` function. - -```python -def calculate_doi(cls): - return len(cls.superclasses) -``` - -## Maintainability Index -Maintainability Index is a software metric which measures how maintainable a codebase is. Maintainability is described as ease to support and change the code. This index is calculated as a factored formula consisting of SLOC (Source Lines Of Code), Cyclomatic Complexity and Halstead volume. - -**Maintainability Index**: `M = 171 - 5.2 * ln(HV) - 0.23 * CC - 16.2 * ln(SLOC)` - -This formula is then normalized to a scale of 0-100, where 100 is the maximum maintainability. - -The implementation is handled through the `calculate_maintainability_index()` function. The codegen codebase object is used to efficiently extract the Cyclomatic Complexity and Halstead Volume for each function and class in the codebase, which are then used to calculate the maintainability index. - -```python -def calculate_maintainability_index( - halstead_volume: float, cyclomatic_complexity: float, loc: int -) -> int: - """Calculate the normalized maintainability index for a given function.""" - if loc <= 0: - return 100 - - try: - raw_mi = ( - 171 - - 5.2 * math.log(max(1, halstead_volume)) - - 0.23 * cyclomatic_complexity - - 16.2 * math.log(max(1, loc)) - ) - normalized_mi = max(0, min(100, raw_mi * 100 / 171)) - return int(normalized_mi) - except (ValueError, TypeError): - return 0 -``` - -## Line Metrics - -Line metrics provide insights into the size, complexity, and maintainability of a codebase. These measurements help determine the scale of a project, identify areas that may need refactoring, and track the growth of the codebase over time. - -### Lines of Code -Lines of Code refers to the total number of lines in the source code, including blank lines and comments. This is accomplished with a simple count of all lines in the source file. - -### Logical Lines of Code (LLOC) -LLOC is the amount of lines of code which contain actual functional statements. It excludes comments, blank lines, and other lines which do not contribute to the utility of the codebase. A high LLOC relative to total lines of code suggests dense, potentially complex code that may benefit from breaking into smaller functions or modules with more documentation. - -### Source Lines of Code (SLOC) -SLOC refers to the number of lines containing actual code, excluding blank lines. This includes programming language keywords and comments. While a higher SLOC indicates a larger codebase, it should be evaluated alongside other metrics like cyclomatic complexity and maintainability index to assess if the size is justified by the functionality provided. - -### Comment Density -Comment density is calculated by dividing the lines of code which contain comments by the total lines of code in the codebase. The formula is: - -```python -"comment_density": (total_comments / total_loc * 100) -``` - -It measures the proportion of comments in the codebase and is a good indicator of how much code is properly documented. Accordingly, it can show how maintainable and easy to understand the codebase is. - -## General Codebase Statistics -The number of files is determined by traversing codegen's FileNode objects in the parsed codebase. The number of functions is calculated by counting FunctionDef nodes across all parsed files. The number of classes is obtained by summing ClassDef nodes throughout the codebase. - -```python -num_files = len(codebase.files(extensions="*")) -num_functions = len(codebase.functions) -num_classes = len(codebase.classes) -``` - -The commit activity is calculated by using the git history of the repository. The number of commits is counted for each month in the last 12 months. - -## Using the Analysis Tool (Modal Server) - -The tool is implemented as a FastAPI application wrapped in a Modal deployment. To analyze a repository: - -1. Send a POST request to `/analyze_repo` with the repository URL -2. The tool will: - - Clone the repository - - Parse the codebase using codegen - - Calculate all metrics - - Return a comprehensive JSON response with all metrics - -This is the only endpoint in the FastAPI server, as it takes care of the entire analysis process. To run the FastAPI server locally, install all dependencies and run the server with `modal serve modal_main.py`. - -The server can be connected to the frontend dashboard. This web component is implemented as a Next.js application with appropriate comments and visualizations for the raw server data. To run the frontend locally, install all dependencies and run the server with `npm run dev`. This can be connected to the FastAPI server by setting the URL in the request to the `/analyze_repo` endpoint. - - - ---- -title: "Mining Training Data for LLMs" -sidebarTitle: "Mining Data" -description: "Learn how to generate training data for large language models using Codegen" -icon: "network-wired" -iconType: "solid" ---- - -This guide demonstrates how to use Codegen to generate high-quality training data for large language models (LLMs) by extracting function implementations along with their dependencies and usages. This approach is similar to [word2vec](https://www.tensorflow.org/text/tutorials/word2vec) or [node2vec](https://snap.stanford.edu/node2vec/) - given the context of a function, learn to predict the function's implementation. - -View the full code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/generate_training_data) - -This example works with both Python and Typescript repositories without modification - -## Overview - -The process involves three main steps: - -1. Finding all functions in the codebase -2. Extracting their implementations, dependencies, and usages -3. Generating structured training data - -Let's walk through each step using Codegen. - -## Step 1: Finding Functions and Their Context - -First, we will do a "graph expansion" for each function - grab the function's source, as well as the full source of all usages of the function and all dependencies. - -See [dependencies and usages](/building-with-codegen/dependencies-and-usages) to learn more about navigating the code graph - -First, let's import the types we need from Codegen: - -```python -import codegen -from codegen import Codebase -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.symbol import Symbol -``` - -Here's how we get the full context for each function: - -```python -def get_function_context(function) -> dict: - """Get the implementation, dependencies, and usages of a function.""" - context = { - "implementation": {"source": function.source, "filepath": function.filepath}, - "dependencies": [], - "usages": [], - } - - # Add dependencies - for dep in function.dependencies: - # Hop through imports to find the root symbol source - if isinstance(dep, Import): - dep = hop_through_imports(dep) - - context["dependencies"].append({"source": dep.source, "filepath": dep.filepath}) - - # Add usages - for usage in function.usages: - context["usages"].append({ - "source": usage.usage_symbol.source, - "filepath": usage.usage_symbol.filepath, - }) - - return context -``` - -Notice how we use `hop_through_imports` to resolve dependencies. When working with imports, symbols can be re-exported multiple times. For example, a helper function might be imported and re-exported through several files before being used. We need to follow this chain to find the actual implementation: - -```python -def hop_through_imports(imp: Import) -> Symbol | ExternalModule: - """Finds the root symbol for an import.""" - if isinstance(imp.imported_symbol, Import): - return hop_through_imports(imp.imported_symbol) - return imp.imported_symbol -``` - -This creates a structured representation of each function's context: - -```json -{ - "implementation": { - "source": "def process_data(input: str) -> dict: ...", - "filepath": "src/data_processor.py" - }, - "dependencies": [ - { - "source": "def validate_input(data: str) -> bool: ...", - "filepath": "src/validators.py" - } - ], - "usages": [ - { - "source": "result = process_data(user_input)", - "filepath": "src/api.py" - } - ] -} -``` - -## Step 2: Processing the Codebase - -Next, we process all functions in the codebase to generate our training data: - -```python -def run(codebase: Codebase): - """Generate training data using a node2vec-like approach for code embeddings.""" - # Track all function contexts - training_data = { - "functions": [], - "metadata": { - "total_functions": len(codebase.functions), - "total_processed": 0, - "avg_dependencies": 0, - "avg_usages": 0, - }, - } - - # Process each function in the codebase - for function in codebase.functions: - # Skip if function is too small - if len(function.source.split("\n")) < 2: - continue - - # Get function context - context = get_function_context(function) - - # Only keep functions with enough context - if len(context["dependencies"]) + len(context["usages"]) > 0: - training_data["functions"].append(context) - - # Update metadata - training_data["metadata"]["total_processed"] = len(training_data["functions"]) - if training_data["functions"]: - training_data["metadata"]["avg_dependencies"] = sum( - len(f["dependencies"]) for f in training_data["functions"] - ) / len(training_data["functions"]) - training_data["metadata"]["avg_usages"] = sum( - len(f["usages"]) for f in training_data["functions"] - ) / len(training_data["functions"]) - - return training_data -``` - -## Step 3: Running the Generator - -Finally, we can run our training data generator on any codebase. - -See [parsing codebases](/building-with-codegen/parsing-codebases) to learn more - -```python -if __name__ == "__main__": - print("Initializing codebase...") - codebase = Codebase.from_repo("fastapi/fastapi") - - print("Generating training data...") - training_data = run(codebase) - - print("Saving training data...") - with open("training_data.json", "w") as f: - json.dump(training_data, f, indent=2) - print("Training data saved to training_data.json") -``` - -This will: -1. Load the target codebase -2. Process all functions -3. Save the structured training data to a JSON file - - - You can use any Git repository as your source codebase by passing the repo URL - to [Codebase.from_repo(...)](/api-reference/core/Codebase#from-repo). - - -## Using the Training Data - -The generated data can be used to train LLMs in several ways: - -1. **Masked Function Prediction**: Hide a function's implementation and predict it from dependencies and usages -2. **Code Embeddings**: Generate embeddings that capture semantic relationships between functions -3. **Dependency Prediction**: Learn to predict which functions are likely to be dependencies -4. **Usage Pattern Learning**: Train models to understand common usage patterns - -For example, to create a masked prediction task: - -```python -def create_training_example(function_data): - """Create a masked prediction example from function data.""" - return { - "context": { - "dependencies": function_data["dependencies"], - "usages": function_data["usages"] - }, - "target": function_data["implementation"] - } - -# Create training examples -examples = [create_training_example(f) for f in training_data["functions"]] -``` - - - ---- -title: "Codebase Visualization" -sidebarTitle: "Visualization" -description: "This guide will show you how to create codebase visualizations using [codegen](/introduction/overview)." -icon: "share-nodes" -iconType: "solid" ---- - - - - - -## Overview - -To demonstrate the visualization capabilities of the codegen we will generate three different visualizations of PostHog's open source [repository](https://github.com/PostHog/posthog). - - [Call Trace Visualization](#call-trace-visualization) - - [Function Dependency Graph](#function-dependency-graph) - - [Blast Radius Visualization](#blast-radius-visualization) - - -## Call Trace Visualization - -Visualizing the call trace of a function is a great way to understand the flow of a function and for debugging. In this tutorial we will create a call trace visualization of the `patch` method of the `SharingConfigurationViewSet` class. View the source code [here](https://github.com/PostHog/posthog/blob/c2986d9ac7502aa107a4afbe31b3633848be6582/posthog/api/sharing.py#L163). - - -### Basic Setup -First, we'll set up our codebase, graph and configure some basic parameters: - -```python -import networkx as nx -from codegen import Codebase - -# Initialize codebase -codebase = Codebase("path/to/posthog/") - -# Create a directed graph for representing call relationships -G = nx.DiGraph() - -# Configuration flags -IGNORE_EXTERNAL_MODULE_CALLS = True # Skip calls to external modules -IGNORE_CLASS_CALLS = False # Include class definition calls -MAX_DEPTH = 10 - -COLOR_PALETTE = { - "StartFunction": "#9cdcfe", # Light blue - Start Function - "PyFunction": "#a277ff", # Soft purple/periwinkle - PyFunction - "PyClass": "#ffca85", # Warm peach/orange - PyClass - "ExternalModule": "#f694ff" # Bright magenta/pink - ExternalModule -} -``` - -### Building the Visualization -We'll create a function that will recursively traverse the call trace of a function and add nodes and edges to the graph: - -```python -def create_downstream_call_trace(src_func: Function, depth: int = 0): - """Creates call graph by recursively traversing function calls - - Args: - src_func (Function): Starting function for call graph - depth (int): Current recursion depth - """ - # Prevent infinite recursion - if MAX_DEPTH <= depth: - return - - # External modules are not functions - if isinstance(src_func, ExternalModule): - return - - # Process each function call - for call in src_func.function_calls: - # Skip self-recursive calls - if call.name == src_func.name: - continue - - # Get called function definition - func = call.function_definition - if not func: - continue - - # Apply configured filters - if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: - continue - if isinstance(func, Class) and IGNORE_CLASS_CALLS: - continue - - # Generate display name (include class for methods) - if isinstance(func, Class) or isinstance(func, ExternalModule): - func_name = func.name - elif isinstance(func, Function): - func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name - - # Add node and edge with metadata - G.add_node(func, name=func_name, - color=COLOR_PALETTE.get(func.__class__.__name__)) - G.add_edge(src_func, func, **generate_edge_meta(call)) - - # Recurse for regular functions - if isinstance(func, Function): - create_downstream_call_trace(func, depth + 1) -``` - -### Adding Edge Metadata -We can enrich our edges with metadata about the function calls: - -```python -def generate_edge_meta(call: FunctionCall) -> dict: - """Generate metadata for call graph edges - - Args: - call (FunctionCall): Function call information - - Returns: - dict: Edge metadata including name and location - """ - return { - "name": call.name, - "file_path": call.filepath, - "start_point": call.start_point, - "end_point": call.end_point, - "symbol_name": "FunctionCall" - } -``` -### Visualizing the Graph -Finally, we can visualize our call graph starting from a specific function: -```python -# Get target function to analyze -target_class = codebase.get_class('SharingConfigurationViewSet') -target_method = target_class.get_method('patch') - -# Add root node -G.add_node(target_method, - name=f"{target_class.name}.{target_method.name}", - color=COLOR_PALETTE["StartFunction"]) - -# Build the call graph -create_downstream_call_trace(target_method) - -# Render the visualization -codebase.visualize(G) -``` - - -### Take a look - - -View on [codegen.sh](https://www.codegen.sh/codemod/6a34b45d-c8ad-422e-95a8-46d4dc3ce2b0/public/diff) - - -### Common Use Cases -The call graph visualization is particularly useful for: - - Understanding complex codebases - - Planning refactoring efforts - - Identifying tightly coupled components - - Analyzing critical paths - - Documenting system architecture - -## Function Dependency Graph - -Understanding symbol dependencies is crucial for maintaining and refactoring code. This tutorial will show you how to create visual dependency graphs using Codegen and NetworkX. We will be creating a dependency graph of the `get_query_runner` function. View the source code [here](https://github.com/PostHog/posthog/blob/c2986d9ac7502aa107a4afbe31b3633848be6582/posthog/hogql_queries/query_runner.py#L152). - -### Basic Setup - -We'll use the same basic setup as the [Call Trace Visualization](/tutorials/codebase-visualization#call-trace-visualization) tutorial. - - -### Building the Dependency Graph -The core function for building our dependency graph: -```python -def create_dependencies_visualization(symbol: Symbol, depth: int = 0): - """Creates visualization of symbol dependencies - - Args: - symbol (Symbol): Starting symbol to analyze - depth (int): Current recursion depth - """ - # Prevent excessive recursion - if depth >= MAX_DEPTH: - return - - # Process each dependency - for dep in symbol.dependencies: - dep_symbol = None - - # Handle different dependency types - if isinstance(dep, Symbol): - # Direct symbol reference - dep_symbol = dep - elif isinstance(dep, Import): - # Import statement - get resolved symbol - dep_symbol = dep.resolved_symbol if dep.resolved_symbol else None - - if dep_symbol: - # Add node with appropriate styling - G.add_node(dep_symbol, - color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, - "#f694ff")) - - # Add dependency relationship - G.add_edge(symbol, dep_symbol) - - # Recurse unless it's a class (avoid complexity) - if not isinstance(dep_symbol, PyClass): - create_dependencies_visualization(dep_symbol, depth + 1) -``` - -### Visualizing the Graph -Finally, we can visualize our dependency graph starting from a specific symbol: -```python -# Get target symbol -target_func = codebase.get_function("get_query_runner") - -# Add root node -G.add_node(target_func, color=COLOR_PALETTE["StartFunction"]) - -# Generate dependency graph -create_dependencies_visualization(target_func) - -# Render visualization -codebase.visualize(G) -``` - -### Take a look - - -View on [codegen.sh](https://www.codegen.sh/codemod/39a36f0c-9d35-4666-9db7-12ae7c28fc17/public/diff) - - -## Blast Radius visualization - -Understanding the impact of code changes is crucial for safe refactoring. A blast radius visualization shows how changes to one function might affect other parts of the codebase by tracing usage relationships. In this tutorial we will create a blast radius visualization of the `export_asset` function. View the source code [here](https://github.com/PostHog/posthog/blob/c2986d9ac7502aa107a4afbe31b3633848be6582/posthog/tasks/exporter.py#L57). - -### Basic Setup - -We'll use the same basic setup as the [Call Trace Visualization](/tutorials/codebase-visualization#call-trace-visualization) tutorial. - - -### Helper Functions -We'll create some utility functions to help build our visualization: -```python -# List of HTTP methods to highlight -HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] - -def generate_edge_meta(usage: Usage) -> dict: - """Generate metadata for graph edges - - Args: - usage (Usage): Usage relationship information - - Returns: - dict: Edge metadata including name and location - """ - return { - "name": usage.match.source, - "file_path": usage.match.filepath, - "start_point": usage.match.start_point, - "end_point": usage.match.end_point, - "symbol_name": usage.match.__class__.__name__ - } - -def is_http_method(symbol: PySymbol) -> bool: - """Check if a symbol is an HTTP endpoint method - - Args: - symbol (PySymbol): Symbol to check - - Returns: - bool: True if symbol is an HTTP method - """ - if isinstance(symbol, PyFunction) and symbol.is_method: - return symbol.name in HTTP_METHODS - return False -``` - -### Building the Blast Radius Visualization -The main function for creating our blast radius visualization: -```python -def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): - """Create visualization of symbol usage relationships - - Args: - symbol (PySymbol): Starting symbol to analyze - depth (int): Current recursion depth - """ - # Prevent excessive recursion - if depth >= MAX_DEPTH: - return - - # Process each usage of the symbol - for usage in symbol.usages: - usage_symbol = usage.usage_symbol - - # Determine node color based on type - if is_http_method(usage_symbol): - color = COLOR_PALETTE.get("HTTP_METHOD") - else: - color = COLOR_PALETTE.get(usage_symbol.__class__.__name__, "#f694ff") - - # Add node and edge to graph - G.add_node(usage_symbol, color=color) - G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) - - # Recursively process usage symbol - create_blast_radius_visualization(usage_symbol, depth + 1) -``` - -### Visualizing the Graph -Finally, we can create our blast radius visualization: -```python -# Get target function to analyze -target_func = codebase.get_function('export_asset') - -# Add root node -G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) - -# Build the visualization -create_blast_radius_visualization(target_func) - -# Render graph to show impact flow -# Note: a -> b means changes to a will impact b -codebase.visualize(G) -``` - -### Take a look - - -View on [codegen.sh](https://www.codegen.sh/codemod/d255db6c-9a86-4197-9b78-16c506858a3b/public/diff) - - -## What's Next? - - - - Learn how to use Codegen to create modular codebases. - - - Learn how to use Codegen to delete dead code. - - - Learn how to use Codegen to increase type coverage. - - - Explore the complete API documentation for all Codegen classes and methods. - - - ---- -title: "Migrating APIs" -sidebarTitle: "API Migrations" -icon: "webhook" -iconType: "solid" ---- - -API migrations are a common task in large codebases. Whether you're updating a deprecated function, changing parameter names, or modifying return types, Codegen makes it easy to update all call sites consistently. - -## Common Migration Scenarios - -### Renaming Parameters - -When updating parameter names across an API, you need to update both the function definition and all call sites: - -```python -# Find the API function to update -api_function = codebase.get_function("process_data") - -# Update the parameter name -old_param = api_function.get_parameter("input") -old_param.rename("data") - -# All call sites are automatically updated: -# process_data(input="test") -> process_data(data="test") -``` - -See [dependencies and usages](/building-with-codegen/dependencies-and-usages) for more on updating parameter names and types. - -### Adding Required Parameters - -When adding a new required parameter to an API: - -```python -# Find all call sites before modifying the function -call_sites = list(api_function.call_sites) - -# Add the new parameter -api_function.add_parameter("timeout: int") - -# Update all existing call sites to include the new parameter -for call in call_sites: - call.add_argument("timeout=30") # Add with a default value -``` - -See [function calls and callsites](/building-with-codegen/function-calls-and-callsites) for more on handling call sites. - -### Changing Parameter Types - -When updating parameter types: - -```python -# Update the parameter type -param = api_function.get_parameter("user_id") -param.type = "UUID" # Change from string to UUID - -# Find all call sites that need type conversion -for call in api_function.call_sites: - arg = call.get_arg_by_parameter_name("user_id") - if arg: - # Convert string to UUID - arg.edit(f"UUID({arg.value})") -``` - -See [working with type annotations](/building-with-codegen/type-annotations) for more on changing parameter types. - -### Deprecating Functions - -When deprecating an old API in favor of a new one: - -```python -old_api = codebase.get_function("old_process_data") -new_api = codebase.get_function("new_process_data") - -# Add deprecation warning -old_api.add_decorator('@deprecated("Use new_process_data instead")') - -# Update all call sites to use the new API -for call in old_api.call_sites: - # Map old arguments to new parameter names - args = [ - f"data={call.get_arg_by_parameter_name('input').value}", - f"timeout={call.get_arg_by_parameter_name('wait').value}" - ] - - # Replace the old call with the new API - call.replace(f"new_process_data({', '.join(args)})") -``` - -## Bulk Updates to Method Chains - -When updating chained method calls, like database queries or builder patterns: - -```python -# Find all query chains ending with .execute() -for execute_call in codebase.function_calls: - if execute_call.name != "execute": - continue - - # Get the full chain - chain = execute_call.call_chain - - # Example: Add .timeout() before .execute() - if "timeout" not in {call.name for call in chain}: - execute_call.insert_before("timeout(30)") -``` - -## Handling Breaking Changes - -When making breaking changes to an API, it's important to: -1. Identify all affected call sites -2. Make changes consistently -3. Update related documentation -4. Consider backward compatibility - -Here's a comprehensive example: - -```python -def migrate_api_v1_to_v2(codebase): - old_api = codebase.get_function("create_user_v1") - - # Document all existing call patterns - call_patterns = {} - for call in old_api.call_sites: - args = [arg.source for arg in call.args] - pattern = ", ".join(args) - call_patterns[pattern] = call_patterns.get(pattern, 0) + 1 - - print("Found call patterns:") - for pattern, count in call_patterns.items(): - print(f" {pattern}: {count} occurrences") - - # Create new API version - new_api = old_api.copy() - new_api.rename("create_user_v2") - - # Update parameter types - new_api.get_parameter("email").type = "EmailStr" - new_api.get_parameter("role").type = "UserRole" - - # Add new required parameters - new_api.add_parameter("tenant_id: UUID") - - # Update all call sites - for call in old_api.call_sites: - # Get current arguments - email_arg = call.get_arg_by_parameter_name("email") - role_arg = call.get_arg_by_parameter_name("role") - - # Build new argument list with type conversions - new_args = [ - f"email=EmailStr({email_arg.value})", - f"role=UserRole({role_arg.value})", - "tenant_id=get_current_tenant_id()" - ] - - # Replace old call with new version - call.replace(f"create_user_v2({', '.join(new_args)})") - - # Add deprecation notice to old version - old_api.add_decorator('@deprecated("Use create_user_v2 instead")') - -# Run the migration -migrate_api_v1_to_v2(codebase) -``` - -## Best Practices - -1. **Analyze First**: Before making changes, analyze all call sites to understand usage patterns - ```python - # Document current usage - for call in api.call_sites: - print(f"Called from: {call.parent_function.name}") - print(f"With args: {[arg.source for arg in call.args]}") - ``` - -2. **Make Atomic Changes**: Update one aspect at a time - ```python - # First update parameter names - param.rename("new_name") - - # Then update types - param.type = "new_type" - - # Finally update call sites - for call in api.call_sites: - # ... update calls - ``` - -3. **Maintain Backwards Compatibility**: - ```python - # Add new parameter with default - api.add_parameter("new_param: str = None") - - # Later make it required - api.get_parameter("new_param").remove_default() - ``` - -4. **Document Changes**: - ```python - # Add clear deprecation messages - old_api.add_decorator('''@deprecated( - "Use new_api() instead. Migration guide: docs/migrations/v2.md" - )''') - ``` - - -Remember to test thoroughly after making bulk changes to APIs. While Codegen ensures syntactic correctness, you'll want to verify the semantic correctness of the changes. - - ---- -title: "Organizing Your Codebase" -sidebarTitle: "Organization" -icon: "folder-tree" -iconType: "solid" ---- - -Codegen SDK provides a powerful set of tools for deterministically moving code safely and efficiently. This guide will walk you through the basics of moving code with Codegen SDK. - -Common use cases include: - - - - -```python -print(f"🔍 Processing file: {filepath}") -file = codebase.get_file(filepath) - -# Get the directory path for creating new files -dir_path = file.directory.path if file.directory else "" - -# Iterate through all functions in the file -for function in file.functions: - # Create new filename based on function name - new_filepath = f"{dir_path}/{function.name}.py" - print(f"📝 Creating new file: {new_filepath}") - - # Create the new file - new_file = codebase.create_file(new_filepath) - - # Move the function to the new file, including dependencies - print(f"➡️ Moving function: {function.name}") - function.move_to_file(new_file, include_dependencies=True) -``` - - - - - -```python -# Dictionary to track modules and their functions -module_map = { - "utils": lambda f: f.name.startswith("util_") or f.name.startswith("helper_"), - "api": lambda f: f.name.startswith("api_") or f.name.startswith("endpoint_"), - "data": lambda f: f.name.startswith("data_") or f.name.startswith("db_"), - "core": lambda f: True # Default module for other functions -} - -print("🔍 Starting code organization...") - -# Create module directories if they don't exist -for module in module_map.keys(): - if not codebase.has_directory(module): - print(f"📁 Creating module directory: {module}") - codebase.create_directory(module, exist_ok=True) - -# Process each file in the codebase -for file in codebase.files: - print(f"\n📄 Processing file: {file.filepath}") - - # Skip if file is already in a module directory - if any(file.filepath.startswith(module) for module in module_map.keys()): - continue - - # Process each function in the file - for function in file.functions: - # Determine which module this function belongs to - target_module = next( - (module for module, condition in module_map.items() - if condition(function)), - "core" - ) - - # Create the new file path - new_filepath = f"{target_module}/{function.name}.py" - - print(f" ➡️ Moving {function.name} to {target_module} module") - - # Create new file and move function - if not codebase.has_file(new_filepath): - new_file = codebase.create_file(new_filepath) - function.move_to_file(new_file, include_dependencies=True) - -print("\n✅ Code organization complete!") -``` - - - - - -```python -# Create a graph to detect cycles -import networkx as nx - -# Build dependency graph -G = nx.DiGraph() - -# Add edges for imports between files -for file in codebase.files: - for imp in file.imports: - if imp.from_file: - G.add_edge(file.filepath, imp.from_file.filepath) - -# Find cycles in the graph -cycles = list(nx.simple_cycles(G)) - -if not cycles: - print("✅ No import cycles found!") - exit() - -print(f"🔍 Found {len(cycles)} import cycles") - -# Process each cycle -for cycle in cycles: - print(f"\n⭕ Processing cycle: {' -> '.join(cycle)}") - - # Get the first two files in the cycle - file1 = codebase.get_file(cycle[0]) - file2 = codebase.get_file(cycle[1]) - - # Find functions in file1 that are used by file2 - for function in file1.functions: - if any(usage.file == file2 for usage in function.usages): - # Create new file for the shared function - new_filepath = f"shared/{function.name}.py" - print(f" ➡️ Moving {function.name} to {new_filepath}") - - if not codebase.has_directory("shared"): - codebase.create_directory("shared") - - new_file = codebase.create_file(new_filepath) - function.move_to_file(new_file, include_dependencies=True) - -print("\n✅ Import cycles resolved!") -``` - - - - - - Most operations in Codegen will automatically handle updaging - [dependencies](/building-with-codegen/dependencies-and-usages) and - [imports](/building-with-codegen/imports). See [Moving - Symbols](/building-with-codegen/moving-symbols) to learn more. - - -## Basic Symbol Movement - -To move a symbol from one file to another, you can use the [move_to_file](/api-reference/core/Function#move-to-file) method. - - -```python python -# Get the symbol -symbol_to_move = source_file.get_symbol("my_function") -# Pick a destination file -dst_file = codebase.get_file("path/to/dst/location.py") -# Move the symbol, move all of its dependencies with it (remove from old file), and add an import of symbol into old file -symbol_to_move.move_to_file(dst_file, include_dependencies=True, strategy="add_back_edge") -``` - -```python typescript -# Get the symbol -symbol_to_move = source_file.get_symbol("myFunction") -# Pick a destination file -dst_file = codebase.get_file("path/to/dst/location.ts") -# Move the symbol, move all of its dependencies with it (remove from old file), and add an import of symbol into old file -symbol_to_move.move_to_file(dst_file, include_dependencies=True, strategy="add_back_edge") -``` - - - -This will move `my_function` to `path/to/dst/location.py`, safely updating all references to it in the process. - -## Updating Imports - -After moving a symbol, you may need to update imports throughout your codebase. GraphSitter offers two strategies for this: - -1. **Update All Imports**: This strategy updates all imports across the codebase to reflect the new location of the symbol. - - -```python python -symbol_to_move = codebase.get_symbol("symbol_to_move") -dst_file = codebase.create_file("new_file.py") -symbol_to_move.move_to_file(dst_file, strategy="update_all_imports") -``` - -```python typescript -symbol_to_move = codebase.get_symbol("symbolToMove") -dst_file = codebase.create_file("new_file.ts") -symbol_to_move.move_to_file(dst_file, strategy="update_all_imports") -``` - - - -Updating all imports can result in very large PRs - -2. **Add Back Edge**: This strategy adds an import in the original file that re-imports (and exports) the moved symbol, maintaining backwards compatibility. This will result in fewer total modifications, as existing imports will not need to be updated. - - -```python python -symbol_to_move = codebase.get_symbol("symbol_to_move") -dst_file = codebase.create_file("new_file.py") -symbol_to_move.move_to_file(dst_file, strategy="add_back_edge") -``` - -```python typescript -symbol_to_move = codebase.get_symbol("symbolToMove") -dst_file = codebase.create_file("new_file.ts") -symbol_to_move.move_to_file(dst_file, strategy="add_back_edge") -``` - - - -## Handling Dependencies - -By default, Codegen will move all of a symbols dependencies along with it. This ensures that your codebase remains consistent and functional. - - -```python python -my_symbol = codebase.get_symbol("my_symbol") -dst_file = codebase.create_file("new_file.py") -my_symbol.move_to_file(dst_file, include_dependencies=True) -``` - -```python typescript -my_symbol = codebase.get_symbol("mySymbol") -dst_file = codebase.create_file("new_file.ts") -my_symbol.move_to_file(dst_file, include_dependencies=True) -``` - - - -If you set `include_dependencies=False`, only the symbol itself will be moved, and any dependencies will remain in the original file. - -## Moving Multiple Symbols - -If you need to move multiple symbols, you can do so in a loop: - -```python -source_file = codebase.get_file("path/to/source_file.py") -dest_file = codebase.get_file("path/to/destination_file.py") -# Create a list of symbols to move -symbols_to_move = [source_file.get_function("my_function"), source_file.get_class("MyClass")] -# Move each symbol to the destination file -for symbol in symbols_to_move: - symbol.move_to_file(dest_file, include_dependencies=True, strategy="update_all_imports") -``` - -## Best Practices - -1. **Commit After Major Changes**: If you're making multiple significant changes, use `codebase.commit()` between them to ensure the codebase graph is up-to-date. - -2. **Re-fetch References**: After a commit, re-fetch any file or symbol references you're working with, as they may have become stale. - -3. **Handle Errors**: Be prepared to handle cases where symbols or files might not exist, or where moves might fail due to naming conflicts. - -By following these guidelines, you can effectively move symbols around your codebase while maintaining its integrity and functionality. - - ---- -title: "Converting Promise Chains to Async/Await" -sidebarTitle: "Promise to Async/Await" -icon: "code-merge" -iconType: "solid" ---- - -Modern JavaScript/TypeScript codebases often need to migrate from Promise-based code to the more readable async/await syntax. Codegen provides powerful tools to automate this conversion while preserving business logic and handling complex scenarios. - - -You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/blob/develop/codegen-examples/examples/promises_to_async_await/promises_to_async_await.ipynb). - - - -## Finding Promise Chains - -Codegen offers multiple ways to locate Promise chains in your codebase: -- In files -- In functions -- Part of a function call chain - -### Promise Chains in a File - -Find all Promise chains in a file: - -```python -ts_file = codebase.get_file("api_client.ts") -promise_chains = ts_file.promise_chains - -print(f"Found {len(promise_chains)} Promise chains") -``` - -### Promise Chains in a Function - -Find Promise chains within a specific function: - -```python -ts_func = codebase.get_function("getUserData") -chains = ts_func.promise_chains - -for chain in chains: - print(f"Found chain starting with: {chain.name}") -``` - -### Promise Chain starting from a Function Call - -Find Promise chains starting from a specific function call: - -```python -# Assuming the function call is part of a promise chain -fetch_call = codebase.get_function("fetchUserData").function_calls[2] -chain = fetch_call.promise_chain -``` - - -## Converting Promise Chains - -### In-Place Conversion - -Convert Promise chains directly in your codebase: - -```python -# Find and convert all Promise chains in a file -for chain in typescript_file.promise_chains: - chain.convert_to_async_await() -``` - -### Handle Business Logic Without In-Place Edit - -Generate the transformed code without inplace edit by returning the new code as a string. This is useful when you want to add additional business logic to the overall conversion. - -```python -async_await_code = chain.convert_to_async_await(inplace_edit=False) -print("Converted code:", async_await_code) - -promise_statement = chain.parent_statement -new_code = promise_statement.edit( - f""" - {async_await_code} - - // handle additional business logic here - """ -) -``` - - -## Supported Promise Chain Patterns - -- Basic `promise.then()` statements of any length -- Catch `promise.then().catch()` statements of any length -- Finally `promise.then().catch().finally()` statements of any length -- Desctructure `promise.then((var1, var2))` statements -> `let [var1, var2] = await statement;` -- Implicit returns -> `return promise.then(() => console.log("hello"))` -- Top level variable assignments -> `let assigned_var = promise.then()` -- Top level variable assignments -> `let assigned_var = promise.then()` -- Ambiguous/conditional return blocks - - -A list of all the covered cases can be found in the [example notebook](https://github.com/codegen-sh/codegen-sdk/tree/codegen-examples/examples/promises_to_async_await/promise_to_async_await.ipynb). - - - -## Examples -### 1. Basic Promise Chains - -```typescript -// Before -function getValue(): Promise { - return Promise.resolve(10) - .then(value => value * 2); -} -``` - -***Applying the conversion...*** -```python -promise_chain = codebase.get_function("getValue").promise_chains[0] -promise_chain.convert_to_async_await() -codebase.commit() -``` - -```typescript -// After -async function getValue(): Promise { - let value = await Promise.resolve(10); - return value * 2; -} -``` - -### 2. Error Handling with Catch/Finally - -```typescript -// Before -function processData(): Promise { - return fetchData() - .then(data => processData(data)) - .catch(error => { - console.error("Error:", error); - throw error; - }) - .finally(() => { - cleanup(); - }); -} -``` - -***Applying the conversion...*** -```python -promise_chain = codebase.get_function("processData").promise_chains[0] -promise_chain.convert_to_async_await() -codebase.commit() -``` - -```typescript -// After -async function processData(): Promise { - try { - let data = await fetchData(); - return processData(data); - } catch (error) { - console.error("Error:", error); - throw error; - } finally { - cleanup(); - } -} -``` - -### 3. Promise.all with Destructuring - -```typescript -// Before -function getAllUserInfo(userId: number) { - return Promise.all([ - fetchUserData(userId), - fetchUserPosts(userId) - ]).then(([user, posts]) => { - return { user, posts }; - }); -} -``` - -***Applying the conversion...*** -```python -promise_chain = codebase.get_function("getAllUserInfo").promise_chains[0] -promise_chain.convert_to_async_await() -codebase.commit() -``` - - -```typescript -// After -async function getAllUserInfo(userId: number) { - const [user, posts] = await Promise.all([ - fetchUserData(userId), - fetchUserPosts(userId) - ]); - return { user, posts }; -} -``` - - -### 4. Handling Ambiguous Returns Using Anonymous functions - - -For `then` blocks that have more than one return statement, Codegen will add an anonymous function to handle the ambiguous return to guarantee a deterministic conversion. - -```typescript -// Before -function create(opts: any): Promise { - let qResponse = request(opts); - qResponse = qResponse.then(function success(response) { - if (response.statusCode < 200 || response.statusCode >= 300) { - throw new Error(JSON.stringify(response)); - } - if (typeof response.body === "string") { - return JSON.parse(response.body); - } - return response.body; - }); - - return qResponse; -} - -``` - -***Applying the conversion...*** -```python -promise_chain = codebase.get_function("create").promise_chains[0] -promise_chain.convert_to_async_await() -codebase.commit() -``` -```typescript -// After -async function create(opts): Promise { - let qResponse = request(opts); - let response = await qResponse; - qResponse = (async (response) => { - if (response.statusCode < 200 || response.statusCode >= 300) { - throw new Error(JSON.stringify(response)); - } - if (typeof response.body === "string") { - return JSON.parse(response.body); - } - return response.body; - })(response); - - return qResponse; -} -``` - - - -## Handling Top-Level Assignment Variables - -When converting Promise chains that involve top-level assignment variables, you can specify the variable name of your choice or pick the default which is the original variable assignment name. - -```python -# Convert with custom variable names for clarity -chain.convert_to_async_await( - assignment_variable_name="operationResult", -) -``` - - -## Next Steps - - -Converting Promise chains to async/await improves code readability and maintainability. Codegen's tools make this migration process automated and reliable, handling complex cases while preserving business logic. -Here are some next steps to ensure a successful migration: - -1. Ensure to run `npx prettier --write .` after the migration to fix indentation + linting -2. **Incremental Migration**: Convert one module at a time -3. **Handle Additional Business Logic**: Use `.promise_statement.edit()` to modify the entire chain and handle external business logic -4. If the specific conversion case is not covered, open an issue on the [Codegen](https://github.com/codegen-sh/codegen-sdk) repository or try to right your own transformation logic using the codegen-sdk - - ---- -title: "Improving Code Modularity" -sidebarTitle: "Modularity" -icon: "diagram-project" -iconType: "solid" ---- - -Codegen SDK provides powerful tools for analyzing and improving code modularity. This guide will help you identify and fix common modularity issues like circular dependencies, tight coupling, and poorly organized imports. - -Common use cases include: -- Breaking up circular dependencies -- Organizing imports and exports -- Identifying highly coupled modules -- Extracting shared code into common modules -- Analyzing module boundaries - -## Analyzing Import Relationships - -First, let's see how to analyze import relationships in your codebase: - -```python -import networkx as nx -from collections import defaultdict - -# Create a graph of file dependencies -def create_dependency_graph(): - G = nx.DiGraph() - - for file in codebase.files: - # Add node for this file - G.add_node(file.filepath) - - # Add edges for each import - for imp in file.imports: - if imp.from_file: # Skip external imports - G.add_edge(file.filepath, imp.from_file.filepath) - - return G - -# Create and analyze the graph -graph = create_dependency_graph() - -# Find circular dependencies -cycles = list(nx.simple_cycles(graph)) -if cycles: - print("🔄 Found circular dependencies:") - for cycle in cycles: - print(f" • {' -> '.join(cycle)}") - -# Calculate modularity metrics -print("\n📊 Modularity Metrics:") -print(f" • Number of files: {len(graph.nodes)}") -print(f" • Number of imports: {len(graph.edges)}") -print(f" • Average imports per file: {len(graph.edges)/len(graph.nodes):.1f}") -``` - -## Breaking Circular Dependencies - -When you find circular dependencies, here's how to break them: - -```python -def break_circular_dependency(cycle): - # Get the first two files in the cycle - file1 = codebase.get_file(cycle[0]) - file2 = codebase.get_file(cycle[1]) - - # Create a shared module for common code - shared_dir = "shared" - if not codebase.has_directory(shared_dir): - codebase.create_directory(shared_dir) - - # Find symbols used by both files - shared_symbols = [] - for symbol in file1.symbols: - if any(usage.file == file2 for usage in symbol.usages): - shared_symbols.append(symbol) - - # Move shared symbols to a new file - if shared_symbols: - shared_file = codebase.create_file(f"{shared_dir}/shared_types.py") - for symbol in shared_symbols: - symbol.move_to_file(shared_file, strategy="update_all_imports") - -# Break each cycle found -for cycle in cycles: - break_circular_dependency(cycle) -``` - -## Organizing Imports - -Clean up and organize imports across your codebase: - -```python -def organize_file_imports(file): - # Group imports by type - std_lib_imports = [] - third_party_imports = [] - local_imports = [] - - for imp in file.imports: - if imp.is_standard_library: - std_lib_imports.append(imp) - elif imp.is_third_party: - third_party_imports.append(imp) - else: - local_imports.append(imp) - - # Sort each group - for group in [std_lib_imports, third_party_imports, local_imports]: - group.sort(key=lambda x: x.module_name) - - # Remove all existing imports - for imp in file.imports: - imp.remove() - - # Add imports back in organized groups - if std_lib_imports: - for imp in std_lib_imports: - file.add_import(imp.source) - file.insert_after_imports("") # Add newline - - if third_party_imports: - for imp in third_party_imports: - file.add_import(imp.source) - file.insert_after_imports("") # Add newline - - if local_imports: - for imp in local_imports: - file.add_import(imp.source) - -# Organize imports in all files -for file in codebase.files: - organize_file_imports(file) -``` - -## Identifying Highly Coupled Modules - -Find modules that might need to be split up: - -```python -from collections import defaultdict - -def analyze_module_coupling(): - coupling_scores = defaultdict(int) - - for file in codebase.files: - # Count unique files imported from - imported_files = {imp.from_file for imp in file.imports if imp.from_file} - coupling_scores[file.filepath] = len(imported_files) - - # Count files that import this file - importing_files = {usage.file for symbol in file.symbols - for usage in symbol.usages if usage.file != file} - coupling_scores[file.filepath] += len(importing_files) - - # Sort by coupling score - sorted_files = sorted(coupling_scores.items(), - key=lambda x: x[1], - reverse=True) - - print("\n🔍 Module Coupling Analysis:") - print("\nMost coupled files:") - for filepath, score in sorted_files[:5]: - print(f" • {filepath}: {score} connections") - -analyze_module_coupling() -``` - -## Extracting Shared Code - -When you find highly coupled modules, extract shared code: - -```python -def extract_shared_code(file, min_usages=3): - # Find symbols used by multiple files - for symbol in file.symbols: - # Get unique files using this symbol - using_files = {usage.file for usage in symbol.usages - if usage.file != file} - - if len(using_files) >= min_usages: - # Create appropriate shared module - module_name = determine_shared_module(symbol) - if not codebase.has_file(f"shared/{module_name}.py"): - shared_file = codebase.create_file(f"shared/{module_name}.py") - else: - shared_file = codebase.get_file(f"shared/{module_name}.py") - - # Move symbol to shared module - symbol.move_to_file(shared_file, strategy="update_all_imports") - -def determine_shared_module(symbol): - # Logic to determine appropriate shared module name - if symbol.is_type: - return "types" - elif symbol.is_constant: - return "constants" - elif symbol.is_utility: - return "utils" - else: - return "common" -``` - ---- -title: "Managing Feature Flags" -sidebarTitle: "Feature Flags" -icon: "flag" -iconType: "solid" ---- - -Codegen has been used in production for multi-million line codebases to automatically delete "dead" (rolled-out) feature flags. This guide will walk you through analyzing feature flag usage and safely removing rolled out flags. - - - Every codebase does feature flags differently. This guide shows common techniques and syntax but likely requires adaptation to codebase-specific circumstances. - - -## Analyzing Feature Flag Usage - -Before removing a feature flag, it's important to analyze its usage across the codebase. Codegen provides tools to help identify where and how feature flags are used. - -### For Python Codebases - -For Python codebases using a `FeatureFlags` class pattern like so: -```python -class FeatureFlags: - FEATURE_1 = False - FEATURE_2 = True -``` - -You can use [Class.get_attribute(...)](/api-reference/core/Class#get-attribute) and [Attribute.usages](/api-reference/core/Attribute#usages) to analyze the coverage of your flags, like so: - - - -```python -feature_flag_usage = {} -feature_flag_class = codebase.get_class('FeatureFlag') - -if feature_flag_class: - # Initialize usage count for all attributes - for attr in feature_flag_class.attributes: - feature_flag_usage[attr.name] = 0 - - # Get all usages of the FeatureFlag class - for usage in feature_flag_class.usages: - usage_source = usage.usage_symbol.source if hasattr(usage, 'usage_symbol') else str(usage) - for flag_name in feature_flag_usage.keys(): - if f"FeatureFlag.{flag_name}" in usage_source: - feature_flag_usage[flag_name] += 1 - - sorted_flags = sorted(feature_flag_usage.items(), key=lambda x: x[1], reverse=True) - - print("Feature Flag Usage Table:") - print("-------------------------") - print(f"{'Feature Flag':<30} | {'Usage Count':<12}") - print("-" * 45) - for flag, count in sorted_flags: - print(f"{flag:<30} | {count:<12}") - - print(f"\nTotal feature flags: {len(sorted_flags)}") -else: - print("❗ FeatureFlag enum not found in the codebase") -``` - -This will output a table showing all feature flags and their usage counts, helping identify which flags are candidates for removal. - - - Learn more about [Attributes](/building-with-codegen/class-api#class-attributes) and [tracking usages](/building-with-codegen/dependencies-and-usages) here - - - -## Removing Rolled Out Flags - -Once you've identified a flag that's ready to be removed, Codegen can help safely delete it and its associated code paths. - - - This primarily leverages Codegen's API for [reduction conditions](/building-with-codegen/reducing-conditions) - - -### Python Example - -For Python codebases, here's how to remove a feature flag and its usages: - -```python -flag_name = "FEATURE_TO_REMOVE" - -# Get the feature flag variable -feature_flag_file = codebase.get_file("app/utils/feature_flags.py") -flag_class = feature_flag_file.get_class("FeatureFlag") - -# Check if the flag exists -flag_var = flag_class.get_attribute(flag_name) -if not flag_var: - print(f'No such flag: {flag_name}') - return - -# Remove all usages of the feature flag -for usage in flag_var.usages: - if isinstance(usage.parent, IfBlockStatement): - # For if statements, reduce the condition to True - usage.parent.reduce_condition(True) - elif isinstance(usage.parent, WithStatement): - # For with statements, keep the code block - usage.parent.code_block.unwrap() - else: - # For other cases, remove the usage - usage.remove() - -# Remove the flag definition -flag_var.remove() - -# Commit changes -codebase.commit() -``` - -### React/TypeScript Example - -For React applications using a hooks-based feature flag system: - -```python -feature_flag_name = "NEW_UI_ENABLED" -target_value = True # The value to reduce the flag to - -print(f'Removing feature flag: {feature_flag_name}') - -# 1. Remove from configuration -config_file = codebase.get_file("src/featureFlags/config.ts") -feature_flag_config = config_file.get_symbol("FEATURE_FLAG_CONFIG").value -if feature_flag_name in feature_flag_config.keys(): - feature_flag_config.pop(feature_flag_name) - print('✅ Removed from feature flag config') - -# 2. Find and reduce all hook usages -hook = codebase.get_function("useFeatureFlag") -for usage in hook.usages: - fcall = usage.match - if isinstance(fcall, FunctionCall): - # Check if this usage is for our target flag - first_arg = fcall.args[0].value - if isinstance(first_arg, String) and first_arg.content == feature_flag_name: - print(f'Reducing in: {fcall.parent_symbol.name}') - # This automatically handles: - # - Ternary expressions: flag ? : - # - If statements: if (flag) { ... } - # - Conditional rendering: {flag && } - fcall.reduce_condition(target_value) - -# 3. Commit changes -codebase.commit() -``` - -This will: -1. Remove the feature flag from the configuration -2. Find all usages of the `useFeatureFlag` hook for this flag -3. Automatically reduce any conditional logic using the flag -4. Handle common React patterns like ternaries and conditional rendering - - -## Related Resources -- [Reducing Conditions](/building-with-codegen/reducing-conditions) - Details on condition reduction APIs -- [Dead Code Removal](/tutorials/deleting-dead-code) - Remove unused code after flag deletion - ---- -title: "Deleting Dead Code" -sidebarTitle: "Dead Code" -icon: "trash" -iconType: "solid" ---- - -Dead code refers to code that is not being used or referenced anywhere in your codebase. - -However, it's important to note that some code might appear unused but should not be deleted, including: -- Test files and test functions -- Functions with decorators (which may be called indirectly) -- Public API endpoints -- Event handlers or callback functions -- Code used through reflection or dynamic imports - -This guide will show you how to safely identify and remove genuinely unused code while preserving important functionality. - -## Overview - -To simply identify code without any external usages, you can check for the absence of [Symbol.usages](/api-reference/core/Symbol#usages). - -See [Dependencies and Usages](/building-with-codegen/dependencies-and-usages) for more information on how to use these properties. - -```python -# Iterate through all functions in the codebase -for function in codebase.functions: - # Remove functions with no usages - if not function.usages: - function.remove() - -# Commit -codebase.commit() -``` - - -This will remove all code that is not explicitly referenced elsewhere, including tests, endpoints, etc. This is almost certainly not what you want. We recommend further filtering. - - -## Filtering for Special Cases - -To filter out special cases that are not explicitly referenced yet are, nonetheless, worth keeping around, you can use the following pattern: - - -```python -for function in codebase.functions: - - # Skip test files - if "test" in function.file.filepath: - continue - - # Skip decorated functions - if function.decorators: - continue - - # Skip public routes, e.g. next.js endpoints - # (Typescript only) - if 'routes' in function.file.filepath and function.is_jsx: - continue - - # ... etc. - - # Check if the function has no usages and no call sites - if not function.usages and not function.call_sites: - # Print a message indicating the removal of the function - print(f"Removing unused function: {function.name}") - # Remove the function from the file - function.remove() - -# Commit -codebase.commit() -``` - - -## Cleaning Up Unused Variables - -To remove unused variables, you can check for their usages within their scope: - -```python typescript -for func in codebase.functions: - # Iterate through local variable assignments in the function - for var_assignments in func.code_block.local_var_assignments: - # Check if the local variable assignment has no usages - if not var_assignments.local_usages: - # Remove the local variable assignment - var_assignments.remove() - -# Commit -codebase.commit() -``` - - -## Cleaning Up After Removal - -After removing dead code, you may need to clean up any remaining artifacts: - -```python -for file in codebase.files: - # Check if the file is empty - if not file.content.strip(): - # Print a message indicating the removal of the empty file - print(f"Removing empty file: {file.filepath}") - # Remove the empty file - file.remove() - -# commit is NECESSARY to remove the files from the codebase -codebase.commit() - -# Remove redundant newlines -for file in codebase.files: - # Replace three or more consecutive newlines with two newlines - file.edit(re.sub(r"\n{3,}", "\n\n", file.content)) -``` - - ---- -title: "Increasing Type Coverage" -sidebarTitle: "Type Coverage" -icon: "shield-check" -iconType: "solid" ---- - -This guide demonstrates how to analyze and manipulate type annotations with Codegen SDK. - -Common use cases include: - -- Adding a type to a union or generic type -- Checking if a generic type has a given subtype -- Resolving a type annotation - - - Adding type hints can improve developer experience and [significantly speed up](https://github.com/microsoft/Typescript/wiki/Performance#using-type-annotations) programs like the Typescript compiler and `mypy`. - - -See [Type Annotations](/building-with-codegen/type-annotations) for a general overview of the type maninpulation - -## APIs for monitoring types - -Codegen programs typically access type annotations through the following APIs: -- [Parameter.type](/api-reference/core/Parameter#type) -- [Function.return_type](/api-reference/python/PyFunction#return-type) -- [Assignment.type](/api-reference/core/Assignment#type) - -Each of these has an associated setter. - - -## Finding the extent of your type coverage - -To get an indication of your progress on type coverage, analyze the percentage of typed elements across your codebase - -```python -# Initialize counters for parameters -total_parameters = 0 -typed_parameters = 0 - -# Initialize counters for return types -total_functions = 0 -typed_returns = 0 - -# Initialize counters for class attributes -total_attributes = 0 -typed_attributes = 0 - -# Count parameter and return type coverage -for function in codebase.functions: - # Count parameters - total_parameters += len(function.parameters) - typed_parameters += sum(1 for param in function.parameters if param.is_typed) - - # Count return types - total_functions += 1 - if function.return_type and function.return_type.is_typed: - typed_returns += 1 - -# Count class attribute coverage -for cls in codebase.classes: - for attr in cls.attributes: - total_attributes += 1 - if attr.is_typed: - typed_attributes += 1 - -# Calculate percentages -param_percentage = (typed_parameters / total_parameters * 100) if total_parameters > 0 else 0 -return_percentage = (typed_returns / total_functions * 100) if total_functions > 0 else 0 -attr_percentage = (typed_attributes / total_attributes * 100) if total_attributes > 0 else 0 - -# Print results -print("\nType Coverage Analysis") -print("---------------------") -print(f"Parameters: {param_percentage:.1f}% ({typed_parameters}/{total_parameters} typed)") -print(f"Return types: {return_percentage:.1f}% ({typed_returns}/{total_functions} typed)") -print(f"Class attributes: {attr_percentage:.1f}% ({typed_attributes}/{total_attributes} typed)") -``` - -This analysis gives you a breakdown of type coverage across three key areas: -1. Function parameters - Arguments passed to functions -2. Return types - Function return type annotations -3. Class attributes - Type hints on class variables - - - Focus first on adding types to the most frequently used functions and classes, as these will have the biggest impact on type checking and IDE support. - - -## Adding simple return type annotations - -To add a return type, use `function.set_return_type`. The script below will add a `-> None` return type to all functions that contain no return statements: - - -```python For Python -for file in codebase.files: - # Check if 'app' is in the file's filepath - if "app" in file.filepath: - # Iterate through all functions in the file - for function in file.functions: - # Check if the function has no return statements - if len(function.return_statements) == 0: - # Set the return type to None - function.set_return_type("None") -``` - -```python For Typescript -for file in codebase.files: - # Check if 'app' is in the file's filepath - if "app" in file.filepath: - # Iterate through all functions in the file - for function in file.functions: - # Check if the function has no return statements - if len(function.return_statements) == 0: - # Set the return type to None - function.set_return_type("null") -``` - - - -## Coming Soon: Advanced Type Inference - -Codegen is building out an API for direct interface with `tsc` and `mypy` for precise type inference. Interested piloting this API? Let us know! - ---- -title: "Managing TypeScript Exports" -sidebarTitle: "Export Management" -description: "Safely and systematically manage exports in your TypeScript codebase" -icon: "ship" -iconType: "solid" ---- - -Codegen provides powerful tools for managing and reorganizing exports in TypeScript codebases. This tutorial builds on the concepts covered in [exports](/building-with-codegen/exports) to show you how to automate common export management tasks and ensure your module boundaries stay clean and maintainable. - -## Common Export Management Tasks - -### Collecting and Processing Exports - -When reorganizing exports, the first step is identifying which exports need to be processed: - -```python -processed_imports = set() - -for file in codebase.files: - # Only process files under /src/shared - if '/src/shared' not in file.filepath: - continue - - # Gather all reexports that are not external exports - all_reexports = [] - for export_stmt in file.export_statements: - for export in export_stmt.exports: - if export.is_reexport() and not export.is_external_export: - all_reexports.append(export) - - # Skip if there are none - if not all_reexports: - continue -``` - -### Moving Exports to Public Files - -When centralizing exports in public-facing files: - -```python -# Replace "src/" with "src/shared/" -resolved_public_file = export.resolved_symbol.filepath.replace("src/", "src/shared/") - -# Get relative path from the "public" file back to the original file -relative_path = codebase.get_relative_path( - from_file=resolved_public_file, - to_file=export.resolved_symbol.filepath -) - -# Ensure the "public" file exists -if not codebase.has_file(resolved_public_file): - target_file = codebase.create_file(resolved_public_file, sync=True) -else: - target_file = codebase.get_file(resolved_public_file) - -# If target file already has a wildcard export for this relative path, skip -if target_file.has_export_statement_for_path(relative_path, "WILDCARD"): - has_wildcard = True - continue -``` - -### Managing Different Export Types - -Codegen can handle all types of exports automatically: - - - - ```python - # A) Wildcard export, e.g. `export * from "..."` - if export.is_wildcard_export(): - target_file.insert_before(f'export * from "{relative_path}"') - ``` - - - - ```python - # B) Type export, e.g. `export type { Foo, Bar } from "..."` - elif export.is_type_export(): - # Does this file already have a type export statement for the path? - statement = file.get_export_statement_for_path(relative_path, "TYPE") - if statement: - # Insert into existing statement - if export.is_aliased(): - statement.insert(0, f"{export.resolved_symbol.name} as {export.name}") - else: - statement.insert(0, f"{export.name}") - else: - # Insert a new type export statement - if export.is_aliased(): - target_file.insert_before( - f'export type {{ {export.resolved_symbol.name} as {export.name} }} ' - f'from "{relative_path}"' - ) - else: - target_file.insert_before( - f'export type {{ {export.name} }} from "{relative_path}"' - ) - ``` - - - - ```python - # C) Normal export, e.g. `export { Foo, Bar } from "..."` - else: - statement = file.get_export_statement_for_path(relative_path, "EXPORT") - if statement: - # Insert into existing statement - if export.is_aliased(): - statement.insert(0, f"{export.resolved_symbol.name} as {export.name}") - else: - statement.insert(0, f"{export.name}") - else: - # Insert a brand-new normal export statement - if export.is_aliased(): - target_file.insert_before( - f'export {{ {export.resolved_symbol.name} as {export.name} }} ' - f'from "{relative_path}"' - ) - else: - target_file.insert_before( - f'export {{ {export.name} }} from "{relative_path}"' - ) - ``` - - - -## Updating Import References - -After moving exports, you need to update all import references: - -```python -# Now update all import usages that refer to this export -for usage in export.symbol_usages(): - if isinstance(usage, TSImport) and usage not in processed_imports: - processed_imports.add(usage) - - # Translate the resolved_public_file to the usage file's TS config import path - new_path = usage.file.ts_config.translate_import_path(resolved_public_file) - - if has_wildcard and export.name != export.resolved_symbol.name: - name = f"{export.resolved_symbol.name} as {export.name}" - else: - name = usage.name - - if usage.is_type_import(): - new_import = f'import type {{ {name} }} from "{new_path}"' - else: - new_import = f'import {{ {name} }} from "{new_path}"' - - usage.file.insert_before(new_import) - usage.remove() - -# Remove the old export from the original file -export.remove() - -# If the file ends up with no exports, remove it entirely -if not file.export_statements and len(file.symbols) == 0: - file.remove() -``` - -## Best Practices - -1. **Check for Wildcards First**: Always check for existing wildcard exports before adding new ones: -```python -if target_file.has_export_statement_for_path(relative_path, "WILDCARD"): - has_wildcard = True - continue -``` - -2. **Handle Path Translations**: Use TypeScript config for path translations: -```python -new_path = usage.file.ts_config.translate_import_path(resolved_public_file) -``` - -3. **Clean Up Empty Files**: Remove files that no longer contain exports or symbols: -```python -if not file.export_statements and len(file.symbols) == 0: - file.remove() -``` - -## Next Steps - -After reorganizing your exports: - -1. Run your test suite to verify everything still works -2. Review the generated import statements -3. Check for any empty files that should be removed -4. Verify that all export types (wildcard, type, named) are working as expected - - -Remember that managing exports is an iterative process. You may need to run the codemod multiple times as your codebase evolves. - - -### Related tutorials -- [Moving symbols](/building-with-codegen/moving-symbols) -- [Exports](/building-with-codegen/exports) -- [Dependencies and usages](/building-with-codegen/dependencies-and-usages) - -## Complete Codemod - -Here's the complete codemod that you can copy and use directly: - -```python -processed_imports = set() - -for file in codebase.files: - # Only process files under /src/shared - if '/src/shared' not in file.filepath: - continue - - # Gather all reexports that are not external exports - all_reexports = [] - for export_stmt in file.export_statements: - for export in export_stmt.exports: - if export.is_reexport() and not export.is_external_export: - all_reexports.append(export) - - # Skip if there are none - if not all_reexports: - continue - - for export in all_reexports: - has_wildcard = False - - # Replace "src/" with "src/shared/" - resolved_public_file = export.resolved_symbol.filepath.replace("src/", "src/shared/") - - # Get relative path from the "public" file back to the original file - relative_path = codebase.get_relative_path( - from_file=resolved_public_file, - to_file=export.resolved_symbol.filepath - ) - - # Ensure the "public" file exists - if not codebase.has_file(resolved_public_file): - target_file = codebase.create_file(resolved_public_file, sync=True) - else: - target_file = codebase.get_file(resolved_public_file) - - # If target file already has a wildcard export for this relative path, skip - if target_file.has_export_statement_for_path(relative_path, "WILDCARD"): - has_wildcard = True - continue - - # Compare "public" path to the local file's export.filepath - if codebase._remove_extension(resolved_public_file) != codebase._remove_extension(export.filepath): - - # A) Wildcard export, e.g. `export * from "..."` - if export.is_wildcard_export(): - target_file.insert_before(f'export * from "{relative_path}"') - - # B) Type export, e.g. `export type { Foo, Bar } from "..."` - elif export.is_type_export(): - # Does this file already have a type export statement for the path? - statement = file.get_export_statement_for_path(relative_path, "TYPE") - if statement: - # Insert into existing statement - if export.is_aliased(): - statement.insert(0, f"{export.resolved_symbol.name} as {export.name}") - else: - statement.insert(0, f"{export.name}") - else: - # Insert a new type export statement - if export.is_aliased(): - target_file.insert_before( - f'export type {{ {export.resolved_symbol.name} as {export.name} }} ' - f'from "{relative_path}"' - ) - else: - target_file.insert_before( - f'export type {{ {export.name} }} from "{relative_path}"' - ) - - # C) Normal export, e.g. `export { Foo, Bar } from "..."` - else: - statement = file.get_export_statement_for_path(relative_path, "EXPORT") - if statement: - # Insert into existing statement - if export.is_aliased(): - statement.insert(0, f"{export.resolved_symbol.name} as {export.name}") - else: - statement.insert(0, f"{export.name}") - else: - # Insert a brand-new normal export statement - if export.is_aliased(): - target_file.insert_before( - f'export {{ {export.resolved_symbol.name} as {export.name} }} ' - f'from "{relative_path}"' - ) - else: - target_file.insert_before( - f'export {{ {export.name} }} from "{relative_path}"' - ) - - # Now update all import usages that refer to this export - for usage in export.symbol_usages(): - if isinstance(usage, TSImport) and usage not in processed_imports: - processed_imports.add(usage) - - # Translate the resolved_public_file to the usage file's TS config import path - new_path = usage.file.ts_config.translate_import_path(resolved_public_file) - - if has_wildcard and export.name != export.resolved_symbol.name: - name = f"{export.resolved_symbol.name} as {export.name}" - else: - name = usage.name - - if usage.is_type_import(): - new_import = f'import type {{ {name} }} from "{new_path}"' - else: - new_import = f'import {{ {name} }} from "{new_path}"' - - usage.file.insert_before(new_import) - usage.remove() - - # Remove the old export from the original file - export.remove() - - # If the file ends up with no exports, remove it entirely - if not file.export_statements and len(file.symbols) == 0: - file.remove() -``` - ---- -title: "Converting Default Exports" -sidebarTitle: "Default Export Conversion" -description: "Convert default exports to named exports in your TypeScript codebase" -icon: "arrow-right-arrow-left" -iconType: "solid" ---- - -Codegen provides tools to help you migrate away from default exports to named exports in your TypeScript codebase. This tutorial builds on the concepts covered in [exports](/building-with-codegen/exports) to show you how to automate this conversion process. - -## Overview - -Default exports can make code harder to maintain and refactor. Converting them to named exports provides several benefits: - -- Better IDE support for imports and refactoring -- More explicit and consistent import statements -- Easier to track symbol usage across the codebase - -## Converting Default Exports - -Here's how to convert default exports to named exports: - -```python -for file in codebase.files: - target_file = file.filepath - if not target_file: - print(f"⚠️ Target file not found: {filepath}") - continue - - # Get corresponding non-shared file - non_shared_path = target_file.filepath.replace('/shared/', '/') - if not codebase.has_file(non_shared_path): - print(f"⚠️ No matching non-shared file for: {filepath}") - continue - - non_shared_file = codebase.get_file(non_shared_path) - print(f"📄 Processing {target_file.filepath}") - - # Process individual exports - for export in target_file.exports: - # Handle default exports - if export.is_reexport() and export.is_default_export(): - print(f" 🔄 Converting default export '{export.name}'") - default_export = next((e for e in non_shared_file.default_exports), None) - if default_export: - default_export.make_non_default() - - print(f"✨ Fixed exports in {target_file.filepath}") -``` - -## Understanding the Process - -Let's break down how this works: - - - - ```python - # Process individual exports - for export in target_file.exports: - # Handle default exports - if export.is_reexport() and export.is_default_export(): - print(f" 🔄 Converting default export '{export.name}'") - ``` - - The code identifies default exports by checking: - 1. If it's a re-export (`is_reexport()`) - 2. If it's a default export (`is_default_export()`) - - - - ```python - default_export = next((e for e in non_shared_file.default_exports), None) - if default_export: - default_export.make_non_default() - ``` - - For each default export: - 1. Find the corresponding export in the non-shared file - 2. Convert it to a named export using `make_non_default()` - - - - ```python - # Get corresponding non-shared file - non_shared_path = target_file.filepath.replace('/shared/', '/') - if not codebase.has_file(non_shared_path): - print(f"⚠️ No matching non-shared file for: {filepath}") - continue - - non_shared_file = codebase.get_file(non_shared_path) - ``` - - The code: - 1. Maps shared files to their non-shared counterparts - 2. Verifies the non-shared file exists - 3. Loads the non-shared file for processing - - - -## Best Practices - -1. **Check for Missing Files**: Always verify files exist before processing: -```python -if not target_file: - print(f"⚠️ Target file not found: {filepath}") - continue -``` - -2. **Log Progress**: Add logging to track the conversion process: -```python -print(f"📄 Processing {target_file.filepath}") -print(f" 🔄 Converting default export '{export.name}'") -``` - -3. **Handle Missing Exports**: Check that default exports exist before converting: -```python -default_export = next((e for e in non_shared_file.default_exports), None) -if default_export: - default_export.make_non_default() -``` - -## Next Steps - -After converting default exports: - -1. Run your test suite to verify everything still works -2. Update any import statements that were using default imports -3. Review the changes to ensure all exports were converted correctly -4. Consider adding ESLint rules to prevent new default exports - - -Remember to test thoroughly after converting default exports, as this change affects how other files import the converted modules. - - -### Related tutorials -- [Managing typescript exports](/tutorials/managing-typescript-exports) -- [Exports](/building-with-codegen/exports) -- [Dependencies and usages](/building-with-codegen/dependencies-and-usages) - -## Complete Codemod - -Here's the complete codemod that you can copy and use directly: - -```python - -for file in codebase.files: - target_file = file.filepath - if not target_file: - print(f"⚠️ Target file not found: {filepath}") - continue - - # Get corresponding non-shared file - non_shared_path = target_file.filepath.replace('/shared/', '/') - if not codebase.has_file(non_shared_path): - print(f"⚠️ No matching non-shared file for: {filepath}") - continue - - non_shared_file = codebase.get_file(non_shared_path) - print(f"📄 Processing {target_file.filepath}") - - # Process individual exports - for export in target_file.exports: - # Handle default exports - if export.is_reexport() and export.is_default_export(): - print(f" 🔄 Converting default export '{export.name}'") - default_export = next((e for e in non_shared_file.default_exports), None) - if default_export: - default_export.make_non_default() - - print(f"✨ Fixed exports in {target_file.filepath}") - -``` - ---- -title: "Creating Documentation" -sidebarTitle: "Documentation" -icon: "book" -iconType: "solid" ---- - -This guide demonstrates how to determine docs coverage and create documentation for your codebase. - -This primarily leverages two APIs: -- [codebase.ai(...)](/api-reference/core/Codebase#ai) for generating docstrings -- [function.set_docstring(...)](/api-reference/core/HasBlock#set-docstring) for modifying them - -## Determining Documentation Coverage - -In order to determine the extent of your documentation coverage, you can iterate through all symbols of interest and count the number of docstrings: - -To see your current documentation coverage, you can iterate through all symbols of interest and count the number of docstrings: - -```python python -# Initialize counters -total_functions = 0 -functions_with_docs = 0 -total_classes = 0 -classes_with_docs = 0 - -# Check functions -for function in codebase.functions: - total_functions += 1 - if function.docstring: - functions_with_docs += 1 - -# Check classes -for cls in codebase.classes: - total_classes += 1 - if cls.docstring: - classes_with_docs += 1 - -# Calculate percentages -func_coverage = (functions_with_docs / total_functions * 100) if total_functions > 0 else 0 -class_coverage = (classes_with_docs / total_classes * 100) if total_classes > 0 else 0 - -# Print results with emojis -print("\n📊 Documentation Coverage Report:") -print(f"\n📝 Functions:") -print(f" • Total: {total_functions}") -print(f" • Documented: {functions_with_docs}") -print(f" • Coverage: {func_coverage:.1f}%") - -print(f"\n📚 Classes:") -print(f" • Total: {total_classes}") -print(f" • Documented: {classes_with_docs}") -print(f" • Coverage: {class_coverage:.1f}%") - -print(f"\n🎯 Overall Coverage: {((functions_with_docs + classes_with_docs) / (total_functions + total_classes) * 100):.1f}%") -``` - -Which provides the following output: -``` -📊 Documentation Coverage Report: -📝 Functions: - • Total: 1384 - • Documented: 331 - • Coverage: 23.9% -📚 Classes: - • Total: 453 - • Documented: 91 - • Coverage: 20.1% -🎯 Overall Coverage: 23.0% -``` - -## Identifying Areas of Low Documentation Coverage - - -To identify areas of low documentation coverage, you can iterate through all directories and count the number of functions with docstrings. - -Learn more about [Directories here](/building-with-codegen/files-and-directories). - -```python python -# Track directory stats -dir_stats = {} - -# Analyze each directory -for directory in codebase.directories: - # Skip test, sql and alembic directories - if any(x in directory.path.lower() for x in ['test', 'sql', 'alembic']): - continue - - # Get undecorated functions - funcs = [f for f in directory.functions if not f.is_decorated] - total = len(funcs) - - # Only analyze dirs with >10 functions - if total > 10: - documented = sum(1 for f in funcs if f.docstring) - coverage = (documented / total * 100) - dir_stats[directory.path] = { - 'total': total, - 'documented': documented, - 'coverage': coverage - } - -# Find lowest coverage directory -if dir_stats: - lowest_dir = min(dir_stats.items(), key=lambda x: x[1]['coverage']) - path, stats = lowest_dir - - print(f"📉 Lowest coverage directory: '{path}'") - print(f" • Total functions: {stats['total']}") - print(f" • Documented: {stats['documented']}") - print(f" • Coverage: {stats['coverage']:.1f}%") - - # Print all directory stats for comparison - print("\n📊 All directory coverage rates:") - for path, stats in sorted(dir_stats.items(), key=lambda x: x[1]['coverage']): - print(f" '{path}': {stats['coverage']:.1f}% ({stats['documented']}/{stats['total']} functions)") -``` - -Which provides the following output: -```python -📉 Lowest coverage directory: 'codegen-backend/app/utils/github_utils/branch' - • Total functions: 12 - • Documented: 0 - • Coverage: 0.0% -📊 All directory coverage rates: - 'codegen-backend/app/utils/github_utils/branch': 0.0% (0/12 functions) - 'codegen-backend/app/utils/slack': 14.3% (2/14 functions) - 'codegen-backend/app/modal_app/github': 18.2% (2/11 functions) - 'codegen-backend/app/modal_app/slack': 18.2% (2/11 functions) - 'codegen-backend/app/utils/github_utils/webhook': 21.4% (6/28 functions) - 'codegen-backend/app/modal_app/cron': 23.1% (3/13 functions) - 'codegen-backend/app/utils/github_utils': 23.5% (39/166 functions) - 'codegen-backend/app/codemod': 25.0% (7/28 functions) -``` - -## Leveraging AI for Generating Documentation - -For non-trivial codebases, it can be challenging to achieve full documentation coverage. - -The most efficient way to edit informative docstrings is to use [codebase.ai](/api-reference/core/Codebase#ai) to generate docstrings, then use the [set_docstring](/api-reference/core/HasBlock#set-docstring) method to update the docstring. - -Learn more about using AI in our [guides](/building-with-codegen/calling-out-to-llms). - -```python python -# Import datetime for timestamp -from datetime import datetime - -# Get current timestamp -timestamp = datetime.now().strftime("%B %d, %Y") - -print("📚 Generating and Updating Function Documentation") - -# Process all functions in the codebase -for function in codebase.functions: - current_docstring = function.docstring() - - if current_docstring: - # Update existing docstring to be more descriptive - new_docstring = codebase.ai( - f"Update the docstring for {function.name} to be more descriptive and comprehensive.", - target=function - ) - new_docstring += f"\n\nUpdated on: {timestamp}" - else: - # Generate new docstring for function - new_docstring = codebase.ai( - f"Generate a comprehensive docstring for {function.name} including parameters, return type, and description.", - target=function - ) - new_docstring += f"\n\nCreated on: {timestamp}" - - # Set the new or updated docstring - function.set_docstring(new_docstring) -``` - - - -## Adding Explicit Parameter Names and Types - -Alternatively, you can also rely on deterministic string formatting to edit docstrings. - -To add "Google-style" parameter names and types to a function docstring, you can use the following code snippet: - -```python python -# Iterate through all functions in the codebase -for function in codebase.functions: - # Skip if function already has a docstring - if function.docstring: - continue - - # Build parameter documentation - param_docs = [] - for param in function.parameters: - param_type = param.type.source if param.is_typed else "Any" - param_docs.append(f" {param.name} ({param_type}): Description of {param.name}") - - # Get return type if present - return_type = function.return_type.source if function.return_type else "None" - - # Create Google-style docstring - docstring = f'''""" - Description of {function.name}. - - Args: -{chr(10).join(param_docs)} - - Returns: - {return_type}: Description of return value - """''' - - # Set the new docstring - function.set_docstring(docstring) -``` - - ---- -title: "React Modernization" -sidebarTitle: "React Modernization" -icon: "react" -iconType: "brands" -description: "Modernize your React codebase with Codegen" ---- - -Codegen SDK provides powerful APIs for modernizing React codebases. This guide will walk you through common React modernization patterns. - -Common use cases include: - -- Upgrading to modern APIs, including React 18+ -- Automatically memoizing components -- Converting to modern hooks -- Standardizing prop types -- Organizing components into individual files - -and much more. - -## Converting Class Components to Functions - -Here's how to convert React class components to functional components: - -```python -# Find all React class components -for class_def in codebase.classes: - # Skip if not a React component - if not class_def.is_jsx or "Component" not in [base.name for base in class_def.bases]: - continue - - print(f"Converting {class_def.name} to functional component") - - # Extract state from constructor - constructor = class_def.get_method("constructor") - state_properties = [] - if constructor: - for statement in constructor.code_block.statements: - if "this.state" in statement.source: - # Extract state properties - state_properties = [prop.strip() for prop in - statement.source.split("{")[1].split("}")[0].split(",")] - - # Create useState hooks for each state property - state_hooks = [] - for prop in state_properties: - hook_name = f"[{prop}, set{prop[0].upper()}{prop[1:]}]" - state_hooks.append(f"const {hook_name} = useState(null);") - - # Convert lifecycle methods to effects - effects = [] - if class_def.get_method("componentDidMount"): - effects.append(""" - useEffect(() => { - // TODO: Move componentDidMount logic here - }, []); - """) - - if class_def.get_method("componentDidUpdate"): - effects.append(""" - useEffect(() => { - // TODO: Move componentDidUpdate logic here - }); - """) - - # Get the render method - render_method = class_def.get_method("render") - - # Create the functional component - func_component = f""" -const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) => {{ - {chr(10).join(state_hooks)} - {chr(10).join(effects)} - - {render_method.code_block.source} -}} -""" - - # Replace the class with the functional component - class_def.edit(func_component) - - # Add required imports - file = class_def.file - if not any("useState" in imp.source for imp in file.imports): - file.add_import("import { useState, useEffect } from 'react';") -``` - -## Migrating to Modern Hooks - -Convert legacy patterns to modern React hooks: - -```python -# Find components using legacy patterns -for function in codebase.functions: - if not function.is_jsx: - continue - - # Look for common legacy patterns - for call in function.function_calls: - # Convert withRouter to useNavigate - if call.name == "withRouter": - # Add useNavigate import - function.file.add_import( - "import { useNavigate } from 'react-router-dom';" - ) - # Add navigate hook - function.insert_before_first_return("const navigate = useNavigate();") - # Replace history.push calls - for history_call in function.function_calls: - if "history.push" in history_call.source: - history_call.edit( - history_call.source.replace("history.push", "navigate") - ) - - # Convert lifecycle methods in hooks - elif call.name == "componentDidMount": - call.parent.edit(""" -useEffect(() => { - // Your componentDidMount logic here -}, []); -""") -``` - -## Standardizing Props - -### Inferring Props from Usage - -Add proper prop types and TypeScript interfaces based on how props are used: - -```python -# Add TypeScript interfaces for props -for function in codebase.functions: - if not function.is_jsx: - continue - - # Get props parameter - props_param = function.parameters[0] if function.parameters else None - if not props_param: - continue - - # Collect used props - used_props = set() - for prop_access in function.function_calls: - if f"{props_param.name}." in prop_access.source: - prop_name = prop_access.source.split(".")[1] - used_props.add(prop_name) - - # Create interface - if used_props: - interface_def = f""" -interface {function.name}Props {{ - {chr(10).join(f' {prop}: any;' for prop in used_props)} -}} -""" - function.insert_before(interface_def) - # Update function signature - function.edit(function.source.replace( - f"({props_param.name})", - f"({props_param.name}: {function.name}Props)" - )) -``` - -### Extracting Inline Props - -Convert inline prop type definitions to separate type declarations: - -```python -# Iterate over all files in the codebase -for file in codebase.files: - # Iterate over all functions in the file - for function in file.functions: - # Check if the function is a React functional component - if function.is_jsx: # Assuming is_jsx indicates a function component - # Check if the function has inline props definition - if len(function.parameters) == 1 and isinstance(function.parameters[0].type, Dict): - # Extract the inline prop type - inline_props: TSObjectType = function.parameters[0].type.source - # Create a new type definition for the props - props_type_name = f"{function.name}Props" - props_type_definition = f"type {props_type_name} = {inline_props};" - - # Set the new type for the parameter - function.parameters[0].set_type_annotation(props_type_name) - # Add the new type definition to the file - function.insert_before('\n' + props_type_definition + '\n') -``` - -This will convert components from: - -```typescript -function UserCard({ name, age }: { name: string; age: number }) { - return ( -
- {name} ({age}) -
- ); -} -``` - -To: - -```typescript -type UserCardProps = { name: string; age: number }; - -function UserCard({ name, age }: UserCardProps) { - return ( -
- {name} ({age}) -
- ); -} -``` - - - Extracting prop types makes them reusable and easier to maintain. It also - improves code readability by separating type definitions from component logic. - - -## Updating Fragment Syntax - -Modernize React Fragment syntax: - -```python -for function in codebase.functions: - if not function.is_jsx: - continue - - # Replace React.Fragment with <> - for element in function.jsx_elements: - if element.name == "React.Fragment": - element.edit(element.source.replace( - "", - "<>" - ).replace( - "", - "" - )) -``` - -## Organizing Components into Individual Files - -A common modernization task is splitting files with multiple components into a more maintainable structure where each component has its own file. This is especially useful when modernizing legacy React codebases that might have grown organically. - -```python -# Initialize a dictionary to store files and their corresponding JSX components -files_with_jsx_components = {} - -# Iterate through all files in the codebase -for file in codebase.files: - # Check if the file is in the components directory - if 'components' not in file.filepath: - continue - - # Count the number of JSX components in the file - jsx_count = sum(1 for function in file.functions if function.is_jsx) - - # Only proceed if there are multiple JSX components - if jsx_count > 1: - # Identify non-default exported components - non_default_components = [ - func for func in file.functions - if func.is_jsx and not func.is_exported - ] - default_components = [ - func for func in file.functions - if func.is_jsx and func.is_exported and func.export.is_default_export() - ] - - # Log the file path and its components - print(f"📁 {file.filepath}:") - for component in default_components: - print(f" 🟢 {component.name} (default)") - for component in non_default_components: - print(f" 🔵 {component.name}") - - # Create a new directory path based on the original file's directory - new_dir_path = "/".join(file.filepath.split("/")[:-1]) + "/" + file.name.split(".")[0] - codebase.create_directory(new_dir_path, exist_ok=True) - - # Create a new file path for the component - new_file_path = f"{new_dir_path}/{component.name}.tsx" - new_file = codebase.create_file(new_file_path) - - # Log the movement of the component - print(f" 🫸 Moved to: {new_file_path}") - - # Move the component to the new file - component.move_to_file(new_file, strategy="add_back_edge") -``` - -This script will: - -1. Find files containing multiple React components -2. Create a new directory structure based on the original file -3. Move each non-default exported component to its own file -4. Preserve imports and dependencies automatically -5. Keep default exports in their original location - -For example, given this structure: - -``` -components/ - Forms.tsx # Contains Button, Input, Form (default) -``` - -It will create: - -``` -components/ - Forms.tsx # Contains Form (default) - forms/ - Button.tsx - Input.tsx -``` - - - The `strategy="add_back_edge"` parameter ensures that any components that were - previously co-located can still import each other without circular - dependencies. Learn more about [moving - code](/building-with-codegen/moving-symbols) here. - - - - ---- -title: "Migrating from unittest to pytest" -sidebarTitle: "Unittest to Pytest" -description: "Learn how to migrate unittest test suites to pytest using Codegen" -icon: "vial" -iconType: "solid" ---- - -Migrating from [unittest](https://docs.python.org/3/library/unittest.html) to [pytest](https://docs.pytest.org/) involves converting test classes and assertions to pytest's more modern and concise style. This guide will walk you through using Codegen to automate this migration. - - -You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/unittest_to_pytest). - - -## Overview - -The migration process involves four main steps: - -1. Converting test class inheritance and setup/teardown methods -2. Updating assertions to pytest style -3. Converting test discovery patterns -4. Modernizing fixture usage - -Let's walk through each step using Codegen. - -## Step 1: Convert Test Classes and Setup Methods - -The first step is to convert unittest's class-based tests to pytest's function-based style. This includes: - -- Removing `unittest.TestCase` inheritance -- Converting `setUp` and `tearDown` methods to fixtures -- Updating class-level setup methods - -```python -# From: -class TestUsers(unittest.TestCase): - def setUp(self): - self.db = setup_test_db() - - def tearDown(self): - self.db.cleanup() - - def test_create_user(self): - user = self.db.create_user("test") - self.assertEqual(user.name, "test") - -# To: -import pytest - -@pytest.fixture -def db(): - db = setup_test_db() - yield db - db.cleanup() - -def test_create_user(db): - user = db.create_user("test") - assert user.name == "test" -``` - -## Step 2: Update Assertions - -Next, we'll convert unittest's assertion methods to pytest's plain assert statements: - -```python -# From: -def test_user_validation(self): - self.assertTrue(is_valid_email("user@example.com")) - self.assertFalse(is_valid_email("invalid")) - self.assertEqual(get_user_count(), 0) - self.assertIn("admin", get_roles()) - self.assertRaises(ValueError, parse_user_id, "invalid") - -# To: -def test_user_validation(): - assert is_valid_email("user@example.com") - assert not is_valid_email("invalid") - assert get_user_count() == 0 - assert "admin" in get_roles() - with pytest.raises(ValueError): - parse_user_id("invalid") -``` - -## Step 3: Update Test Discovery - -pytest uses a different test discovery pattern than unittest. We'll update the test file names and patterns: - -```python -# From: -if __name__ == '__main__': - unittest.main() - -# To: -# Remove the unittest.main() block entirely -# Rename test files to test_*.py or *_test.py -``` - -## Step 4: Modernize Fixture Usage - -Finally, we'll update how test dependencies are managed using pytest's powerful fixture system: - -```python -# From: -class TestDatabase(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.db_conn = create_test_db() - - def setUp(self): - self.transaction = self.db_conn.begin() - - def tearDown(self): - self.transaction.rollback() - -# To: -@pytest.fixture(scope="session") -def db_conn(): - return create_test_db() - -@pytest.fixture -def transaction(db_conn): - transaction = db_conn.begin() - yield transaction - transaction.rollback() -``` - -## Common Patterns - -Here are some common patterns you'll encounter when migrating to pytest: - -1. **Parameterized Tests** - -```python -# From: -def test_validation(self): - test_cases = [("valid@email.com", True), ("invalid", False)] - for email, expected in test_cases: - with self.subTest(email=email): - self.assertEqual(is_valid_email(email), expected) - -# To: -@pytest.mark.parametrize("email,expected", [ - ("valid@email.com", True), - ("invalid", False) -]) -def test_validation(email, expected): - assert is_valid_email(email) == expected -``` - -2. **Exception Testing** - -```python -# From: -def test_exceptions(self): - self.assertRaises(ValueError, process_data, None) - with self.assertRaises(TypeError): - process_data(123) - -# To: -def test_exceptions(): - with pytest.raises(ValueError): - process_data(None) - with pytest.raises(TypeError): - process_data(123) -``` - -3. **Temporary Resources** - -```python -# From: -def setUp(self): - self.temp_dir = tempfile.mkdtemp() - -def tearDown(self): - shutil.rmtree(self.temp_dir) - -# To: -@pytest.fixture -def temp_dir(): - dir = tempfile.mkdtemp() - yield dir - shutil.rmtree(dir) -``` - -## Tips and Notes - -1. pytest fixtures are more flexible than unittest's setup/teardown methods: - - - They can be shared across test files - - They support different scopes (function, class, module, session) - - They can be parameterized - -2. pytest's assertion introspection provides better error messages by default: - - ```python - # pytest shows a detailed comparison - assert result == expected - ``` - -3. You can gradually migrate to pytest: - - - pytest can run unittest-style tests - - Convert one test file at a time - - Start with assertion style updates before moving to fixtures - -4. Consider using pytest's built-in fixtures: - - `tmp_path` for temporary directories - - `capsys` for capturing stdout/stderr - - `monkeypatch` for modifying objects - - `caplog` for capturing log messages - - ---- -title: "Migrating from SQLAlchemy 1.4 to 2.0" -sidebarTitle: "SQLAlchemy 1.4 to 2.0" -description: "Learn how to migrate SQLAlchemy 1.4 codebases to 2.0 using Codegen" -icon: "layer-group" -iconType: "solid" ---- - -Migrating from [SQLAlchemy](https://www.sqlalchemy.org/) 1.4 to 2.0 involves several API changes to support the new 2.0-style query interface. This guide will walk you through using Codegen to automate this migration, handling query syntax, session usage, and ORM patterns. - - -You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/sqlalchemy_1.4_to_2.0). - - -## Overview - -The migration process involves three main steps: - -1. Converting legacy Query objects to select() statements -2. Updating session execution patterns -3. Modernizing ORM relationship declarations - -Let's walk through each step using Codegen. - -## Step 1: Convert Query to Select - -First, we need to convert legacy Query-style operations to the new select() syntax: - -```python -def convert_query_to_select(file): - """Convert Query-style operations to select() statements""" - for call in file.function_calls: - if call.name == "query": - # Convert query(Model) to select(Model) - call.set_name("select") - - # Update method chains - if call.parent and call.parent.is_method_chain: - chain = call.parent - if "filter" in chain.source: - # Convert .filter() to .where() - chain.source = chain.source.replace(".filter(", ".where(") - if "filter_by" in chain.source: - # Convert .filter_by(name='x') to .where(Model.name == 'x') - model = call.args[0].value - conditions = chain.source.split("filter_by(")[1].split(")")[0] - new_conditions = [] - for cond in conditions.split(","): - if "=" in cond: - key, value = cond.split("=") - new_conditions.append(f"{model}.{key.strip()} == {value.strip()}") - chain.edit(f".where({' & '.join(new_conditions)})") -``` - -This transforms code from: - -```python -# Legacy Query style -session.query(User).filter_by(name='john').filter(User.age >= 18).all() -``` - -to: - -```python -# New select() style -session.execute( - select(User).where(User.name == 'john').where(User.age >= 18) -).scalars().all() -``` - - - SQLAlchemy 2.0 standardizes on select() statements for all queries, providing - better type checking and a more consistent API. - - -## Step 2: Update Session Execution - -Next, we update how queries are executed with the Session: - -```python -def update_session_execution(file): - """Update session execution patterns for 2.0 style""" - for call in file.function_calls: - if call.name == "query": - # Find the full query chain - chain = call - while chain.parent and chain.parent.is_method_chain: - chain = chain.parent - - # Wrap in session.execute() if needed - if not chain.parent or "execute" not in chain.parent.source: - chain.edit(f"execute(select{chain.source[5:]})") - - # Add .scalars() for single-entity queries - if len(call.args) == 1: - chain.edit(f"{chain.source}.scalars()") -``` - -This converts patterns like: - -```python -# Old style -users = session.query(User).all() -first_user = session.query(User).first() -``` - -to: - -```python -# New style -users = session.execute(select(User)).scalars().all() -first_user = session.execute(select(User)).scalars().first() -``` - - - The new execution pattern is more explicit about what's being returned, making - it easier to understand and maintain type safety. - - -## Step 3: Update ORM Relationships - -Finally, we update relationship declarations to use the new style: - -``` - -``` - - ---- -title: "Fixing Import Loops" -description: "Learn how to identify and fix problematic import loops using Codegen." -icon: "arrows-rotate" -iconType: "solid" ---- - - - - - -Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. - -In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen. - - -You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/removing_import_loops_in_pytorch). - - -## Overview - -The steps to identify and fix import loops are as follows: -1. Detect import loops -2. Visualize them -3. Identify problematic cycles with mixed static/dynamic imports -4. Fix these cycles using Codegen - -# Step 1: Detect Import Loops -- Create a graph -- Loop through imports in the codebase and add edges between the import files -- Find strongly connected components using Networkx (the import loops) -```python -G = nx.MultiDiGraph() - -# Add all edges to the graph -for imp in codebase.imports: - if imp.from_file and imp.to_file: - edge_color = "red" if imp.is_dynamic else "black" - edge_label = "dynamic" if imp.is_dynamic else "static" - - # Store the import statement and its metadata - G.add_edge( - imp.to_file.filepath, - imp.from_file.filepath, - color=edge_color, - label=edge_label, - is_dynamic=imp.is_dynamic, - import_statement=imp, # Store the whole import object - key=id(imp.import_statement), - ) -# Find strongly connected components -cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1] - -print(f"🔄 Found {len(cycles)} import cycles:") -for i, cycle in enumerate(cycles, 1): - print(f"\nCycle #{i}:") - print(f"Size: {len(cycle)} files") - - # Create subgraph for this cycle to count edges - cycle_subgraph = G.subgraph(cycle) - - # Count total edges - total_edges = cycle_subgraph.number_of_edges() - print(f"Total number of imports in cycle: {total_edges}") - - # Count dynamic and static imports separately - dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "red") - static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "black") - - print(f"Number of dynamic imports: {dynamic_imports}") - print(f"Number of static imports: {static_imports}") -``` - - -## Understanding Import Cycles - -Not all import cycles are problematic! Here's an example of a cycle that one may think would cause an error but it does not because due to using dynamic imports. - -```python -# top level import in in APoT_tensor.py -from quantizer.py import objectA -``` - -```python -# dynamic import in quantizer.py -def some_func(): - # dynamic import (evaluated when some_func() is called) - from APoT_tensor.py import objectB -``` - - - -A dynamic import is an import defined inside of a function, method or any executable body of code which delays the import execution until that function, method or body of code is called. - -You can use [Import.is_dynamic](/api-reference/core/Import#is-dynamic) to check if the import is dynamic allowing you to investigate imports that are handled more intentionally. - -# Step 2: Visualize Import Loops -- Create a new subgraph to visualize one cycle -- color and label the edges based on their type (dynamic/static) -- visualize the cycle graph using [codebase.visualize(graph)](/api-reference/core/Codebase#visualize) - -Learn more about codebase visualization [here](/building-with-codegen/codebase-visualization) - -```python -cycle = cycles[0] - -def create_single_loop_graph(cycle): - cycle_graph = nx.MultiDiGraph() # Changed to MultiDiGraph to support multiple edges - cycle = list(cycle) - for i in range(len(cycle)): - for j in range(len(cycle)): - # Get all edges between these nodes from original graph - edge_data_dict = G.get_edge_data(cycle[i], cycle[j]) - if edge_data_dict: - # For each edge between these nodes - for edge_key, edge_data in edge_data_dict.items(): - # Add edge with all its attributes to cycle graph - cycle_graph.add_edge(cycle[i], cycle[j], **edge_data) - return cycle_graph - - -cycle_graph = create_single_loop_graph(cycle) -codebase.visualize(cycle_graph) -``` - - - - - - -# Step 3: Identify problematic cycles with mixed static & dynamic imports - -The import loops that we are really concerned about are those that have mixed static/dynamic imports. - -Here's an example of a problematic cycle that we want to fix: - -```python -# In flex_decoding.py -from .flex_attention import ( - compute_forward_block_mn, - compute_forward_inner, - # ... more static imports -) - -# Also in flex_decoding.py -def create_flex_decoding_kernel(*args, **kwargs): - from .flex_attention import set_head_dim_values # dynamic import -``` - -It's clear that there is both a top level and a dynamic import that imports from the *same* module. Thus, this can cause issues if not handled carefully. - - - -Let's find these problematic cycles: - -```python -def find_problematic_import_loops(G, sccs): - """Find cycles where files have both static and dynamic imports between them.""" - problematic_cycles = [] - - for i, scc in enumerate(sccs): - if i == 2: # skipping the second import loop as it's incredibly long (it's also invalid) - continue - mixed_import_files = {} # (from_file, to_file) -> {dynamic: count, static: count} - - # Check all file pairs in the cycle - for from_file in scc: - for to_file in scc: - if G.has_edge(from_file, to_file): - # Get all edges between these files - edges = G.get_edge_data(from_file, to_file) - - # Count imports by type - dynamic_count = sum(1 for e in edges.values() if e["color"] == "red") - static_count = sum(1 for e in edges.values() if e["color"] == "black") - - # If we have both types between same files, this is problematic - if dynamic_count > 0 and static_count > 0: - mixed_import_files[(from_file, to_file)] = {"dynamic": dynamic_count, "static": static_count, "edges": edges} - - if mixed_import_files: - problematic_cycles.append({"files": scc, "mixed_imports": mixed_import_files, "index": i}) - - # Print findings - print(f"Found {len(problematic_cycles)} cycles with mixed imports:") - for i, cycle in enumerate(problematic_cycles): - print(f"\n⚠️ Problematic Cycle #{i + 1}:") - print(f"\n⚠️ Index #{cycle['index']}:") - print(f"Size: {len(cycle['files'])} files") - - for (from_file, to_file), data in cycle["mixed_imports"].items(): - print("\n📁 Mixed imports detected:") - print(f" From: {from_file}") - print(f" To: {to_file}") - print(f" Dynamic imports: {data['dynamic']}") - print(f" Static imports: {data['static']}") - - return problematic_cycles - -problematic_cycles = find_problematic_import_loops(G, cycles) -``` - -# Step 4: Fix the loop by moving the shared symbols to a separate `utils.py` file -One common fix to this problem to break this cycle is to move all the shared symbols to a separate `utils.py` file. We can do this using the method [symbol.move_to_file](/api-reference/core/Symbol#move-to-file): - -Learn more about moving symbols [here](/building-with-codegen/moving-symbols) - -```python -# Create new utils file -utils_file = codebase.create_file("torch/_inductor/kernel/flex_utils.py") - -# Get the two files involved in the import cycle -decoding_file = codebase.get_file("torch/_inductor/kernel/flex_decoding.py") -attention_file = codebase.get_file("torch/_inductor/kernel/flex_attention.py") -attention_file_path = "torch/_inductor/kernel/flex_attention.py" -decoding_file_path = "torch/_inductor/kernel/flex_decoding.py" - -# Track symbols to move -symbols_to_move = set() - -# Find imports from flex_attention in flex_decoding -for imp in decoding_file.imports: - if imp.from_file and imp.from_file.filepath == attention_file_path: - # Get the actual symbol from flex_attention - if imp.imported_symbol: - symbols_to_move.add(imp.imported_symbol) - -# Move identified symbols to utils file -for symbol in symbols_to_move: - symbol.move_to_file(utils_file) - -print(f"🔄 Moved {len(symbols_to_move)} symbols to flex_utils.py") -for symbol in symbols_to_move: - print(symbol.name) - -# Commit changes -codebase.commit() -``` - -# Conclusions & Next Steps - -Import loops can be tricky to identify and fix, but Codegen provides powerful tools to help manage them: - -- Use `codebase.imports` to analyze import relationships across your project -- Visualize import cycles to better understand dependencies -- Distinguish between static and dynamic imports using `Import.is_dynamic` -- Move shared symbols to break cycles using `symbol.move_to_file` - -Here are some next steps you can take: - -1. **Analyze Your Codebase**: Run similar analysis on your own codebase to identify potential import cycles -2. **Create Import Guidelines**: Establish best practices for your team around when to use static vs dynamic imports -3. **Automate Fixes**: Create scripts to automatically detect and fix problematic import patterns -4. **Monitor Changes**: Set up CI checks to prevent new problematic import cycles from being introduced - - -For more examples of codebase analysis and refactoring, check out our other [tutorials](/tutorials/at-a-glance). - - ---- -title: "Migrating from Python 2 to Python 3" -sidebarTitle: "Python 2 to 3" -description: "Learn how to migrate Python 2 codebases to Python 3 using Codegen" -icon: "snake" -iconType: "solid" ---- - -Migrating from Python 2 to Python 3 involves several syntax and API changes. This guide will walk you through using Codegen to automate this migration, handling print statements, string handling, iterators, and more. - - -You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/python2_to_python3). - - -## Overview - -The migration process involves five main steps: - -1. Converting print statements to function calls -2. Updating Unicode to str -3. Converting raw_input to input -4. Updating exception handling syntax -5. Modernizing iterator methods - -Let's walk through each step using Codegen. - -## Step 1: Convert Print Statements - -First, we need to convert Python 2's print statements to Python 3's print function calls: - -```python -def convert_print_statements(file): - """Convert Python 2 print statements to Python 3 function calls""" - lines = file.content.split('\n') - new_content = [] - - for line in lines: - stripped = line.strip() - if stripped.startswith('print '): - indent = line[:len(line) - len(line.lstrip())] - args = stripped[6:].strip() - new_content.append(f"{indent}print({args})") - else: - new_content.append(line) - - if new_content != lines: - file.edit('\n'.join(new_content)) -``` - -This transforms code from: - -```python -print "Hello, world!" -print x, y, z -``` - -to: - -```python -print("Hello, world!") -print(x, y, z) -``` - - - In Python 3, `print` is a function rather than a statement, requiring - parentheses around its arguments. - - -## Step 2: Update Unicode to str - -Next, we update Unicode-related code to use Python 3's unified string type: - -```python -def update_unicode_to_str(file): - """Convert Unicode-related code to str for Python 3""" - # Update imports from 'unicode' to 'str' - for imp in file.imports: - if imp.name == 'unicode': - imp.set_name("str") - - # Update function calls from Unicode to str - for func_call in file.function_calls: - if func_call.name == "unicode": - func_call.set_name("str") - - # Check function arguments for Unicode references - for arg in func_call.args: - if arg.value == "unicode": - arg.set_value("str") - - # Find and update Unicode string literals (u"...") - for string_literal in file.find('u"'): - if string_literal.source.startswith('u"') or string_literal.source.startswith("u'"): - new_string = string_literal.source[1:] # Remove the 'u' prefix - string_literal.edit(new_string) -``` - -This converts code from: - -```python -from __future__ import unicode_literals -text = unicode("Hello") -prefix = u"prefix" -``` - -to: - -```python -text = str("Hello") -prefix = "prefix" -``` - - - Python 3 unifies string types, making the `unicode` type and `u` prefix - unnecessary. - - -## Step 3: Convert raw_input to input - -Python 3 renames `raw_input()` to `input()`: - -```python -def convert_raw_input(file): - """Convert raw_input() calls to input()""" - for call in file.function_calls: - if call.name == "raw_input": - call.edit(f"input{call.source[len('raw_input'):]}") -``` - -This updates code from: - -```python -name = raw_input("Enter your name: ") -``` - -to: - -```python -name = input("Enter your name: ") -``` - - - Python 3's `input()` function always returns a string, like Python 2's - `raw_input()`. - - -## Step 4: Update Exception Handling - -Python 3 changes the syntax for exception handling: - -```python -def update_exception_syntax(file): - """Update Python 2 exception handling to Python 3 syntax""" - for editable in file.find("except "): - if editable.source.lstrip().startswith("except") and ", " in editable.source and " as " not in editable.source: - parts = editable.source.split(",", 1) - new_source = f"{parts[0]} as{parts[1]}" - editable.edit(new_source) -``` - -This converts code from: - -```python -try: - process_data() -except ValueError, e: - print(e) -``` - -to: - -```python -try: - process_data() -except ValueError as e: - print(e) -``` - - - Python 3 uses `as` instead of a comma to name the exception variable. - - -## Step 5: Update Iterator Methods - -Finally, we update iterator methods to use Python 3's naming: - -```python -def update_iterators(file): - """Update iterator methods from Python 2 to Python 3""" - for cls in file.classes: - next_method = cls.get_method("next") - if next_method: - # Create new __next__ method with same content - new_method_source = next_method.source.replace("def next", "def __next__") - cls.add_source(new_method_source) - next_method.remove() -``` - -This transforms iterator classes from: - -```python -class MyIterator: - def next(self): - return self.value -``` - -to: - -```python -class MyIterator: - def __next__(self): - return self.value -``` - - - Python 3 renames the `next()` method to `__next__()` for consistency with - other special methods. - - -## Running the Migration - -You can run the complete migration using our example script: - -```bash -git clone https://github.com/codegen-sh/codegen-sdk.git -cd codegen-examples/examples/python2_to_python3 -python run.py -``` - -The script will: - -1. Process all Python [files](/api-reference/python/PyFile) in your codebase -2. Apply the transformations in the correct order -3. Maintain your code's functionality while updating to Python 3 syntax - -## Next Steps - -After migration, you might want to: - -- Add type hints to your code -- Use f-strings for string formatting -- Update dependencies to Python 3 versions -- Run the test suite to verify functionality - -Check out these related tutorials: - -- [Increase Type Coverage](/tutorials/increase-type-coverage) -- [Organizing Your Codebase](/tutorials/organize-your-codebase) -- [Creating Documentation](/tutorials/creating-documentation) - -## Learn More - -- [Python 3 Documentation](https://docs.python.org/3/) -- [What's New in Python 3](https://docs.python.org/3/whatsnew/3.0.html) -- [Codegen API Reference](/api-reference) -- [Dependencies and Usages](/building-with-codegen/dependencies-and-usages) - - ---- -title: "Migrating from Flask to FastAPI" -sidebarTitle: "Flask to FastAPI" -icon: "bolt" -iconType: "solid" ---- - -Migrating from [Flask](https://flask.palletsprojects.com/) to [FastAPI](https://fastapi.tiangolo.com/) involves several key changes to your codebase. This guide will walk you through using Codegen to automate this migration, handling imports, route decorators, static files, and template rendering. - -You can find the complete example code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/codegen-examples/examples/flask_to_fastapi_migration) - -## Overview - -The migration process involves four main steps: - -1. Updating imports and initialization -2. Converting route decorators -3. Setting up static file handling -4. Updating template handling - -Let's walk through each step using Codegen. - -## I: Update Imports and Initialization - -First, we need to update Flask imports to their FastAPI equivalents and modify the app initialization: - - - Learn more about [imports here](/building-with-codegen/imports). - - -```python -from codegen import Codebase - -# Parse the codebase -codebase = Codebase("./") - -# Update imports and initialization -for file in codebase.files: - # Update Flask to FastAPI imports - for imp in file.imports: - if imp.name == "Flask": - imp.set_name("FastAPI") - elif imp.module == "flask": - imp.set_module("fastapi") - - # Update app initialization - for call in file.function_calls: - if call.name == "Flask": - call.set_name("FastAPI") - # Remove __name__ argument (not needed in FastAPI) - if len(call.args) > 0 and call.args[0].value == "__name__": - call.args[0].remove() -``` - -This transforms code from: - -```python -from flask import Flask -app = Flask(__name__) -``` - -to: - -```python -from fastapi import FastAPI -app = FastAPI() -``` - - - FastAPI doesn't require the `__name__` argument that Flask uses for template - resolution. Codegen automatically removes it during migration. - - -## II: Convert Route Decorators - -Next, we update Flask's route decorators to FastAPI's operation decorators: - -```python -for function in file.functions: - for decorator in function.decorators: - if "@app.route" in decorator.source: - route = decorator.source.split('"')[1] - method = "get" # Default to GET - if "methods=" in decorator.source: - methods = decorator.source.split("methods=")[1].split("]")[0] - if "post" in methods.lower(): - method = "post" - elif "put" in methods.lower(): - method = "put" - elif "delete" in methods.lower(): - method = "delete" - decorator.edit(f'@app.{method}("{route}")') -``` - -This converts decorators from Flask style: - -```python -@app.route("/users", methods=["POST"]) -def create_user(): - pass -``` - -to FastAPI style: - -```python -@app.post("/users") -def create_user(): - pass -``` - - - FastAPI provides specific decorators for each HTTP method, making the API more - explicit and enabling better type checking and OpenAPI documentation. - - -## III: Setup Static Files - -FastAPI handles static files differently than Flask. We need to add the StaticFiles mounting: - -```python -# Add StaticFiles import -file.add_import("from fastapi.staticfiles import StaticFiles") - -# Mount static directory -file.add_symbol_from_source( - 'app.mount("/static", StaticFiles(directory="static"), name="static")' -) -``` - -This sets up static file serving equivalent to Flask's automatic static file handling. - - - FastAPI requires explicit mounting of static directories, which provides more - flexibility in how you serve static files. - - -## IV: Update Template Handling - -Finally, we update the template rendering to use FastAPI's Jinja2Templates: - -```python -for func_call in file.function_calls: - if func_call.name == "render_template": - # Convert to FastAPI's template response - func_call.set_name("Jinja2Templates(directory='templates').TemplateResponse") - if len(func_call.args) > 1: - # Convert template variables to context dict - context_arg = ", ".join( - f"{arg.name}={arg.value}" for arg in func_call.args[1:] - ) - func_call.set_kwarg("context", f"{'{'}{context_arg}{'}'}") - # Add required request parameter - func_call.set_kwarg("request", "request") -``` - -This transforms template rendering from Flask style: - -```python -@app.get("/users") -def list_users(): - return render_template("users.html", users=users) -``` - -to FastAPI style: - -```python -@app.get("/users") -def list_users(request: Request): - return Jinja2Templates(directory="templates").TemplateResponse( - "users.html", - context={"users": users}, - request=request - ) -``` - - - FastAPI requires the `request` object to be passed to templates. Codegen - automatically adds this parameter during migration. - - -## Running the Migration - -You can run the complete migration using our example script: - -```bash -git clone https://github.com/codegen-sh/codegen-sdk.git -cd codegen-examples/examples/flask_to_fastapi_migration -python run.py -``` - -The script will: - -1. Process all Python [files](/api-reference/python/PyFile) in your codebase -2. Apply the transformations in the correct order -3. Maintain your code's functionality while updating to FastAPI patterns - -## Next Steps - -After migration, you might want to: - -- Add type hints to your route parameters -- Set up dependency injection -- Add request/response models -- Configure CORS and middleware - -Check out these related tutorials: - -- [Increase Type Coverage](/tutorials/increase-type-coverage) -- [Managing TypeScript Exports](/tutorials/managing-typescript-exports) -- [Organizing Your Codebase](/tutorials/organize-your-codebase) - -## Learn More - -- [FastAPI Documentation](https://fastapi.tiangolo.com/) -- [Codegen API Reference](/api-reference) -- [Moving Symbols Guide](/building-with-codegen/moving-symbols) -- [Dependencies and Usages](/building-with-codegen/dependencies-and-usages) - - ---- -title: "Building a Model Context Protocol server with Codegen" -sidebarTitle: "MCP Server" -icon: "boxes-stacked" -iconType: "solid" ---- - -Learn how to build a Model Context Protocol (MCP) server that enables AI models to understand and manipulate code using Codegen's powerful tools. - -This guide will walk you through creating an MCP server that can provide semantic code search - -View the full code in our [examples repository](https://github.com/codegen-sh/codegen-sdk/tree/develop/src/codegen/extensions/mcp) - - -## Setup: -Install the MCP python library -``` -uv pip install mcp -``` - -## Step 1: Setting Up Your MCP Server - -First, let's create a basic MCP server using Codegen's MCP tools: - -server.py -```python -from codegen import Codebase -from mcp.server.fastmcp import FastMCP -from typing import Annotated -# Initialize the codebase -codebase = Codebase.from_repo(".") - -# create the MCP server using FastMCP -mcp = FastMCP(name="demo-mcp", instructions="Use this server for semantic search of codebases") - - -if __name__ == "__main__": - # Initialize and run the server - print("Starting demo mpc server...") - mcp.run(transport="stdio") - -``` - -## Step 2: Create the search tool - -Let's implement the semantic search tool. - -server.py -```python -from codegen.extensions.tools.semantic_search import semantic_search - -.... - -@mcp.tool('codebase_semantic_search', "search codebase with the provided query") -def search(query: Annotated[str, "search query to run against codebase"]): - codebase = Codebase("provide location to codebase", language="provide codebase Language") - # use the semantic search tool from codegen.extensions.tools OR write your own - results = semantic_search(codebase=codebase, query=query) - return results - -.... -``` - -## Run Your MCP Server - -You can run and inspect your MCP server with: - -``` -mcp dev server.py -``` - -If you'd like to integrate this into an IDE checkout out this [setup guide](/introduction/ide-usage#mcp-server-setup) - -And that's a wrap, chime in at our [community - Slack](https://community.codegen.com) if you have questions or ideas for additional MCP tools/capabilities - - ---- -title: "Neo4j Graph" -sidebarTitle: "Neo4j Graph" -icon: "database" -iconType: "solid" ---- - - - - - -# Neo4j Graph - -Codegen can export codebase graphs to Neo4j for visualization and analysis. - -## Installation -In order to use Neo4j you will need to install it and run it locally using Docker. - -### Neo4j -First, install Neo4j using the official [installation guide](https://neo4j.com/docs/desktop-manual/current/installation/download-installation/). - -### Docker -To run Neo4j locally using Docker, follow the instructions [here](https://neo4j.com/docs/apoc/current/installation/#docker). - -## Launch Neo4j Locally - -```bash -docker run \ - -p 7474:7474 -p 7687:7687 \ - -v $PWD/data:/data -v $PWD/plugins:/plugins \ - --name neo4j-apoc \ - -e NEO4J_apoc_export_file_enabled=true \ - -e NEO4J_apoc_import_file_enabled=true \ - -e NEO4J_apoc_import_file_use__neo4j__config=true \ - -e NEO4J_PLUGINS=\[\"apoc\"\] \ - neo4j:latest -``` -## Usage - -```python -from codegen import Codebase -from codegen.extensions.graph.main import visualize_codebase - -# parse codebase -codebase = Codebase("path/to/codebase") - -# export to Neo4j -visualize_codebase(codebase, "bolt://localhost:7687", "neo4j", "password") -``` - -## Visualization - -Once exported, you can open the Neo4j browser at `http://localhost:7474`, sign in with the username `neo4j` and the password `password`, and use the following Cypher queries to visualize the codebase: - -### Class Hierarchy - -```cypher -Match (s: Class )-[r: INHERITS_FROM*]-> (e:Class) RETURN s, e LIMIT 10 -``` - - - - -### Methods Defined by Each Class - -```cypher -Match (s: Class )-[r: DEFINES]-> (e:Method) RETURN s, e LIMIT 10 -``` - - - - -### Function Calls - -```cypher -Match (s: Func )-[r: CALLS]-> (e:Func) RETURN s, e LIMIT 10 -``` - - - - - -### Call Graph - -```cypher -Match path = (:(Method|Func)) -[:CALLS*5..10]-> (:(Method|Func)) -Return path -LIMIT 20 -``` - - - - - ---- -title: "Code Attributions" -sidebarTitle: "Code Attributions" -description: "Learn how to analyze code statistics and attributions using Codegen" -icon: "network-wired" -iconType: "solid" ---- - -# AI Impact Analysis - -This tutorial shows how to use Codegen's attribution extension to analyze the impact of AI on your -codebase. You'll learn how to identify which parts of your code were written by AI tools like -GitHub Copilot, Devin, or other AI assistants. - -Note: the code is flexible - you can track CI pipeline bots, or any other contributor you want. - - -## Overview - -The attribution extension analyzes git history to: - -1. Identify which symbols (functions, classes, etc.) were authored or modified by AI tools -2. Calculate the percentage of AI contributions in your codebase -3. Find high-impact AI-written code (code that many other parts depend on) -4. Track the evolution of AI contributions over time - -## Installation - -The attribution extension is included with Codegen. No additional installation is required. - -## Basic Usage - -### Running the Analysis - -You can run the AI impact analysis using the Codegen CLI: - -```bash -codegen analyze-ai-impact -``` - -Or from Python code: - -```python -from codegen import Codebase -from codegen.extensions.attribution.cli import run - -# Initialize codebase from current directory -codebase = Codebase.from_repo("your-org/your-repo", language="python") - -# Run the analysis -run(codebase) -``` - -### Understanding the Results - -The analysis will print a summary of AI contributions to your console and save detailed results to a JSON file. The summary includes: - -- List of all contributors (human and AI) -- Percentage of commits made by AI -- Number of files and symbols touched by AI -- High-impact AI-written code (code with many dependents) -- Top files by AI contribution percentage - -## Advanced Usage - -### Accessing Attribution Information - -After running the analysis, each symbol in your codebase will have attribution information attached to it: - -```python -from codegen import Codebase -from codegen.extensions.attribution.main import add_attribution_to_symbols - -# Initialize codebase -codebase = Codebase.from_repo("your-org/your-repo", language="python") - -# Add attribution information to symbols -ai_authors = ['github-actions[bot]', 'dependabot[bot]', 'copilot[bot]'] -add_attribution_to_symbols(codebase, ai_authors) - -# Access attribution information on symbols -for symbol in codebase.symbols: - if hasattr(symbol, 'is_ai_authored') and symbol.is_ai_authored: - print(f"AI-authored symbol: {symbol.name} in {symbol.filepath}") - print(f"Last editor: {symbol.last_editor}") - print(f"All editors: {symbol.editor_history}") -``` - -### Customizing AI Author Detection - -By default, the analysis looks for common AI bot names in commit authors. -You can customize this by providing your own list of AI authors: - -```python -from codegen import Codebase -from codegen.extensions.attribution.main import analyze_ai_impact - -# Initialize codebase -codebase = Codebase.from_repo("your-org/your-repo", language="python") - -# Define custom AI authors -ai_authors = [ - 'github-actions[bot]', - 'dependabot[bot]', - 'copilot[bot]', - 'devin[bot]', - 'your-custom-ai-email@example.com' -] - -# Run analysis with custom AI authors -results = analyze_ai_impact(codebase, ai_authors) -``` - -## Example: Contributor Analysis - -Here's a complete example that analyzes contributors to your codebase and their impact: - -```python -import os -from collections import Counter - -from codegen import Codebase -from codegen.extensions.attribution.main import add_attribution_to_symbols -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig -from codegen.sdk.codebase.config import ProjectConfig -from codegen.shared.enums.programming_language import ProgrammingLanguage - -def analyze_contributors(codebase): - """Analyze contributors to the codebase and their impact.""" - print("\n🔍 Contributor Analysis:") - - # Define which authors are considered AI - ai_authors = ['devin[bot]', 'codegen[bot]', 'github-actions[bot]', 'dependabot[bot]'] - - # Add attribution information to all symbols - print("Adding attribution information to symbols...") - add_attribution_to_symbols(codebase, ai_authors) - - # Collect statistics about contributors - contributor_stats = Counter() - ai_contributor_stats = Counter() - - print("Analyzing symbol attributions...") - for symbol in codebase.symbols: - if hasattr(symbol, 'last_editor') and symbol.last_editor: - contributor_stats[symbol.last_editor] += 1 - - # Track if this is an AI contributor - if any(ai in symbol.last_editor for ai in ai_authors): - ai_contributor_stats[symbol.last_editor] += 1 - - # Print top contributors overall - print("\n👥 Top Contributors by Symbols Authored:") - for contributor, count in contributor_stats.most_common(10): - is_ai = any(ai in contributor for ai in ai_authors) - ai_indicator = "🤖" if is_ai else "👤" - print(f" {ai_indicator} {contributor}: {count} symbols") - - # Print top AI contributors if any - if ai_contributor_stats: - print("\n🤖 Top AI Contributors:") - for contributor, count in ai_contributor_stats.most_common(5): - print(f" • {contributor}: {count} symbols") - -# Initialize codebase from current directory -if os.path.exists(".git"): - repo_path = os.getcwd() - repo_config = RepoConfig.from_repo_path(repo_path) - repo_operator = RepoOperator(repo_config=repo_config) - - project = ProjectConfig.from_repo_operator( - repo_operator=repo_operator, - programming_language=ProgrammingLanguage.PYTHON - ) - codebase = Codebase(projects=[project]) - - # Run the contributor analysis - analyze_contributors(codebase) -``` - -## Conclusion - -The attribution extension provides valuable insights into how AI tools are being used in your -development process. By understanding which parts of your codebase are authored by AI, you can: - -- Track the adoption of AI coding assistants in your team -- Identify areas where AI is most effective -- Ensure appropriate review of AI-generated code -- Measure the impact of AI on developer productivity diff --git a/src/codegen/sdk/topological_sort.py b/src/codegen/sdk/topological_sort.py deleted file mode 100644 index cb43c5ff9..000000000 --- a/src/codegen/sdk/topological_sort.py +++ /dev/null @@ -1,44 +0,0 @@ -import rustworkx as nx -from rustworkx import DAGHasCycle, PyDiGraph - -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -def pseudo_topological_sort(graph: PyDiGraph, flatten: bool = True): - """This will come up with an ordering of nodes within the graph respecting topological""" - try: - # Try to perform a topological sort - sorted_nodes = list(nx.topological_sort(graph)) - return sorted_nodes - except DAGHasCycle: - # If a cycle is detected, handle it separately - logger.warning("The graph contains a cycle. Performing an approximate topological sort.") - - # Find the strongly connected components in the graph - sccs = list(nx.strongly_connected_components(graph)) - - if not flatten: - return sccs - - # Create a new graph with each strongly connected component as a single node - scc_graph = nx.PyDiGraph() - for i, scc in enumerate(sccs): - scc_graph.add_node(i) - - for u, v in graph.edges(): - scc_u = next((i for i, scc in enumerate(sccs) if u in scc), None) - scc_v = next((i for i, scc in enumerate(sccs) if v in scc), None) - if scc_u is None or scc_v is None: - continue - if scc_u != scc_v: - scc_graph.add_edge(scc_u, scc_v, None) - - # Perform a topological sort on the condensed graph - sorted_sccs = list(nx.topological_sort(scc_graph)) - - # Expand the strongly connected components back to individual nodes - sorted_nodes = [node for scc_idx in sorted_sccs for node in sccs[scc_idx]] - - return sorted_nodes diff --git a/src/codegen/sdk/tree_sitter_parser.py b/src/codegen/sdk/tree_sitter_parser.py deleted file mode 100644 index c9a04c6cc..000000000 --- a/src/codegen/sdk/tree_sitter_parser.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -from os import PathLike -from pathlib import Path -from typing import Union - -import tree_sitter_javascript as ts_javascript -import tree_sitter_python as ts_python -import tree_sitter_typescript as ts_typescript -from tree_sitter import Language, Parser -from tree_sitter import Node as TSNode - -from codegen.sdk.output.utils import stylize_error - -PY_LANGUAGE = Language(ts_python.language()) -JS_LANGUAGE = Language(ts_javascript.language()) -TS_LANGUAGE = Language(ts_typescript.language_typescript()) -TSX_LANGUAGE = Language(ts_typescript.language_tsx()) - - -def to_extension(filepath_or_extension: str | PathLike) -> str: - return Path(filepath_or_extension).suffix - - -class _TreeSitterAbstraction: - """Class to facilitate loading/retrieval of the Parser object for a given language. - Should not be used directly, instead use `get_tree_sitter_parser` to get the parser for a given extension. - """ - - _instance: Union["_TreeSitterAbstraction", None] = None - # TODO: use ProgrammingLanguages enum here instead - extension_to_lang = { - # ".js": JS_LANGUAGE, - # ".jsx": JS_LANGUAGE, - # ".ts": TS_LANGUAGE, - # Use TSX for ALL JS/TS files! - ".js": TSX_LANGUAGE, - ".jsx": TSX_LANGUAGE, - ".ts": TSX_LANGUAGE, - ".tsx": TSX_LANGUAGE, - ".py": PY_LANGUAGE, - } - extension_to_parser: dict[str, Parser] = {} - - def __init__(self) -> None: - self.initialize_parsers() - - def initialize_parsers(self) -> None: - for extension, language in self.extension_to_lang.items(): - parser = Parser(language) - self.extension_to_parser[extension] = parser - - -_ts_parser_factory = _TreeSitterAbstraction() - - -def get_parser_by_filepath_or_extension(filepath_or_extension: str | PathLike = ".py") -> Parser: - extension = to_extension(filepath_or_extension) - # HACK: we do not currently use a plain text parser, so default to python for now - if extension not in _ts_parser_factory.extension_to_parser: - extension = ".py" - return _ts_parser_factory.extension_to_parser[extension] - - -def get_lang_by_filepath_or_extension(filepath_or_extension: str = ".py") -> Language: - extension = to_extension(filepath_or_extension) - # HACK: we do not currently use a plain text parser, so default to python for now - if extension not in _ts_parser_factory.extension_to_parser: - extension = ".py" - return _ts_parser_factory.extension_to_lang[extension] - - -def parse_file(filepath: PathLike, content: str) -> TSNode: - parser = get_parser_by_filepath_or_extension(filepath) - ts_node = parser.parse(bytes(content, "utf-8")).root_node - return ts_node - - -def print_errors(filepath: PathLike, content: str) -> None: - if not os.path.exists(filepath): - return - parser = get_parser_by_filepath_or_extension(filepath) - ts_node = parser.parse(bytes(content, "utf-8")).root_node - if ts_node.has_error: - - def traverse(node): - if node.is_error or node.is_missing: - stylize_error(filepath, node.start_point, node.end_point, ts_node, content, "with ts_node type of " + node.type) - if node.has_error: - for child in node.children: - traverse(child) - - traverse(ts_node) diff --git a/src/codegen/sdk/types.py b/src/codegen/sdk/types.py deleted file mode 100644 index 7f070aa0d..000000000 --- a/src/codegen/sdk/types.py +++ /dev/null @@ -1,3 +0,0 @@ -from typing import TypeAlias - -JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None diff --git a/src/codegen/sdk/typescript/__init__.py b/src/codegen/sdk/typescript/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/typescript/assignment.py b/src/codegen/sdk/typescript/assignment.py deleted file mode 100644 index 538f0e5ca..000000000 --- a/src/codegen/sdk/typescript/assignment.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.assignment import Assignment -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.expressions.multi_expression import MultiExpression -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.typescript.statements.assignment_statement import TSAssignmentStatement - - -@ts_apidoc -class TSAssignment(Assignment["TSAssignmentStatement | ExportStatement"], TSSymbol): - """A class representing TypeScript assignments, including variable declarations and property assignments. - - Handles various types of TypeScript assignments including variable declarators, assignment expressions, - augmented assignments, property signatures, and public field definitions. It provides functionality - for manipulating assignments and managing their associated types and comments. - """ - - assignment_types: list[str] = ["variable_declarator", "assignment_expression", "augmented_assignment_expression", "property_signature", "public_field_definition"] - - @noapidoc - @classmethod - def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSAssignmentStatement) -> MultiExpression[TSAssignmentStatement, TSAssignment]: - if ts_node.type not in ["assignment_expression", "augmented_assignment_expression"]: - msg = f"Unknown assignment type: {ts_node.type}" - raise ValueError(msg) - - left_node = ts_node.child_by_field_name("left") - right_node = ts_node.child_by_field_name("right") - assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node) - return MultiExpression(ts_node, file_node_id, ctx, parent, assignments) - - @classmethod - def from_named_expression(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSAssignmentStatement) -> MultiExpression[TSAssignmentStatement, TSAssignment]: - """Creates a MultiExpression object from a TypeScript named expression node. - - Constructs assignments from a TypeScript named expression node (variable declarator, public field definition, or property signature) by extracting the left (name) and right (value) nodes. - - Args: - ts_node (TSNode): The TypeScript node representing the named expression. - file_node_id (NodeId): The unique identifier for the file containing this node. - ctx (CodebaseContext): The graph representation of the codebase. - parent (Parent): The parent node containing this expression. - - Returns: - MultiExpression[Parent, TSAssignment]: A MultiExpression object containing the constructed assignments. - - Raises: - ValueError: If the node type is not one of: "variable_declarator", "public_field_definition", or "property_signature". - """ - if ts_node.type not in ["variable_declarator", "public_field_definition", "property_signature"]: - msg = f"Unknown assignment type: {ts_node.type}" - raise ValueError(msg) - - left_node = ts_node.child_by_field_name("name") - right_node = ts_node.child_by_field_name("value") - assignments = cls._from_left_and_right_nodes(ts_node, file_node_id, ctx, parent, left_node, right_node) - return MultiExpression(ts_node, file_node_id, ctx, parent, assignments) - - @writer - def set_inline_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True) -> None: - """Sets an inline comment for an assignment node. - - This method adds or updates an inline comment on the parent statement of the assignment node. - - Args: - comment (str): The comment text to set. - auto_format (bool, optional): Whether to automatically format the comment. Defaults to True. - clean_format (bool, optional): Whether to clean existing formatting. Defaults to True. - - Returns: - None - """ - super().set_inline_comment(comment, auto_format=auto_format, clean_format=clean_format, node=self.parent.ts_node) diff --git a/src/codegen/sdk/typescript/class_definition.py b/src/codegen/sdk/typescript/class_definition.py deleted file mode 100644 index 62fd0c0c8..000000000 --- a/src/codegen/sdk/typescript/class_definition.py +++ /dev/null @@ -1,228 +0,0 @@ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING, Self - -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.core.expressions.placeholder_type import PlaceholderType -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.symbol_group import SymbolGroup -from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection -from codegen.sdk.core.symbol_groups.parents import Parents -from codegen.sdk.typescript.detached_symbols.decorator import TSDecorator -from codegen.sdk.typescript.detached_symbols.parameter import TSParameter -from codegen.sdk.typescript.expressions.type import TSType -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.symbol_statement import SymbolStatement - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -@ts_apidoc -class TSClass(Class[TSFunction, TSDecorator, "TSCodeBlock", TSParameter, TSType], TSHasBlock, TSSymbol): - """A class representing a TypeScript/JavaScript class with enhanced functionality for class manipulation. - - The TSClass provides comprehensive functionality for working with TypeScript/JavaScript classes, - including handling class methods, attributes, JSX components, and inheritance relationships. - It supports operations like adding source code to class bodies, managing class attributes, - and handling React JSX components. - - Attributes: - parent_classes (Parents | None): The parent classes that this class extends or implements. - constructor_keyword (str): The keyword used to identify the constructor method. - """ - - constructor_keyword = "constructor" - """ - Representation of a Class in JavaScript/TypeScript - """ - - def __init__(self, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: SymbolStatement) -> None: - super().__init__(ts_node, file_id, ctx, parent) - if superclasses_node := self.child_by_field_types("class_heritage"): - if extends_clause := superclasses_node.child_by_field_types(["extends_clause", "implements_clause"]): - self.parent_classes = Parents(extends_clause.ts_node, self.file_node_id, self.ctx, self) - if self.constructor is not None and len(self.constructor.parameters) > 0: - self._parameters = SymbolGroup(self.file_node_id, self.ctx, self, children=self.constructor.parameters) - self.type_parameters = self.child_by_field_name("type_parameters") - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - """Adds an internal edge from itself to used symbol references within itself.""" - dest = dest or self.self_dest - # =====[ SUBCLASSING ]===== - if self.parent_classes is not None: - self.parent_classes._compute_dependencies(UsageKind.SUBCLASS, dest) - - if self.type_parameters: - self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) - # =====[ BODY IDENTIFIERS ]===== - # TODO - this breaks if there's a local variable that shadows a global variable... tough - self.code_block._compute_dependencies(usage_type, dest) - - @staticmethod - @noapidoc - def _get_name_node(ts_node: TSNode) -> TSNode | None: - """Returns the ID node from the root node of the symbol""" - if ts_node.parent and ts_node.parent.type == "pair": - return ts_node.parent.child_by_field_name("key") - return ts_node.child_by_field_name("name") - - @reader - def _parse_methods(self) -> MultiLineCollection[TSFunction, Self]: - methods = [m.symbol for m in self.code_block.symbol_statements if isinstance(m.symbol, TSFunction)] - block_node = self.code_block.ts_node - if len(block_node.children) == 2: - # If the class definition is an empty class, there is no indent - indent_size = 0 - else: - # Otherwise, the indent should match the first line that appears in the code block - indent_size = block_node.children[1].start_point[1] - if len(methods) > 0: - start_byte = methods[0].start_byte - methods[0].start_point[1] - elif len(self.code_block.statements) > 0: - start_byte = self.code_block.statements[-1].ts_node.end_byte + 2 - else: - start_byte = block_node.start_byte - block_node.start_point[1] - return MultiLineCollection( - children=methods, file_node_id=self.file_node_id, ctx=self.ctx, parent=self, node=self.code_block.ts_node, indent_size=indent_size, start_byte=start_byte, end_byte=block_node.end_byte - 1 - ) - - @property - @reader - def is_jsx(self) -> bool: - """Determine if the class is a React JSX component. - - Check if any parent class contains 'React' in its name or source. - - Returns: - bool: True if the class inherits from a React component, False otherwise. - """ - if self.parent_classes is None: - return False - - for p in self.parent_classes: - if isinstance(p, HasName): - if "React" in p.full_name: - return True - elif isinstance(p, PlaceholderType): - if "React" in p.source: - return True - for resolution in p.resolved_types: - if isinstance(resolution, ExternalModule): - if "react" in resolution.source: - return True - return False - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def add_source(self, source: str) -> None: - """Adds source code to a class body. - - Adds a block of source code to the class body. The code will be added at the end of the existing code block, - maintaining proper indentation based on the class's structure. - - Args: - source (str): The source code to be added to the class body. - - Returns: - None - """ - msg = "TODO" - raise NotImplementedError(msg) - - @writer - def add_attribute_from_source(self, source: str) -> None: - """Adds a class attribute from source code to a TypeScript/JavaScript class. - - Adds the attribute to the class in a suitable location based on the class's current structure: - after existing attributes if any exist, before methods if any exist, or in an empty class block. - - Args: - source (str): The source code of the attribute to add to the class. - - Returns: - None - """ - attributes = self.attributes - if len(attributes) > 0: - last_attribute = attributes[-1] - semi_colon = last_attribute.next_sibling - indent = " " * last_attribute.start_point[1] - semi_colon.insert_after(f"{indent}{source}", fix_indentation=False) - elif (methods := self.methods) and len(methods) > 0: - first_method = methods[0] - first_method.insert_before(f"{source}\n", fix_indentation=True) - else: - indent = " " * (4 * self.code_block.level) - self.code_block.edit(f"{{\n{indent}{source}\n}}", fix_indentation=False) - - def convert_props_to_interface(self) -> None: - """Converts React component props to TypeScript interfaces. - - For React class components, converts PropTypes declarations to a separate interface. - The interface will be named {ComponentName}Props and inserted before the component. - The component will be updated to extend React.Component with the interface type parameter. - - Handles both simple types and complex types including: - - PropTypes declarations - - Union types and optional props - - Nested object shapes - - Arrays and complex types - - Required vs optional props - - Example: - ```typescript - // Before - class Button extends React.Component { - render() { - return ; - } - } - Button.propTypes = { - text: PropTypes.string.isRequired, - onClick: PropTypes.func.isRequired - }; - - // After - interface ButtonProps { - text: string; - onClick: CallableFunction; - } - - class Button extends React.Component { - render() { - return ; - } - } - ``` - """ - if self.parent_classes and len(self.parent_classes) > 0: - react_parent = self.parent_classes[0] - if "Component" in react_parent.source: - if interface_name := self.convert_to_react_interface(): - if isinstance(react_parent, GenericType): - react_parent.parameters.insert(0, interface_name) - else: - react_parent.insert_after(f"<{interface_name}>", newline=False) - - @writer - def class_component_to_function_component(self) -> None: - """Converts a class component to a function component.""" - return self.ctx.ts_declassify.declassify(self.source, filename=os.path.basename(self.file.file_path)) diff --git a/src/codegen/sdk/typescript/config_parser.py b/src/codegen/sdk/typescript/config_parser.py deleted file mode 100644 index 0c9e8bfe7..000000000 --- a/src/codegen/sdk/typescript/config_parser.py +++ /dev/null @@ -1,63 +0,0 @@ -from pathlib import Path -from typing import TYPE_CHECKING - -from codegen.sdk.codebase.config_parser import ConfigParser -from codegen.sdk.core.file import File -from codegen.sdk.enums import NodeType -from codegen.sdk.typescript.ts_config import TSConfig - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.file import TSFile - -import os -from functools import cache - - -class TSConfigParser(ConfigParser): - # Cache of path names to TSConfig objects - config_files: dict[Path, TSConfig] - ctx: "CodebaseContext" - - def __init__(self, codebase_context: "CodebaseContext", default_config_name: str = "tsconfig.json"): - super().__init__() - self.config_files = dict() - self.ctx = codebase_context - self.default_config_name = default_config_name - - def get_config(self, config_path: os.PathLike) -> TSConfig | None: - path = self.ctx.to_absolute(config_path) - if path in self.config_files: - return self.config_files[path] - if path.exists(): - self.config_files[path] = TSConfig(File.from_content(config_path, path.read_text(), self.ctx, sync=False), self) - return self.config_files.get(path) - return None - - def parse_configs(self): - # This only yields a 0.05s speedup, but its funny writing dynamic programming code - @cache - def get_config_for_dir(dir_path: Path) -> TSConfig | None: - # Check if the config file exists in the directory - ts_config_path = dir_path / self.default_config_name - # If it does, return the config - if ts_config_path.exists(): - if ts_config := self.get_config(self.ctx.to_absolute(ts_config_path)): - self.config_files[ts_config_path] = ts_config - return ts_config - # Otherwise, check the parent directory - if dir_path.is_relative_to(self.ctx.repo_path): - return get_config_for_dir(dir_path.parent) - return None - - # Get all the files in the codebase - for file in self.ctx.get_nodes(NodeType.FILE): - file: TSFile # This should be safe because we only call this on TSFiles - # Get the config for the directory the file is in - config = get_config_for_dir(file.path.parent) - # Set the config for the file - file.ts_config = config - - # Loop through all the configs and precompute their import aliases - for config in self.config_files.values(): - config._precompute_import_aliases() diff --git a/src/codegen/sdk/typescript/detached_symbols/code_block.py b/src/codegen/sdk/typescript/detached_symbols/code_block.py deleted file mode 100644 index c6871100e..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/code_block.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.detached_symbols.code_block import CodeBlock -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.statements.statement import Statement -from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection -from codegen.sdk.extensions.utils import find_line_start_and_end_nodes -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.typescript.assignment import TSAssignment - from codegen.sdk.typescript.interfaces.has_block import TSHasBlock - - -Parent = TypeVar("Parent", bound="TSHasBlock") - - -@ts_apidoc -class TSCodeBlock(CodeBlock[Parent, "TSAssignment"], Generic[Parent]): - """Extends the CodeBlock class to provide TypeScript-specific functionality.""" - - @noapidoc - @reader - def _parse_statements(self) -> MultiLineCollection[Statement, Self]: - statements: list[Statement] = self.ctx.parser.parse_ts_statements(self.ts_node, self.file_node_id, self.ctx, self) - line_nodes = find_line_start_and_end_nodes(self.ts_node) - start_node = line_nodes[1][0] if len(line_nodes) > 1 else line_nodes[0][0] - end_node = line_nodes[-2][1] if len(line_nodes) > 1 else line_nodes[-1][1] - indent_size = start_node.start_point[1] - collection = MultiLineCollection( - children=statements, - file_node_id=self.file_node_id, - ctx=self.ctx, - parent=self, - node=self.ts_node, - indent_size=indent_size, - leading_delimiter="", - start_byte=start_node.start_byte - indent_size, - end_byte=end_node.end_byte + 1, - ) - return collection - - @reader - @noapidoc - def _get_line_starts(self) -> list[Editable]: - """Returns an ordered list of first Editable for each non-empty line within the code block""" - line_start_nodes = super()._get_line_starts() - if len(line_start_nodes) >= 3 and line_start_nodes[0].source == "{" and line_start_nodes[-1].source == "}": - # Remove the first and last line of the code block as they are opening and closing braces. - return line_start_nodes[1:-1] - return line_start_nodes - - @reader - @noapidoc - def _get_line_ends(self) -> list[Editable]: - """Returns an ordered list of last Editable for each non-empty line within the code block""" - line_end_nodes = super()._get_line_ends() - # Remove the first and last line of the code block as they are opening and closing braces. - return line_end_nodes[1:-1] - - @writer - def unwrap(self) -> None: - """Unwraps a code block by removing its opening and closing braces. - - This method removes both the opening and closing braces of a code block, including any trailing whitespace - up to the next sibling node if it exists, or up to the closing brace of the last line if no sibling exists. - This is commonly used to flatten nested code structures like if statements, with statements, and function bodies. - - Returns: - None - """ - super().unwrap() - # Also remove the closing brace of the last line. - next_sibling = self.ts_node.next_sibling - if next_sibling: - self.remove_byte_range(self.ts_node.end_byte - 1, next_sibling.start_byte) - else: - # If there is no next sibling, remove up to the closing brace of the last line - self.remove_byte_range(self._get_line_ends()[-1].end_byte, self.ts_node.end_byte) diff --git a/src/codegen/sdk/typescript/detached_symbols/decorator.py b/src/codegen/sdk/typescript/detached_symbols/decorator.py deleted file mode 100644 index 2d24c10c0..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/decorator.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.detached_symbols.decorator import Decorator -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.typescript.class_definition import TSClass - from codegen.sdk.typescript.detached_symbols.parameter import TSParameter - from codegen.sdk.typescript.function import TSFunction - - -@ts_apidoc -class TSDecorator(Decorator["TSClass", "TSFunction", "TSParameter"]): - """Abstract representation of a Decorator""" - - @reader - def _get_name_node(self) -> TSNode: - """Returns the name of the decorator.""" - for child in self.ts_node.children: - # =====[ Identifier ]===== - # Just `@dataclass` etc. - if child.type == "identifier": - return child - - # =====[ Attribute ]===== - # e.g. `@a.b` - elif child.type == "member_expression": - return child - - # =====[ Call ]===== - # e.g. `@a.b()` - elif child.type == "call_expression": - func = child.child_by_field_name("function") - return func - - msg = f"Could not find decorator name within {self.source}" - raise ValueError(msg) - - @property - @reader - def call(self) -> FunctionCall | None: - """Retrieves the function call expression associated with the decorator. - - This property checks if the decorator has a function call expression (e.g., @decorator()) and returns it as a FunctionCall object. - If the decorator is a simple identifier (e.g., @decorator), returns None. - - Returns: - FunctionCall | None: A FunctionCall object representing the decorator's call expression if present, None otherwise. - """ - if call_node := next((x for x in self.ts_node.named_children if x.type == "call_expression"), None): - return FunctionCall(call_node, self.file_node_id, self.ctx, self.parent) - return None diff --git a/src/codegen/sdk/typescript/detached_symbols/jsx/element.py b/src/codegen/sdk/typescript/detached_symbols/jsx/element.py deleted file mode 100644 index f640f543e..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/jsx/element.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import annotations - -from functools import cached_property -from typing import TYPE_CHECKING, Generic, TypeVar, override - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.expressions import Expression, Value -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.typescript.detached_symbols.jsx.prop import JSXProp -from codegen.sdk.utils import find_all_descendants -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.jsx.expression import JSXExpression - -Parent = TypeVar("Parent", bound="Editable") - - -@ts_apidoc -class JSXElement(Expression[Parent], HasName, Generic[Parent]): - """Abstract representation of TSX/JSX elements, e.g. ``. This allows for many React-specific modifications, like adding props, changing the name, etc.""" - - _name_node: Name | None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - open_tag = self.ts_node.child_by_field_name("open_tag") or self.ts_node - name_node = open_tag.child_by_field_name("name") - self._name_node = self._parse_expression(name_node, default=Name) - self.children # Force parse children of this JSX element - - @cached_property - @reader - def jsx_elements(self) -> list[JSXElement]: - """Returns a list of JSX elements nested within the current element. - - Gets all JSX elements that are descendants of this element in the syntax tree, excluding the element itself. - This includes both regular JSX elements (`...`) and self-closing elements (``). - - Args: - None - - Returns: - list[JSXElement]: A list of JSXElement objects representing all nested JSX elements. - """ - jsx_elements = [] - for node in self.extended_nodes: - jsx_element_nodes = find_all_descendants(node.ts_node, {"jsx_element", "jsx_self_closing_element"}) - jsx_elements.extend([self._parse_expression(x) for x in jsx_element_nodes if x != self.ts_node]) - return jsx_elements - - @cached_property - @reader - def expressions(self) -> list[JSXExpression]: - """Gets all JSX expressions within the JSX element. - - Retrieves all JSX expressions that are descendant nodes of the current JSX element, including expressions in child elements and attributes. - - Returns: - list[JSXExpression]: A list of JSX expression objects found within this element, excluding the current element itself. - """ - jsx_expressions = [] - for node in self.extended_nodes: - jsx_expressions_nodes = find_all_descendants(node.ts_node, {"jsx_expression"}) - jsx_expressions.extend([self._parse_expression(x) for x in jsx_expressions_nodes if x != self.ts_node]) - return jsx_expressions - - @property - @noapidoc - @reader - def _attribute_nodes(self) -> list[Editable]: - """Returns all attribute nodes of the element""" - open_tag = self.ts_node.child_by_field_name("open_tag") or self.ts_node - attribute_nodes = open_tag.children_by_field_name("attribute") - return [Value(x, self.file_node_id, self.ctx, self) for x in attribute_nodes] - - @property - @reader - def props(self) -> list[JSXProp]: - """Retrieves all JSXProps (attributes) from a JSX element. - - Gets all props (attributes) on the current JSX element. For example, in ``, this would return a list with one JSXProp object representing `prop1="value"`. - - Args: - self: The JSXElement instance. - - Returns: - list[JSXProp]: A list of JSXProp objects representing each attribute on the element. - """ - return [self._parse_expression(x.ts_node, default=JSXProp) for x in self._attribute_nodes] - - @reader - def get_prop(self, name: str) -> JSXProp | None: - """Returns the JSXProp with the given name from the JSXElement. - - Searches through the element's props to find a prop with a matching name. - - Args: - name (str): The name of the prop to find. - - Returns: - JSXProp | None: The matching JSXProp object if found, None if not found. - """ - for prop in self.props: - if prop.name == name: - return prop - return None - - @property - def attributes(self) -> list[JSXProp]: - """Returns all JSXProp on this JSXElement, an alias for JSXElement.props. - - Returns all JSXProp attributes (props) on this JSXElement. For example, for a JSX element like - ``, this would return a list containing one JSXProp object. - - Returns: - list[JSXProp]: A list of JSXProp objects representing each attribute/prop on the JSXElement. - """ - return [self._parse_expression(x.ts_node, default=JSXProp) for x in self._attribute_nodes] - - @writer - def set_name(self, name: str) -> None: - """Sets the name of a JSXElement by modifying both opening and closing tags. - - Updates the name of a JSX element, affecting both self-closing tags (``) and elements with closing tags (``). - - Args: - name (str): The new name to set for the JSX element. - - Returns: - None: The method modifies the JSXElement in place. - """ - # This should correctly set the name of both the opening and closing tags - if open_tag := self.ts_node.child_by_field_name("open_tag"): - name_node = self._parse_expression(open_tag.child_by_field_name("name"), default=Name) - name_node.edit(name) - if close_tag := self.ts_node.child_by_field_name("close_tag"): - name_node = self._parse_expression(close_tag.child_by_field_name("name"), default=Name) - name_node.edit(name) - else: - # If the element is self-closing, we only need to edit the name of the element - super().set_name(name) - - @writer - def add_prop(self, prop_name: str, prop_value: str) -> None: - """Adds a new prop to a JSXElement. - - Adds a prop with the specified name and value to the JSXElement. If the element already has props, - the new prop is added after the last existing prop. If the element has no props, the new prop is - added immediately after the element name. - - Args: - prop_name (str): The name of the prop to add. - prop_value (str): The value of the prop to add. - - Returns: - None - """ - if len(self.props) > 0: - last_prop = self.props[-1] - # Extra padding is handled by the insert_after method on prop - last_prop.insert_after(f"{prop_name}={prop_value}", newline=False) - else: - self._name_node.insert_after(f" {prop_name}={prop_value}", newline=False) - - @property - @reader - @noapidoc - def _source(self): - """Text representation of the Editable instance""" - return self.ts_node.text.decode("utf-8").strip() - - @writer - def wrap(self, opening_tag: str, closing_tag: str) -> None: - """Wraps the current JSXElement with the provided opening and closing tags, properly handling indentation. - - Args: - opening_tag (str): The opening JSX tag to wrap around the current element (e.g. `
`) - closing_tag (str): The closing JSX tag to wrap around the current element (e.g. `
`) - """ - current_source = self.source - indented_source = "\n".join(f" {line.rstrip()}" for line in current_source.split("\n")) - new_source = f"{opening_tag}\n{indented_source}\n{closing_tag}" - self.edit(new_source, fix_indentation=True) - - @commiter - @noapidoc - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - for node in self.children: - node._compute_dependencies(usage_type, dest=dest) diff --git a/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py b/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py deleted file mode 100644 index 51f44e846..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py +++ /dev/null @@ -1,79 +0,0 @@ -from functools import cached_property -from typing import TYPE_CHECKING, Self, override - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.unwrappable import Unwrappable -from codegen.sdk.extensions.autocommit import commiter -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.function import Function - from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement - from codegen.sdk.typescript.detached_symbols.jsx.prop import JSXProp - - -@ts_apidoc -class JSXExpression(Unwrappable["Function | JSXElement | JSXProp"]): - """Abstract representation of TSX/JSX expression""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.statement - - @cached_property - @reader - def statement(self) -> Editable[Self] | None: - """Returns the editable component of this JSX expression. - - Retrieves the editable contained within this JSX expression by accessing the second child node. Returns None if the JSX expression doesn't - contain an editable object. - - Returns: - Editable[Self]: A Editable object representing the statement of this JSX expression. None if the object doesn't have an Editable object. - """ - return self._parse_expression(self.ts_node.named_children[0]) if len(self.ts_node.named_children) > 0 else None - - @commiter - @noapidoc - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - if self.statement: - self.statement._compute_dependencies(usage_type, dest=dest) - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable) -> None: - """Simplifies a JSX expression by reducing it based on a boolean condition. - - - Args: - bool_condition (bool): The boolean value to reduce the condition to. - - """ - if self.ts_node.parent.type == "jsx_attribute" and not bool_condition: - node.edit(self.ctx.node_classes.bool_conversion[bool_condition]) - else: - self.remove() - - @writer - @override - def unwrap(self, node: Expression | None = None) -> None: - """Removes the brackets from a JSX expression. - - - Returns: - None - """ - from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement - from codegen.sdk.typescript.detached_symbols.jsx.prop import JSXProp - - if node is None: - node = self - if isinstance(self.parent, JSXProp): - return - if isinstance(node, JSXExpression | JSXElement | JSXProp): - for child in self._anonymous_children: - child.remove() diff --git a/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py b/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py deleted file mode 100644 index 8732cd8dc..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import TYPE_CHECKING, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.codebase_context import CodebaseContext -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.typescript.detached_symbols.jsx.expression import JSXExpression -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.function import Function - from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement - - -@ts_apidoc -class JSXProp(Expression["Function | JSXElement | JSXProp"], HasName, HasValue): - """Abstract representation of TSX/JSX prop, e.g .""" - - _name_node: Name | None - _expression_node: JSXExpression | None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: "Function | JSXElement | JSXProp") -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self._name_node = self._parse_expression(self.ts_node.children[0], default=Name) - if len(self.ts_node.children) > 2: - self._value_node = self._parse_expression(self.ts_node.children[2]) - if self._value_node.ts_node.type == "jsx_expression": - self._expression_node = self._parse_expression(self._value_node.ts_node) - else: - self._expression_node = None - else: - # If there is no value node, then the prop is a boolean prop - # For example, is equivalent to - self._value_node = None - - @property - @reader - def expression(self) -> JSXExpression | None: - """Retrieves the JSX expression associated with this JSX prop. - - Returns the JSX expression node if this prop has one, e.g., for props like prop={expression}. - For boolean props or string literal props, returns None. - - Returns: - JSXExpression | None: The JSX expression node if present, None otherwise. - """ - return self._expression_node - - @writer - def insert_after( - self, - new_src: str, - fix_indentation: bool = False, - newline: bool = True, - priority: int = 0, - dedupe: bool = True, - ) -> None: - """Inserts source code after a JSX prop in a TypeScript/JSX file. - - Inserts the provided source code after the current JSX prop, adding necessary spacing. - - Args: - new_src (str): The source code to insert after the prop. - fix_indentation (bool, optional): Whether to fix the indentation of the inserted code. Defaults to False. - newline (bool, optional): Whether to add a newline after the inserted code. Defaults to True. - priority (int, optional): The priority of the insertion. Defaults to 0. - dedupe (bool, optional): Whether to prevent duplicate insertions. Defaults to True. - - Returns: - None - """ - # TODO: This may not be transaction save with adds and deletes - # Insert space after the prop name - super().insert_after(" " + new_src, fix_indentation, newline, priority, dedupe) - - @writer - def insert_before( - self, - new_src: str, - fix_indentation: bool = False, - newline: bool = True, - priority: int = 0, - dedupe: bool = True, - ) -> None: - """Insert a new source code string before a JSX prop in a React component. - - Inserts a new string of source code before a JSX prop, maintaining proper spacing. - Automatically adds a trailing space after the inserted code. - - Args: - new_src (str): The source code string to insert before the prop. - fix_indentation (bool, optional): Whether to adjust the indentation of the inserted code. Defaults to False. - newline (bool, optional): Whether to add a newline after the inserted code. Defaults to True. - priority (int, optional): Priority of this insertion relative to others. Defaults to 0. - dedupe (bool, optional): Whether to avoid duplicate insertions. Defaults to True. - - Returns: - None - """ - # TODO: This may not be transaction save with adds and deletes - # Insert space before the prop name - super().insert_before(new_src + " ", fix_indentation, newline, priority, dedupe) - - @commiter - @noapidoc - @override - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - for node in self.children: - node._compute_dependencies(usage_type, dest=dest) diff --git a/src/codegen/sdk/typescript/detached_symbols/parameter.py b/src/codegen/sdk/typescript/detached_symbols/parameter.py deleted file mode 100644 index 3a3a67dae..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/parameter.py +++ /dev/null @@ -1,187 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, override - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.autocommit.decorators import writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.detached_symbols.parameter import Parameter -from codegen.sdk.core.expressions.union_type import UnionType -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.typescript.expressions.object_type import TSObjectType -from codegen.sdk.typescript.expressions.type import TSType -from codegen.sdk.typescript.symbol_groups.dict import TSPair -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.placeholder.placeholder import Placeholder - from codegen.sdk.typescript.function import TSFunction - - -@ts_apidoc -class TSParameter(Parameter[TSType, Collection["TSParameter", "TSFunction"]]): - """A class representing a TypeScript function parameter with extensive type analysis capabilities. - - This class provides functionality to inspect and manipulate TypeScript function parameters, - including support for destructured parameters, optional parameters, variadic parameters, - default values, and type annotations. - - Attributes: - type (TSType): The TypeScript type annotation of the parameter. - """ - - def __init__(self, ts_node: TSNode, index: int, parent: TSFunction, type: TSType | Placeholder | None = None) -> None: - super().__init__(ts_node, index, parent) - if not self.type and type is not None: - self.type = type # Destructured types - - @property - @reader - def is_destructured(self) -> bool: - """Determines if a parameter is part of an object destructuring pattern. - - Checks the parameter's tree-sitter node type to determine if it represents a destructured parameter. - A parameter is considered destructured if it appears within an object destructuring pattern. - - Returns: - bool: True if the parameter is destructured, False otherwise. - """ - return self.ts_node.type in ("shorthand_property_identifier_pattern", "object_assignment_pattern") - - @property - @reader - def is_optional(self) -> bool: - """Determines if a parameter is marked as optional in TypeScript. - - Checks whether a parameter is marked with the '?' syntax in TypeScript, indicating that it is optional. - If the parameter is part of a destructured pattern, this function returns False as optionality is - handled at the function level for destructured parameters. - - Returns: - bool: True if the parameter is marked as optional, False otherwise. - """ - if self.is_destructured: - # In this case, individual destructured parameters are not marked as optional - # The entire object might be optional, but that's handled at the function level - return False - else: - return self.ts_node.type == "optional_parameter" - - @property - @reader - def is_variadic(self) -> bool: - """Determines if a parameter is variadic (using the rest operator). - - A property that checks if the parameter uses the rest pattern (e.g., ...args in TypeScript), - which allows the parameter to accept an arbitrary number of arguments. - - Returns: - bool: True if the parameter is variadic (uses rest pattern), False otherwise. - """ - pattern = self.ts_node.child_by_field_name("pattern") - return pattern is not None and pattern.type == "rest_pattern" - - @property - @reader - def default(self) -> str | None: - """Returns the default value of a parameter. - - Retrieves the default value of a parameter, handling both destructured and non-destructured parameters. - For destructured parameters, returns the default value if it's an object assignment pattern. - For non-destructured parameters, returns the value specified after the '=' sign. - - Returns: - str | None: The default value of the parameter as a string if it exists, None otherwise. - """ - # =====[ Destructured ]===== - if self.is_destructured: - if self.ts_node.type == "object_assignment_pattern": - return self.ts_node.children[-1].text.decode("utf-8") - else: - return None - - # =====[ Not destructured ]===== - default_node = self.ts_node.child_by_field_name("value") - if default_node is None: - return None - return default_node.text.decode("utf-8") - - @noapidoc - @commiter - @override - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.type: - if not (self.is_destructured and self.index > 0): - self.type._compute_dependencies(UsageKind.TYPE_ANNOTATION, dest or self.parent.self_dest) - if self.value: - self.value._compute_dependencies(UsageKind.DEFAULT_VALUE, dest or self.parent.self_dest) - - @writer - def convert_to_interface(self) -> None: - """Converts a parameter's inline type definition to an interface. - - For React components, converts inline props type definitions to a separate interface. - Handles both simple types and complex types including generics, extends patterns, and union types. - The interface will be named {ComponentName}Props and inserted before the component. - Supports extracting types from destructured parameters and preserves any type parameters. - - Example: - ```typescript - // Before - function Button(props: { text: string, onClick: () => void }) { - return ; - } - - // After - interface ButtonProps { - text: string; - onClick: () => void; - } - function Button(props: ButtonProps) { - return ; - } - ``` - """ - if not self.type or not self.parent_function.is_jsx or not isinstance(self.type, TSObjectType | UnionType): - return - - # # Get the type definition and component name - # type_def = self.type.source - component_name = self.parent_function.name - - # # Handle extends pattern - extends_clause: str = "" - - type = self.type - if isinstance(type, UnionType): - for subtype in type: - if isinstance(subtype, TSObjectType): - type = subtype - else: - extends_clause += f" extends {subtype.source}" - - # # Extract generic type parameters if present - generic_params = "" - if self.parent_function.type_parameters: - generic_params = self.parent_function.type_parameters.source - interface_name = f"{component_name}Props" - # # Update parameter type to use interface - if generic_params: - interface_name += generic_params - - # # Convert type definition to interface - interface_def = f"interface {interface_name}{extends_clause} {{\n" - - # Strip outer braces and convert to semicolon-separated lines - for value in type.values(): - interface_def += f" {value.parent_of_type(TSPair).source.rstrip(',')};\n" - interface_def += "}" - - # Insert interface before the function - self.parent_function.insert_before(interface_def + "\n") - - self.type.edit(interface_name) diff --git a/src/codegen/sdk/typescript/detached_symbols/promise_chain.py b/src/codegen/sdk/typescript/detached_symbols/promise_chain.py deleted file mode 100644 index 1e523bb17..000000000 --- a/src/codegen/sdk/typescript/detached_symbols/promise_chain.py +++ /dev/null @@ -1,559 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.expressions import Name -from codegen.sdk.core.statements.statement import StatementType - -if TYPE_CHECKING: - from codegen.sdk.core.class_definition import Class - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection - from codegen.sdk.typescript.function import TSFunction - - -class TSPromiseChain: - """A class representing a TypeScript Promise chain. - - This class parses and handles Promise chains in TypeScript code, including .then(), .catch(), and .finally() chains. - It provides functionality to convert Promise chains to async/await syntax. - """ - - base_chain: list[FunctionCall | Name] - then_chain: list[FunctionCall] - catch_call: FunctionCall | None - finally_call: FunctionCall | None - after_promise_chain: list[FunctionCall | Name] - base_attribute: Name - parent_statement: Statement - parent_function: FunctionCall - parent_class: Class - declared_vars: set[str] - base_indent: str - name: str | None - log_statements: list[str] = ["console.error", "console.warn", "console.log"] - - def __init__(self, attribute_chain: list[FunctionCall | Name]) -> None: - """Initialize a TSPromiseChain instance. - - Args: - attribute_chain: A list of function calls or a Name object representing the Promise chain - """ - # Parse the chain and assign all attributes - (self.base_chain, self.then_chain, self.catch_call, self.finally_call, self.after_promise_chain) = self._parse_chain(attribute_chain) - - self.base_attribute = self.base_chain[-1].parent.object - self.parent_statement = self.base_chain[0].parent_statement - self.parent_function = self.parent_statement.parent_function - self.parent_class = self.parent_statement.parent_class - self.declared_vars = set() - self.base_indent = " " * self.parent_statement._get_indent() - self.name = self.base_chain[0].source if isinstance(self.base_chain[0], Name) else self.base_chain[0].name - - @reader - def _parse_chain(self, attribute_chain: list[FunctionCall | Name]) -> tuple[list[FunctionCall], list[FunctionCall], FunctionCall | None, FunctionCall | None, list[FunctionCall | Name]]: - """Parse the Promise chain into its component parts. - - Args: - attribute_chain: The chain of function calls to parse - - Returns: - A tuple containing: - - base_chain: Initial function calls - - then_chain: .then() calls - - catch_call: .catch() call if present - - finally_call: .finally() call if present - - after_promise_chain: Calls after the Promise chain - """ - base_chain: list[FunctionCall | Name] = [] - then_chain: list[FunctionCall] = [] - catch_call: FunctionCall | None = None - finally_call: FunctionCall | None = None - after_promise_chain: list[FunctionCall | Name] = [] - - in_then_chain: bool = False - promise_chain_ended: bool = False - - for attribute in attribute_chain: - if not isinstance(attribute, Name): - if attribute.name == "then": - in_then_chain = True - then_chain.append(attribute) - elif attribute.name == "catch": - catch_call = attribute - in_then_chain = False - elif attribute.name == "finally": - finally_call = attribute - in_then_chain = False - promise_chain_ended = True - else: - if promise_chain_ended: - after_promise_chain.append(attribute) - elif in_then_chain: - then_chain.append(attribute) - else: - base_chain.append(attribute) - else: - if promise_chain_ended: - after_promise_chain.append(attribute) - elif in_then_chain: - then_chain.append(attribute) - else: - base_chain.append(attribute) - - return base_chain, then_chain, catch_call, finally_call, after_promise_chain - - @property - @reader - def is_return_statement(self) -> bool: - """Check if the parent statement is a return statement. - - Returns: - bool: True if the parent statement is a return statement - """ - return self.parent_statement.statement_type == StatementType.RETURN_STATEMENT - - @property - @reader - def assigned_var(self) -> str | None: - """Get the variable being assigned to in an assignment statement. - - Returns: - Optional[str]: The name of the variable being assigned to, or None if not an assignment - """ - if self.parent_statement.statement_type == StatementType.ASSIGNMENT: - return self.parent_statement.left - - @reader - def get_next_call_params(self, call: FunctionCall | None) -> list[str]: - from codegen.sdk.typescript.function import TSFunction - - """Get parameters from the next then/catch/finally call. - - Args: - call: The function call to extract parameters from - - Returns: - List[str]: List of parameter names from the call - """ - # handling the .then in parameter function - if call and len(call.args) > 0 and isinstance(call.args[0].value, TSFunction): - return [p.source for p in call.args[0].value.parameters] - - return [] - - @reader - def _needs_anonymous_function(self, arrow_fn: TSFunction) -> bool: - """Determine if we need to use an anonymous function wrapper. - - Returns True if: - 1. There are multiple return statements - 2. The code block has complex control flow (if/else, loops, etc) - - Args: - arrow_fn: The arrow function to analyze - - Returns: - bool: True if an anonymous function wrapper is needed - """ - statements = arrow_fn.code_block.get_statements() - return_count = sum(1 for stmt in statements if stmt.statement_type == StatementType.RETURN_STATEMENT) - return return_count > 1 - - @reader - def format_param_assignment(self, params: list[str], base_expr: str, declare: bool = True) -> str: - """Format parameter assignment with proper let declaration if needed. - - Args: - params: List of parameter names to assign - base_expr: The base expression to assign from - declare: Whether to declare new variables with 'let' - - Returns: - str: Formatted parameter assignment string - """ - if not params: - return base_expr - - if len(params) > 1: - param_str = ", ".join(params) - if declare and not any(p in self.declared_vars for p in params): - self.declared_vars.update(params) - return f"let [{param_str}] = {base_expr}" - return f"[{param_str}] = {base_expr}" - else: - param = params[0] - if declare and param not in self.declared_vars: - self.declared_vars.add(param) - return f"let {param} = {base_expr}" - return f"{param} = {base_expr}" - - @reader - def handle_base_call(self) -> str: - """Format the base promise call. - - Returns: - str: Formatted base call string - """ - new_handle = None - if "await" not in self.base_attribute.extended_source: - new_handle = f"await {self.base_attribute.extended_source};" - else: - new_handle = f"{self.base_attribute.extended_source};" - - next_params = self.get_next_call_params(self.then_chain[0]) - if next_params: - new_handle = self.format_param_assignment(next_params, new_handle) - return new_handle - - @reader - def handle_then_block(self, call: FunctionCall, next_call: FunctionCall | None = None) -> str: - from codegen.sdk.typescript.function import TSFunction - - """Format a then block in the promise chain. - - Args: - call: The then call to format - next_call: The next function call in the chain, if any - - Returns: - str: Formatted then block code - """ - # a then block must have a callback handler - if not call or call.name != "then" or len(call.args) != 1: - msg = "Invalid then call provided" - raise Exception(msg) - - arrow_fn = call.args[0].value - if not isinstance(arrow_fn, TSFunction): - msg = "callback function not provided in the argument" - raise Exception(msg) - - statements = arrow_fn.code_block.statements - - formatted_statements = [] - - # adds anonymous function if then block handler has ambiguous returns - if self._needs_anonymous_function(arrow_fn): - anon_block = self._format_anonymous_function(arrow_fn, next_call) - formatted_statements.append(f"{self.base_indent}{anon_block}") - - elif self._is_implicit_return(arrow_fn): - implicit_block = self._handle_last_block_implicit_return(statements, is_catch=False) - formatted_statements.append(f"{self.base_indent}{implicit_block}") - else: - for stmt in statements: - if stmt.statement_type == StatementType.RETURN_STATEMENT: - return_value = stmt.source[7:].strip() - next_params = self.get_next_call_params(next_call) - await_expression = f"await {return_value}" - if next_params: - formatted_statements.append(f"{self.base_indent}{self.format_param_assignment(next_params, await_expression, declare=True)}") - else: - formatted_statements.append(f"{self.base_indent}{await_expression}") - else: - formatted_statements.append(f"{self.base_indent}{stmt.source.strip()}") - - return "\n".join(formatted_statements) - - @reader - def parse_last_then_block(self, call: FunctionCall, assignment_variable_name: str | None = None) -> str: - from codegen.sdk.typescript.function import TSFunction - - """Parse the last .then() block in the chain. - - Args: - call: The last .then() call to parse - assignment_variable_name: Optional custom variable name for assignment - - Returns: - str: Formatted code for the last .then() block - """ - arrow_fn = call.args[0].value - - if not isinstance(arrow_fn, TSFunction): - msg = "callback function not provided in the argument" - raise Exception(msg) - - statements = arrow_fn.code_block.statements - - if self._needs_anonymous_function(arrow_fn): - return self._format_anonymous_function(arrow_fn, assignment_variable_name=assignment_variable_name) - - if self._is_implicit_return(arrow_fn): - return self._handle_last_block_implicit_return(statements, assignment_variable_name=assignment_variable_name) - else: - formatted_statements = [] - for stmt in statements: - if stmt.statement_type == StatementType.RETURN_STATEMENT: - return_value = self._handle_last_block_normal_return(stmt, assignment_variable_name=assignment_variable_name) - formatted_statements.append(return_value) - else: - formatted_statements.append(stmt.source.strip()) - return "\n".join(formatted_statements) - - @reader - def _handle_last_block_normal_return(self, stmt: Statement, is_catch: bool = False, assignment_variable_name: str | None = None) -> str: - """Handle a normal return statement in the last block of a Promise chain. - - Args: - stmt: The return statement to handle - is_catch: Whether this is in a catch block - assignment_variable_name: Optional custom variable name for assignment - - Returns: - str: Formatted return statement code - """ - return_value = stmt.source[7:].strip() # Remove 'return ' prefix - - var_name = assignment_variable_name if assignment_variable_name else self.assigned_var - if var_name: - return self.format_param_assignment([var_name], return_value) - elif self.is_return_statement: - if is_catch: - return f"throw {return_value}" - else: - return f"return {return_value}" - else: - if is_catch: - return f"throw {return_value}" - else: - return f"await {return_value}" - - @reader - def _handle_last_block_implicit_return(self, statements: MultiLineCollection[Statement], is_catch: bool = False, assignment_variable_name: str | None = None) -> str: - """Handle an implicit return in the last block of a Promise chain. - - Args: - statements: The statements in the block - is_catch: Whether this is in a catch block - assignment_variable_name: Optional custom variable name for assignment - - Returns: - str: Formatted implicit return code - """ - stmt_source = statements[0].source.strip() - var_name = assignment_variable_name if assignment_variable_name else self.assigned_var - - if any(stmt_source.startswith(console_method) for console_method in self.log_statements): - return stmt_source + ";" - elif is_catch: - return "throw " + stmt_source + ";" - elif var_name: - return self.format_param_assignment([var_name], stmt_source) - elif self.is_return_statement: - return "return " + stmt_source + ";" - else: - return "await " + stmt_source + ";" - - @reader - def handle_catch_block(self, call: FunctionCall, assignment_variable_name: str | None = None) -> str: - """Handle catch block in the promise chain. - - Args: - call: The catch function call to handle - assignment_variable_name: Optional custom variable name for assignment - - Returns: - str: Formatted catch block code - """ - # a catch block must have a callback handler - if not call or call.name != "catch" or len(call.args) != 1: - msg = "Invalid catch call provided" - raise Exception(msg) - - arrow_fn = call.args[0].value - statements = arrow_fn.code_block.statements - if len(arrow_fn.parameters) > 0: - error_param = arrow_fn.parameters[0].source - else: - error_param = "" - - formatted_statements = [f"{self.base_indent}}} catch({error_param}: any) {{"] - - # adds annonymous function if catch block handler has ambiguous returns - if self._needs_anonymous_function(arrow_fn): - anon_block = self._format_anonymous_function(arrow_fn, assignment_variable_name=assignment_variable_name) - formatted_statements.append(f"{self.base_indent}{anon_block}") - - elif self._is_implicit_return(arrow_fn): - implicit_block = self._handle_last_block_implicit_return(statements, is_catch=True, assignment_variable_name=assignment_variable_name) - formatted_statements.append(f"{self.base_indent}{implicit_block}") - else: - for stmt in statements: - if stmt.statement_type == StatementType.RETURN_STATEMENT: - return_block = self._handle_last_block_normal_return(stmt, is_catch=True, assignment_variable_name=assignment_variable_name) - formatted_statements.append(f"{self.base_indent}{return_block}") - else: - formatted_statements.append(f"{self.base_indent}{stmt.source.strip()}") - - return "\n".join(formatted_statements) - - @reader - def handle_finally_block(self, call: FunctionCall) -> str: - """Handle finally block in the promise chain. - - Args: - call: The finally function call to handle - - Returns: - str: Formatted finally block code - """ - if not call or call.name != "finally": - msg = "Invalid finally call provided" - raise Exception(msg) - - arrow_fn = call.args[0].value - statements = arrow_fn.code_block.statements - - formatted_statements = [f"{self.base_indent}}} finally {{"] - - for stmt in statements: - formatted_statements.append(f"{self.base_indent}{stmt.source.strip()}") - - return "\n".join(formatted_statements) - - @writer - def convert_to_async_await(self, assignment_variable_name: str | None = None, inplace_edit: bool = True) -> str | None: - """Convert the promise chain to async/await syntax. - - Args: - assignment_variable_name: Optional custom variable name for assignment - inplace_edit: If set to true, will call statement.edit(); else will return a string of the new code - - Returns: - Optional[str]: The converted async/await code - """ - # check if promise expression needs to be wrapped in a try/catch/finally block - needs_wrapping = self.has_catch_call or self.has_finally_call - formatted_blocks = [] - - if needs_wrapping: - formatted_blocks.append(f"\n{self.base_indent}try {{") - - base_call = self.handle_base_call() - formatted_blocks.append(f"{self.base_indent}{base_call}") - - for idx, then_call in enumerate(self.then_chain): - is_last_then = idx == len(self.then_chain) - 1 - - # if it's the last then block, then parse differently - if is_last_then: - formatted_block = self.parse_last_then_block(then_call, assignment_variable_name=assignment_variable_name) - else: - next_call = self.then_chain[idx + 1] if idx + 1 < len(self.then_chain) else None - formatted_block = self.handle_then_block(then_call, next_call) - formatted_blocks.append(f"{self.base_indent}{formatted_block}") - - if self.catch_call: - catch_block = self.handle_catch_block(self.catch_call, assignment_variable_name=assignment_variable_name) - formatted_blocks.append(catch_block) - - if self.finally_call: - finally_block = self.handle_finally_block(self.finally_call) - formatted_blocks.append(finally_block) - - if needs_wrapping: - formatted_blocks.append(f"{self.base_indent}}}") - - if self.parent_statement.parent_function: - self.parent_statement.parent_function.asyncify() - - diff_changes = "\n".join(formatted_blocks) - if inplace_edit: - self.parent_statement.edit(diff_changes) - else: - return diff_changes - - @reader - def _is_implicit_return(self, arrow_fn: TSFunction) -> bool: - """Check if an arrow function has an implicit return. - - An implicit return occurs when: - 1. The function has exactly one statement - 2. The statement is not a comment - 3. The function body is not wrapped in curly braces - - Args: - arrow_fn: The arrow function to check - - Returns: - bool: True if the function has an implicit return - """ - statements = arrow_fn.code_block.statements - if len(statements) != 1: - return False - - stmt = statements[0] - return not stmt.statement_type == StatementType.COMMENT and not arrow_fn.code_block.source.strip().startswith("{") - - @reader - def _format_anonymous_function(self, arrow_fn: TSFunction, next_call: FunctionCall | None = None, assignment_variable_name: str | None = None) -> str: - """Format an arrow function as an anonymous async function. - - Args: - arrow_fn: The arrow function to format - next_call: The next function call in the chain, if any - assignment_variable_name: Optional custom variable name for assignment - - Returns: - str: Formatted anonymous function code - """ - params = arrow_fn.parameters - params_str = ", ".join(p.source for p in params) if params else "" - lines = [] - - var_name = assignment_variable_name if assignment_variable_name else self.assigned_var - - if next_call and next_call.name == "then": - next_params = self.get_next_call_params(next_call) - if next_params: - lines.append(f"{self.base_indent}{self.format_param_assignment(next_params, f'await (async ({params_str}) => {{', declare=True)}") - else: - prefix = "" - if self.is_return_statement: - prefix = "return " - elif var_name: - prefix = f"{var_name} = " - lines.append(f"{self.base_indent}{prefix}await (async ({params_str}) => {{") - - code_block = arrow_fn.code_block - block_content = code_block.source.strip() - if block_content.startswith("{"): - block_content = block_content[1:] - if block_content.endswith("}"): - block_content = block_content[:-1] - - block_lines = block_content.split("\n") - for line in block_lines: - if line.strip(): - lines.append(f"{self.base_indent} {line.strip()}") - - if params_str: - lines.append(f"{self.base_indent}}})({params_str});") - else: - lines.append(f"{self.base_indent}}})();") - - return "\n".join(lines) - - @property - @reader - def has_catch_call(self) -> bool: - """Check if the Promise chain has a catch call. - - Returns: - bool: True if there is a catch call - """ - return self.catch_call is not None - - @property - @reader - def has_finally_call(self) -> bool: - """Check if the Promise chain has a finally call. - - Returns: - bool: True if there is a finally call - """ - return self.finally_call is not None diff --git a/src/codegen/sdk/typescript/enum_definition.py b/src/codegen/sdk/typescript/enum_definition.py deleted file mode 100644 index faacc6e32..000000000 --- a/src/codegen/sdk/typescript/enum_definition.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Self, TypeVar, override - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.enums import SymbolType -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.statements.attribute import TSAttribute -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - -Parent = TypeVar("Parent", bound="TSHasBlock") - - -@ts_apidoc -class TSEnum(TSHasBlock, TSSymbol, HasAttribute[TSAttribute]): - """Representation of an Enum in TypeScript. - - Attributes: - symbol_type: The type of symbol, set to SymbolType.Enum. - body: The expression representing the body of the enum. - code_block: The code block associated with the enum. - """ - - symbol_type = SymbolType.Enum - body: Expression[Self] - code_block: TSCodeBlock - - def __init__( - self, - ts_node: TSNode, - file_id: NodeId, - ctx: CodebaseContext, - parent: Statement[CodeBlock[Parent, ...]], - ) -> None: - name_node = ts_node.child_by_field_name("name") - super().__init__(ts_node, file_id, ctx, parent, name_node=name_node) - self.body = self._parse_expression(ts_node.child_by_field_name("body")) - - @property - @reader - def attributes(self) -> list[TSAttribute[Self, None]]: - """Property that retrieves the attributes of a TypeScript enum. - - Returns the list of attributes defined within the enum's code block. - - Returns: - list[TSAttribute[Self, None]]: List of TSAttribute objects representing the enum's attributes. - """ - return self.code_block.attributes - - @reader - def get_attribute(self, name: str) -> TSAttribute | None: - """Returns an attribute from the TypeScript enum by its name. - - Args: - name (str): The name of the attribute to retrieve. - - Returns: - TSAttribute | None: The attribute with the given name if it exists, None otherwise. - """ - return next((x for x in self.attributes if x.name == name), None) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: - dest = dest or self.self_dest - self.body._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - return super().descendant_symbols + self.body.descendant_symbols - - @noapidoc - @reader - @override - def resolve_attribute(self, name: str) -> TSAttribute | None: - return self.get_attribute(name) - - @staticmethod - @noapidoc - def _get_name_node(ts_node: TSNode) -> TSNode | None: - if ts_node.type == "enum_declaration": - return ts_node.child_by_field_name("name") - return None diff --git a/src/codegen/sdk/typescript/enums.py b/src/codegen/sdk/typescript/enums.py deleted file mode 100644 index ce101ec60..000000000 --- a/src/codegen/sdk/typescript/enums.py +++ /dev/null @@ -1,36 +0,0 @@ -from enum import StrEnum - - -class TSFunctionTypeNames(StrEnum): - # const a = function functionExpression(): void { - # console.log("This is a regular function expression"); - # }; - FunctionExpression = "function_expression" - - # let arrowFunction = (x,y) => { x + y }; - ArrowFunction = "arrow_function" - - # function* generatorFunctionDeclaration(): Generator { - # yield 1; - # } - GeneratorFunctionDeclaration = "generator_function_declaration" - - # const a = function* generatorFunction(): Generator { - # yield 1; - # }; - GeneratorFunction = "generator_function" - - # function functionDeclaration(name: string): string { - # return `Hello, ${name}!`; - # } - FunctionDeclaration = "function_declaration" - - # class Example { - # methodDefinition(): void { - # console.log("This is a method definition"); - # } - # } - MethodDefinition = "method_definition" - - # Decorated methods (assuming decorators are supported in your JavaScript/TypeScript parser) - DecoratedMethodDefinition = "decorated_method_definition" diff --git a/src/codegen/sdk/typescript/export.py b/src/codegen/sdk/typescript/export.py deleted file mode 100644 index 36c499358..000000000 --- a/src/codegen/sdk/typescript/export.py +++ /dev/null @@ -1,705 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, override - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.autocommit.decorators import writer -from codegen.sdk.core.dataclasses.usage import UsageKind, UsageType -from codegen.sdk.core.export import Export -from codegen.sdk.core.expressions.name import Name -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_value import HasValue -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.enums import EdgeType, ImportType, NodeType -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.typescript.assignment import TSAssignment -from codegen.sdk.typescript.class_definition import TSClass -from codegen.sdk.typescript.enum_definition import TSEnum -from codegen.sdk.typescript.enums import TSFunctionTypeNames -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.import_resolution import TSImport -from codegen.sdk.typescript.interface import TSInterface -from codegen.sdk.typescript.namespace import TSNamespace -from codegen.sdk.typescript.statements.assignment_statement import TSAssignmentStatement -from codegen.sdk.typescript.type_alias import TSTypeAlias -from codegen.sdk.utils import find_all_descendants -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.codebase.resolution_stack import ResolutionStack - from codegen.sdk.core.interfaces.exportable import Exportable - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.core.symbol_groups.collection import Collection - from codegen.sdk.typescript.symbol import TSSymbol - - -@ts_apidoc -class TSExport(Export["Collection[TSExport, ExportStatement[TSExport]]"], HasValue, Chainable): - """Represents a single exported symbol. - - There is a 1:M relationship between an ExportStatement and an Export - - Attributes: - node_type: The type of the node, set to NodeType.EXPORT. - """ - - _declared_symbol: TSSymbol | TSImport | None - _exported_symbol: Name | None - _name_node: Name | None - node_type: Literal[NodeType.EXPORT] = NodeType.EXPORT - - def __init__( - self, - ts_node: TSNode, - file_node_id: NodeId, - parent: Collection[TSExport, ExportStatement[TSExport]], - ctx: CodebaseContext, - name_node: TSNode | None = None, - declared_symbol: TSSymbol | TSImport | None = None, - exported_symbol: TSNode | None = None, - value_node: TSNode | None = None, - ) -> None: - """Given an `export_statement` tree sitter node, parses all implicit export symbols.""" - if declared_symbol and exported_symbol and declared_symbol.name != exported_symbol.text.decode("utf-8"): - msg = "The exported symbol name must match the declared symbol name" - raise ValueError(msg) - - super().__init__(ts_node, file_node_id, ctx, parent) - self._name_node = self._parse_expression(name_node, default=Name) - self._declared_symbol = declared_symbol - self._exported_symbol = self._parse_expression(exported_symbol, default=Name) - # if self.is_wildcard_export(): - # self.node_id = NodeIdFactory.export_node_id(name=f"wildcard_export_<{self._declared_symbol.node_id}>", file_id=self.file_node_id, is_default=self.is_default_export()) - # else: - # self.node_id = NodeIdFactory.export_node_id(name=self.name, file_id=self.file_node_id, is_default=self.is_default_export()) - self.parse(ctx) - self._value_node = self._parse_expression(value_node) - - @classmethod - @noapidoc - def from_export_statement_with_declaration( - cls, - export_statement: TSNode, - declaration: TSNode, - file_id: NodeId, - ctx: CodebaseContext, - parent: ExportStatement[TSExport], - pos: int, - ) -> list[TSExport]: - declared_symbols = [] - - # =====[ Symbol Definitions ]===== - if declaration.type in ["function_declaration", "generator_function_declaration"]: - # e.g. export function* namedGenerator() {} - declared_symbols.append(TSFunction(declaration, file_id, ctx, parent)) - elif declaration.type == "class_declaration": - # e.g. export class NamedClass {} - declared_symbols.append(TSClass(declaration, file_id, ctx, parent)) - elif declaration.type in ["variable_declaration", "lexical_declaration"]: - if len(arrow_functions := find_all_descendants(declaration, {"arrow_function"}, max_depth=2)) > 0: - # e.g. export const arrowFunction = () => {}, but not export const a = { func: () => null } - for arrow_func in arrow_functions: - declared_symbols.append(TSFunction.from_function_type(arrow_func, file_id, ctx, parent)) - else: - # e.g. export const a = value; - for child in declaration.named_children: - if child.type in TSAssignmentStatement.assignment_types: - s = TSAssignmentStatement.from_assignment(declaration, file_id, ctx, parent.parent, pos, assignment_node=child) - declared_symbols.extend(s.assignments) - elif declaration.type == "interface_declaration": - # e.g. export interface MyInterface {} - declared_symbols.append(TSInterface(declaration, file_id, ctx, parent)) - elif declaration.type == "type_alias_declaration": - # e.g. export type MyType = {} - declared_symbols.append(TSTypeAlias(declaration, file_id, ctx, parent)) - elif declaration.type == "enum_declaration": - # e.g. export enum MyEnum {} - declared_symbols.append(TSEnum(declaration, file_id, ctx, parent)) - elif declaration.type == "internal_module": - # e.g. export namespace MyNamespace {} - declared_symbols.append(TSNamespace(declaration, file_id, ctx, parent)) - else: - declared_symbols.append(None) - - exports = [] - for declared_symbol in declared_symbols: - name_node = declared_symbol._name_node.ts_node if declared_symbol and declared_symbol._name_node else declaration - export = cls(ts_node=declaration, file_node_id=file_id, ctx=ctx, name_node=name_node, declared_symbol=declared_symbol, parent=parent.exports) - exports.append(export) - return exports - - @classmethod - @noapidoc - def from_export_statement_with_value(cls, export_statement: TSNode, value: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: ExportStatement[TSExport], pos: int) -> list[TSExport]: - declared_symbols = [] - exported_name_and_symbol = [] # tuple of export name node and export symbol name - detached_value_node = None - - # =====[ Symbol Definitions ]===== - if value.type in [function_type.value for function_type in TSFunctionTypeNames]: - # e.g. export default async function() {} - declared_symbols.append(parent._parse_expression(value)) - elif value.type == "class": - # e.g. export default class {} - declared_symbols.append(parent._parse_expression(value, default=TSClass)) - elif value.type == "object": - # e.g. export default { a, b, c }, export = { a, b, c } - # Export symbol usage will get resolved in _compute_dependencies based on identifiers in value - # TODO: parse as TSDict - detached_value_node = value - for child in value.named_children: - if child.type == "pair": - key_value = child.child_by_field_name("key") - pair_value = child.child_by_field_name("value") - if pair_value.type in [function_type.value for function_type in TSFunctionTypeNames]: - declared_symbols.append(TSFunction(pair_value, file_id, ctx, parent)) - elif pair_value.type == "class": - declared_symbols.append(TSClass(pair_value, file_id, ctx, parent)) - else: - exported_name_and_symbol.append((key_value, pair_value)) - elif child.type == "shorthand_property_identifier": - exported_name_and_symbol.append((child, child)) - elif value.type == "assignment_expression": - left = value.child_by_field_name("left") - right = value.child_by_field_name("right") - assignment = TSAssignment(value, file_id, ctx, parent, left, right, left) - declared_symbols.append(assignment) - else: - # Other values are detached symbols: array, number, string, true, null, undefined, new_expression, call_expression - # Export symbol usage will get resolved in _compute_dependencies based on identifiers in value - detached_value_node = value - declared_symbols.append(None) - - exports = [] - for declared_symbol in declared_symbols: - if declared_symbol is None: - name_node = value - else: - name_node = declared_symbol._name_node.ts_node if declared_symbol._name_node else declared_symbol.ts_node - export = cls(ts_node=export_statement, file_node_id=file_id, ctx=ctx, name_node=name_node, declared_symbol=declared_symbol, value_node=detached_value_node, parent=parent.exports) - exports.append(export) - for name_node, symbol_name_node in exported_name_and_symbol: - exports.append(cls(ts_node=export_statement, file_node_id=file_id, ctx=ctx, name_node=name_node, exported_symbol=symbol_name_node, value_node=detached_value_node, parent=parent.exports)) - return exports - - @noapidoc - @commiter - def parse(self, ctx: CodebaseContext) -> None: - pass - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.exported_symbol: - for frame in self.resolved_type_frames: - if frame.parent_frame: - frame.parent_frame.add_usage(self._name_node or self, UsageKind.EXPORTED_SYMBOL, self, self.ctx) - elif self._exported_symbol: - if not next(self.resolve_name(self._exported_symbol.source), None): - self._exported_symbol._compute_dependencies(UsageKind.BODY, dest=dest or self) - elif self.value: - self.value._compute_dependencies(UsageKind.EXPORTED_SYMBOL, self) - - @noapidoc - @commiter - def compute_export_dependencies(self) -> None: - """Create Export edges from this export to it's used symbols""" - if self.declared_symbol is not None: - assert self.ctx.has_node(self.declared_symbol.node_id) - self.ctx.add_edge(self.node_id, self.declared_symbol.node_id, type=EdgeType.EXPORT) - elif self._exported_symbol is not None: - symbol_name = self._exported_symbol.source - if (used_node := next(self.resolve_name(symbol_name), None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): - self.ctx.add_edge(self.node_id, used_node.node_id, type=EdgeType.EXPORT) - elif self.value is not None: - if isinstance(self.value, Chainable): - for resolved in self.value.resolved_types: - if self.ctx.has_node(getattr(resolved, "node_id", None)): - self.ctx.add_edge(self.node_id, resolved.node_id, type=EdgeType.EXPORT) - elif self.name is None: - # This is the export *; case - self.ctx.add_edge(self.node_id, self.file_node_id, type=EdgeType.EXPORT) - if self.is_wildcard_export(): - for file in self.file.importers: - file.__dict__.pop("valid_symbol_names", None) - file.__dict__.pop("valid_import_names", None) - - @reader - def is_named_export(self) -> bool: - """Determines whether this export is a named export. - - Named exports are exports that are not default exports. For example, `export const foo = 'bar'` is a named export, - while `export default foo` is not. - - Returns: - bool: True if this is a named export, False if it is a default export. - """ - return not self.is_default_export() - - @reader - def is_default_export(self) -> bool: - """Determines if an export is the default export for a file. - - This function checks if the export is a default export by examining the export source code and the export's symbol. It handles various cases of default exports including: - - Re-exports as default (`export { foo as default }`) - - Default exports (`export default foo`) - - Module exports (`export = foo`) - - Returns: - bool: True if this is a default export, False otherwise. - """ - exported_symbol = self.exported_symbol - if exported_symbol and isinstance(exported_symbol, TSImport) and exported_symbol.is_default_import(): - return True - - # ==== [ Case: Named re-export as default ] ==== - # e.g. export { foo as default } from './other-module'; - exported_symbol = self.exported_symbol - if exported_symbol is not None and exported_symbol.node_type == NodeType.IMPORT and exported_symbol.source == self.source: - return self.name == "default" - - # ==== [ Case: Default export ] ==== - # e.g. export default foo; export default { foo }; export = foo; export = { foo }; - return self.parent.parent.source.startswith("export default ") or self.parent.parent.source.startswith("export = ") - - @reader - def is_default_symbol_export(self) -> bool: - """Returns True if this is exporting a default symbol, as opposed to a default object export. - - This method checks if an export is a default symbol export (e.g. 'export default foo') rather than a default object export (e.g. 'export default { foo }'). - It handles both direct exports and re-exports. - - Args: - self (TSExport): The export object being checked. - - Returns: - bool: True if this is a default symbol export, False otherwise. - """ - if not self.is_default_export(): - return False - - # ==== [ Case: Default import re-export ] ==== - exported_symbol = self.exported_symbol - if exported_symbol is not None and exported_symbol.node_type == NodeType.IMPORT and exported_symbol.source == self.source: - return self.name == "default" - - # === [ Case: Default symbol export ] ==== - export_object = next((x for x in self.ts_node.children if x.type == "object"), None) - return export_object is None - - @reader - def is_type_export(self) -> bool: - """Determines if this export is exclusively exporting a type. - - Checks if this export starts with "export type" to identify if it's only exporting a type definition. - This method is used to distinguish between value exports and type exports in TypeScript. - - Returns: - bool: True if this is a type-only export, False otherwise. - """ - # TODO: do this more robustly - return self.source.startswith("export type ") - - @reader - def is_reexport(self) -> bool: - """Returns whether the export is re-exporting an import or export. - - Checks if this export node is re-exporting a symbol that was originally imported from another module or exported from another location. This includes wildcard re-exports of entire modules. - - Args: - self (TSExport): The export node being checked. - - Returns: - bool: True if this export re-exports an imported/exported symbol or entire module, False otherwise. - """ - if exported_symbol := self.exported_symbol: - return exported_symbol.node_type == NodeType.IMPORT or exported_symbol.node_type == NodeType.EXPORT or exported_symbol == self.file - return False - - @reader - def is_wildcard_export(self) -> bool: - """Determines if the export is a wildcard export. - - Checks if the export statement contains a wildcard export pattern 'export *' or 'export *;'. A wildcard export exports all symbols from a module. - - Returns: - bool: True if the export is a wildcard export (e.g. 'export * from "./module"'), False otherwise. - """ - return "export * " in self.source or "export *;" in self.source - - @reader - def is_module_export(self) -> bool: - """Determines if the export is exporting a module rather than a symbol. - - Returns True if the export is a wildcard export (e.g. 'export *') or if it is a default export but not of a symbol (e.g. 'export default { foo }'). - - Returns: - bool: True if the export represents a module export, False otherwise. - """ - return self.is_wildcard_export() or (self.is_default_export() and not self.is_default_symbol_export()) - - @property - @reader(cache=False) - def declared_symbol(self) -> TSSymbol | TSImport | None: - """Returns the symbol that was defined in this export. - - Returns the symbol that was directly declared within this export statement. For class, function, - interface, type alias, enum declarations or assignments, returns the declared symbol. - For re-exports or exports without declarations, returns None. - - Returns: - Union[TSSymbol, TSImport, None]: The symbol declared within this export statement, - or None if no symbol was declared. - """ - return self._declared_symbol - - @property - @reader - def exported_symbol(self) -> Exportable | None: - """Returns the symbol, file, or import being exported from this export object. - - Retrieves the symbol or module being exported by this export node by finding the node connected via an EXPORT edge. - This method is the inverse of Import.imported_symbol. - - Args: - None - - Returns: - Exportable | None: The exported symbol, file, or import, or None if no symbol is exported. - """ - return next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.EXPORT)), None) - - @property - @reader - def resolved_symbol(self) -> Exportable | None: - """Returns the Symbol, SourceFile or External module that this export resolves to. - - Recursively traverses through indirect imports and exports to find the final resolved symbol. - This is useful for determining what symbol an export ultimately points to, particularly in cases of re-exports and import-export chains. - - Returns: - Exportable | None: The final resolved Symbol, SourceFile or External module, or None if the resolution fails. The resolution follows this chain: - - If the symbol is an Import, resolves to its imported symbol - - If the symbol is an Export, resolves to its exported symbol - - Otherwise returns the symbol itself - - Note: - Handles circular references by tracking visited symbols to prevent infinite loops. - """ - ix_seen = set() - resolved_symbol = self.exported_symbol - - while resolved_symbol is not None and (resolved_symbol.node_type == NodeType.IMPORT or resolved_symbol.node_type == NodeType.EXPORT): - if resolved_symbol in ix_seen: - return resolved_symbol - - ix_seen.add(resolved_symbol) - if resolved_symbol.node_type == NodeType.IMPORT: - resolved_symbol = resolved_symbol.resolved_symbol - else: - resolved_symbol = resolved_symbol.exported_symbol - - return resolved_symbol - - @writer - def make_non_default(self) -> None: - """Converts the export to a named export. - - Transforms default exports into named exports by modifying the export syntax and updating any corresponding export/import usages. - For default exports, it removes the 'default' keyword and adjusts all import statements that reference this export. - - Args: - None - - Returns: - None - """ - if self.is_default_export(): - # Default node is: - # export default foo = ... - # ^^^^^^^ - default_node = self.parent.parent._anonymous_children[1] - - if default_node.ts_node.type == "default": - if isinstance(self.declared_symbol, TSAssignment): - # Converts `export default foo` to `export const foo` - default_node.edit("const") - else: - # Converts `export default foo` to `export { foo }` - default_node.remove() - if name_node := self.get_name(): - name_node.insert_before("{ ", newline=False) - name_node.insert_after(" }", newline=False) - - # Update all usages of this export - for usage in self.usages(usage_types=UsageType.DIRECT): - if usage.match is not None and usage.kind == UsageKind.IMPORTED: - # === [ Case: Exported Symbol ] === - # Fixes Exports of the form `export { ... } from ...` - if usage.usage_symbol.source.startswith("export") and usage.match.source == "default": - # Export clause is: - # export { default as foo } from ... - # ^^^^^^^^^^^^^^^^^^ - export_clause = usage.usage_symbol.children[0] - for export_specifier in export_clause.children: - # This is the case where `export { ... as ... }` - if len(export_specifier.children) == 2 and export_specifier.children[0] == usage.match: - if export_specifier.children[1].source == self.name: - # Converts `export { default as foo }` to `export { foo }` - export_specifier.edit(self.name) - else: - # Converts `export { default as renamed_foo }` to `export { foo as renamed_foo }` - usage.match.edit(self.name) - # This is the case where `export { ... } from ...`, (specifically `export { default }`) - elif len(export_specifier.children) == 1 and export_specifier.children[0] == usage.match: - # Converts `export { default }` to `export { foo }` - export_specifier.edit(self.name) - - # === [ Case: Imported Symbol ] === - # Fixes Imports of the form `import { default as foo }` - else: - # Import clause is: - # import A, { B } from ... - # ^^^^^^^^ - import_clause = usage.usage_symbol.children[0] - - # Fixes imports of the form `import foo, { ... } from ...` - if len(import_clause.children) > 1 and import_clause.children[0] == usage.match: - # This is a terrible hack :skull: - - # Named imports are: - # import foo, { ... } - # ^^^^^^^ - named_imports = import_clause.children[1] - - # This converts `import foo, { bar, baz as waz }` to `import { foo, bar, baz as waz }` - import_clause.children[0].remove() # Remove `foo, ` - named_imports.children[0].insert_before(f"{self.name}, ", newline=False) # Add the `foo, ` - # Fixes imports of the form `import foo from ...` - else: - # This converts `import foo` to `import { foo }` - usage.match.insert_before("{ ", newline=False) - usage.match.insert_after(" }", newline=False) - - @cached_property - @noapidoc - @reader - def _wildcards(self) -> dict[str, WildcardExport[Self]]: - if self.is_wildcard_export() and isinstance(self.exported_symbol, Import): - res = {} - for name, symbol in self.exported_symbol._wildcards.items(): - res[name] = WildcardExport(self, symbol) - return res - return {} - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - aliased = self.is_aliased() - if self.exported_symbol is not None: - yield from self.with_resolution_frame(self.exported_symbol, direct=True, aliased=aliased) - elif self.value is not None: - yield from self.with_resolution_frame(self.value, direct=True, aliased=aliased) - - @property - @noapidoc - def names(self) -> Generator[tuple[str, Self | WildcardExport[Self]], None, None]: - if self.exported_name is None: - if self.is_wildcard_export(): - yield from self._wildcards.items() - else: - yield self.exported_name, self - - @property - def descendant_symbols(self) -> list[Importable]: - """Returns a list of all descendant symbols from this export's declared symbol. - - Returns all child symbols that are contained within the declared symbol of this export. For example, - if the declared symbol is a class, this will return all methods, properties and nested classes. - If the export has no declared symbol, returns an empty list. - - Returns: - list[Importable]: List of descendant symbols. Empty list if no declared symbol exists. - """ - if self.declared_symbol: - return [self, *self.declared_symbol.descendant_symbols] - return [self] - - def __hash__(self): - if self._hash is None: - self._hash = hash((self.filepath, self.range, self.ts_node.kind_id, self.name)) - return self._hash - - @reader - def __eq__(self, other: object): - if isinstance(other, TSExport): - return super().__eq__(other) and self.name == other.name - return super().__eq__(other) - - @property - @reader - def source(self) -> str: - """Returns the source code of the symbol. - - Gets the source code of the symbol from its extended representation, which includes the export statement. - - Returns: - str: The complete source code of the symbol including any extended nodes. - """ - return self.parent.parent.source - - @property - @reader - def is_external_export(self) -> bool: - """Determines if this export is exporting a symbol from an external (non-relative) module. - - An external module is one that comes from outside the project's codebase. - - Returns: - bool: True if the export is from an external module, False otherwise. - """ - if self.is_reexport(): - if isinstance(self.exported_symbol, TSImport): - for resolved in self.exported_symbol.resolved_types: - if isinstance(resolved, ExternalModule): - return True - return False - - @reader - def to_import_string(self) -> str: - """Converts this export into its equivalent import string representation. - - This is primarily used for handling re-exports, converting them into their - equivalent import statements. - - Returns: - str: The import string representation of this export. - - Examples: - - For `export { foo } from './bar'` -> `import { foo } from './bar'` - - For `export * from './bar'` -> `import * as _namespace from './bar'` - - For `export { default as foo } from './bar'` -> `import foo from './bar'` - """ - module_path = self.exported_symbol.module.source.strip("'\"") if self.exported_symbol.module is not None else "" - type_prefix = "type " if self.is_type_export() else "" - - if self.is_wildcard_export(): - namespace = self.name or module_path.split("/")[-1].split(".")[0] - return f"import * as {namespace} from '{module_path}';" - - if self.is_default_export(): - if self.is_type_export() and self.is_aliased(): - original_name = self.exported_symbol.symbol_name.source if self.exported_symbol.symbol_name is not None else self.exported_symbol.name - print(original_name) - if original_name == "default": - return f"import {type_prefix}{{ default as {self.name} }} from '{module_path}';" - else: - return f"import {type_prefix}{{ {original_name} as default }} from '{module_path}';" - - # Handle mixed type and value exports - if "type" in self.source and "," in self.source and "{" in self.source and "}" in self.source: - content = self.source[self.source.index("{") + 1 : self.source.index("}")].strip() - return f"import {{ {content} }} from '{module_path}';" - - original_name = self.exported_symbol.symbol_name.source if self.exported_symbol.symbol_name is not None else self.exported_symbol.name - return f"import {{ {original_name} as {self.name} }} from '{module_path}';" - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Returns the import string for this export. - - Args: - alias (str | None): Optional alias to use when importing the symbol. - module (str | None): Optional module name to import from. - import_type (ImportType): The type of import to generate. - is_type_import (bool): Whether this is a type-only import. - - Returns: - str: The formatted import string. - """ - if self.is_reexport(): - return self.to_import_string() - - module_path = self.file.import_module_name.strip("'\"") - type_prefix = "type " if is_type_import else "" - - if import_type == ImportType.WILDCARD: - namespace = alias or module_path.split("/")[-1].split(".")[0] - return f"import * as {namespace} from '{module_path}';" - - # Handle default exports - if self.is_default_export(): - name = alias or self.name - return f"import {name} from '{module_path}';" - - # Handle named exports - original_name = self.name - if alias and alias != original_name: - return f"import {type_prefix}{{ {original_name} as {alias} }} from '{module_path}';" - return f"import {type_prefix}{{ {original_name} }} from '{module_path}';" - - @reader - def reexport_symbol(self) -> TSImport | None: - """Returns the import object that is re-exporting this symbol. - - For re-exports like: - - `export { foo } from './bar'` # Direct re-export - - `export { default as baz } from './bar'` # Direct default re-export - - `export * from './bar'` # Direct wildcard re-export - - `import { foo } from './bar'; export { foo }` # Local re-export - - This returns the corresponding import object that's being re-exported. - - Returns: - TSImport | None: The import object being re-exported, or None if this - is not a re-export or no import was found. - """ - # Only exports can have re-export sources - if not self.is_reexport(): - return None - - # For direct re-exports (export { x } from './y'), use declared_symbol - if self.declared_symbol is not None: - return self.declared_symbol - - # For local re-exports (import x; export { x }), use exported_symbol - if self.exported_symbol is not None and self.exported_symbol.node_type == NodeType.IMPORT: - return self.exported_symbol - - return None - - -TExport = TypeVar("TExport", bound="Export") - - -class WildcardExport(Chainable, Generic[TExport]): - """Class to represent one of many wildcard exports.""" - - exp: TExport - symbol: Exportable - - def __init__(self, exp: TExport, symbol: Exportable): - self.exp = exp - self.symbol = symbol - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - """Resolve the types used by this import.""" - yield from self.exp.with_resolution_frame(self.symbol, direct=False) - - @noapidoc - @reader - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: - pass diff --git a/src/codegen/sdk/typescript/expressions/array_type.py b/src/codegen/sdk/typescript/expressions/array_type.py deleted file mode 100644 index 0fe714bbd..000000000 --- a/src/codegen/sdk/typescript/expressions/array_type.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.typescript.expressions.named_type import TSNamedType -from codegen.shared.decorators.docs import ts_apidoc - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSArrayType(TSNamedType[Parent]): - """Array type - Examples: - string[] - """ - - def _get_name_node(self) -> TSNode | None: - return self.ts_node.named_children[0] diff --git a/src/codegen/sdk/typescript/expressions/chained_attribute.py b/src/codegen/sdk/typescript/expressions/chained_attribute.py deleted file mode 100644 index 87734848c..000000000 --- a/src/codegen/sdk/typescript/expressions/chained_attribute.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import Expression, Name -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@ts_apidoc -class TSChainedAttribute(ChainedAttribute[Expression, Name, Parent], Generic[Parent]): - """A TypeScript chained attribute class representing member access expressions. - - This class handles the representation and analysis of chained attribute access expressions in TypeScript, - such as 'object.property' or 'object.method()'. It provides functionality for accessing the object - and property components of the expression, as well as analyzing function calls made on the object. - """ - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent=parent, object=ts_node.child_by_field_name("object"), attribute=ts_node.child_by_field_name("property")) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Returns a list of function calls associated with this chained attribute's object. - - Retrieves all function calls made on the object component of this chained attribute. - This is useful for analyzing call sites and call patterns in code analysis and refactoring tasks. - - Returns: - list[FunctionCall]: A list of function calls made on this chained attribute's object. - """ - # Move the parent reference to its own parent to skip over an identifier type in parent chain - return self._object.function_calls diff --git a/src/codegen/sdk/typescript/expressions/conditional_type.py b/src/codegen/sdk/typescript/expressions/conditional_type.py deleted file mode 100644 index 654876e4a..000000000 --- a/src/codegen/sdk/typescript/expressions/conditional_type.py +++ /dev/null @@ -1,59 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.expressions.type import TSType - - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSConditionalType(Type[Parent], Generic[Parent]): - """Conditional Type - - Examples: - typeof s - - Attributes: - left: The left-hand side type of the conditional type. - right: The right-hand side type of the conditional type. - consequence: The type if the condition is true. - alternative: The type if the condition is false. - """ - - left: "TSType[Self]" - right: "TSType[Self]" - consequence: "TSType[Self]" - alternative: "TSType[Self]" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self.left = self.child_by_field_name("left") - self.right = self.child_by_field_name("right") - self.consequence = self.child_by_field_name("consequence") - self.alternative = self.child_by_field_name("alternative") - - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - self.left._compute_dependencies(usage_type, dest) - self.right._compute_dependencies(usage_type, dest) - self.consequence._compute_dependencies(usage_type, dest) - self.alternative._compute_dependencies(usage_type, dest) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.consequence) - yield from self.with_resolution_frame(self.alternative) diff --git a/src/codegen/sdk/typescript/expressions/expression_type.py b/src/codegen/sdk/typescript/expressions/expression_type.py deleted file mode 100644 index 8a866bb60..000000000 --- a/src/codegen/sdk/typescript/expressions/expression_type.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.typescript.expressions.named_type import TSNamedType -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@ts_apidoc -class TSExpressionType(TSNamedType, Generic[Parent]): - """Type defined by evaluation of an expression - - Attributes: - expression: The expression to evaluate that yields the type - """ - - expression: Expression["TSExpressionType[Parent]"] - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self.expression = self._parse_expression(ts_node) diff --git a/src/codegen/sdk/typescript/expressions/function_type.py b/src/codegen/sdk/typescript/expressions/function_type.py deleted file mode 100644 index 85807bade..000000000 --- a/src/codegen/sdk/typescript/expressions/function_type.py +++ /dev/null @@ -1,94 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.typescript.detached_symbols.parameter import TSParameter -from codegen.sdk.typescript.placeholder.placeholder_return_type import TSReturnTypePlaceholder -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.expressions.type import TSType - - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSFunctionType(Type[Parent], Generic[Parent]): - """Function type definition. - - Example: - a: (a: number) => number - - Attributes: - return_type: Return type of the function. - name: This lets parameters generate their node_id properly. - """ - - return_type: "TSType[Self] | TSReturnTypePlaceholder[Self]" - _parameters: Collection[TSParameter, Self] - name: None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self.return_type = self.child_by_field_name("return_type", placeholder=TSReturnTypePlaceholder) - params_node = self.ts_node.child_by_field_name("parameters") - params = [TSParameter(child, idx, self) for idx, child in enumerate(params_node.named_children) if child.type != "comment"] - self._parameters = Collection(params_node, file_node_id, ctx, self, children=params) - - @property - @reader - def parameters(self) -> Collection[TSParameter, Self]: - """Retrieves the parameters of a function type. - - Returns the collection of parameters associated with this function type. These parameters represent the arguments that can be passed to the function. - - Returns: - Collection[TSParameter, Self]: A collection of TSParameter objects representing the function's parameters. - """ - return self._parameters - - @writer - def asyncify(self) -> None: - """Modifies the function type to be asynchronous by wrapping its return type in a Promise. - - This method transforms a synchronous function type into an asynchronous one by modifying - its return type. It wraps the existing return type in a Promise, effectively changing - 'T' to 'Promise'. - - Args: - self: The TSFunctionType instance to modify. - - Returns: - None - """ - if self.return_type: - self.return_type.insert_before("Promise<", newline=False) - self.return_type.insert_after(">", newline=False) - - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: Importable | None = None): - if self.return_type: - self.return_type._compute_dependencies(UsageKind.GENERIC, dest) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.return_type) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = [] - for param in self.parameters: - symbols.extend(param.descendant_symbols) - return symbols diff --git a/src/codegen/sdk/typescript/expressions/generic_type.py b/src/codegen/sdk/typescript/expressions/generic_type.py deleted file mode 100644 index 6e43572a7..000000000 --- a/src/codegen/sdk/typescript/expressions/generic_type.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import TYPE_CHECKING, Self, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.core.symbol_groups.dict import Dict -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.typescript.expressions.type import TSType - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSGenericType(GenericType["TSType", Parent]): - """Generic type - - Examples: - `Array` - """ - - def _get_name_node(self) -> TSNode: - return self.child_by_field_name("name").ts_node - - def _get_parameters(self) -> Collection[Self, Self] | Dict[Self, Self] | None: - type_parameter = self.child_by_field_types("type_arguments").ts_node - types = [self._parse_type(child) for child in type_parameter.named_children] - return Collection(node=type_parameter, file_node_id=self.file_node_id, ctx=self.ctx, parent=self, children=types) diff --git a/src/codegen/sdk/typescript/expressions/lookup_type.py b/src/codegen/sdk/typescript/expressions/lookup_type.py deleted file mode 100644 index 1885b9545..000000000 --- a/src/codegen/sdk/typescript/expressions/lookup_type.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.expressions.type import TSType - - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSLookupType(Type[Parent], Generic[Parent]): - """Type lookup - - Examples: - a["key"] - - Attributes: - type: The type of the TypeScript object being looked up. - lookup: The expression used for the lookup operation. - """ - - type: "TSType[Self]" - lookup: Expression - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self.type = self._parse_type(ts_node.named_children[0]) - if literal_type := self.child_by_field_types("literal_type"): - self.lookup = self._parse_expression(literal_type.ts_node.named_children[0]) - - @property - @reader - def name(self) -> str | None: - """Retrieves the name of the type object. - - Gets the name property of the underlying type object. This property is commonly used to access type names in TypeScript-style type lookups. - - Returns: - str | None: The name of the type object if it exists, None otherwise. - """ - return self.type.name - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - # TODO: not implemented properly. Needs to look at the actual lookup - self._log_parse("Cannot resolve lookup type properly") - yield from self.with_resolution_frame(self.type) - - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - self.type._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/typescript/expressions/named_type.py b/src/codegen/sdk/typescript/expressions/named_type.py deleted file mode 100644 index 223f61de5..000000000 --- a/src/codegen/sdk/typescript/expressions/named_type.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions.named_type import NamedType -from codegen.shared.decorators.docs import ts_apidoc - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSNamedType(NamedType[Parent]): - """Named type - Examples: - string - """ - - def _get_name_node(self) -> TSNode | None: - return self.ts_node diff --git a/src/codegen/sdk/typescript/expressions/object_type.py b/src/codegen/sdk/typescript/expressions/object_type.py deleted file mode 100644 index 60198f750..000000000 --- a/src/codegen/sdk/typescript/expressions/object_type.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.expression import Expression -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.expressions.value import Value -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.typescript.symbol_groups.dict import TSDict, TSPair -from codegen.shared.decorators.docs import ts_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -logger = get_logger(__name__) - - -Parent = TypeVar("Parent") - - -class TSObjectPair(TSPair, Generic[Parent]): - """Object type - - Examples: - a: {a: int; b?(a: int): c} - """ - - def _get_key_value(self) -> tuple[Expression[Self] | None, Expression[Self] | None]: - from codegen.sdk.typescript.expressions.function_type import TSFunctionType - - key, value = None, None - if self.ts_node_type == "property_signature": - type_node = self.ts_node.child_by_field_name("type") - value = self._parse_expression(type_node) - key = self._parse_expression(self.ts_node.child_by_field_name("name")) - elif self.ts_node_type == "call_signature": - value = TSFunctionType(self.ts_node, self.file_node_id, self.ctx, self) - elif self.ts_node_type == "index_signature": - value = self._parse_expression(self.ts_node.child_by_field_name("type")) - key = self._parse_expression(self.ts_node.named_children[0]) - elif self.ts_node_type == "method_signature": - value = TSFunctionType(self.ts_node, self.file_node_id, self.ctx, self) - key = self._parse_expression(self.ts_node.child_by_field_name("name")) - elif self.ts_node_type == "method_definition": - key = self._parse_expression(self.ts_node.child_by_field_name("mapped_clause_type")) - value = self._parse_expression(self.ts_node.child_by_field_name("type")) - else: - key, value = super()._get_key_value() - if isinstance(value, Value): - # HACK: sometimes types are weird - value = self._parse_expression(value.ts_node.named_children[0]) - elif not isinstance(value, Type): - self._log_parse(f"{value} of type {value.__class__.__name__} from {self.ts_node} not a valid type") - - return key, value - - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSObjectType(TSDict, Type[Parent], Generic[Parent]): - """A class representing a TypeScript object type with type annotations and dependencies. - - A specialized class extending `TSDict` and implementing `Type` for handling TypeScript object type annotations. - This class handles object type definitions including nested type structures and manages their dependencies. - It provides functionality for computing dependencies within the type structure and handling type relationships - in TypeScript code. - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, delimiter=";", pair_type=TSObjectPair) - - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - for child in self.values(): - if isinstance(child, Type): - child._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/typescript/expressions/query_type.py b/src/codegen/sdk/typescript/expressions/query_type.py deleted file mode 100644 index 1fde11789..000000000 --- a/src/codegen/sdk/typescript/expressions/query_type.py +++ /dev/null @@ -1,59 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.expressions.type import TSType - - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSQueryType(Type[Parent], Generic[Parent]): - """Type query - - Examples: - typeof s - - Attributes: - query: The TypeScript type associated with the query. - """ - - query: "TSType[Self]" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self.query = self._parse_type(ts_node.named_children[0]) - - @property - @reader - def name(self) -> str | None: - """Returns the name of the query type. - - A property that retrieves the name of the query type. This property is used to get the name - associated with TypeScript type queries (e.g., 'typeof s'). - - Returns: - str | None: The name of the query type, or None if no name is available. - """ - return self.query.name - - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - self.query._compute_dependencies(usage_type, dest) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.query) diff --git a/src/codegen/sdk/typescript/expressions/readonly_type.py b/src/codegen/sdk/typescript/expressions/readonly_type.py deleted file mode 100644 index 40e12083f..000000000 --- a/src/codegen/sdk/typescript/expressions/readonly_type.py +++ /dev/null @@ -1,59 +0,0 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.expressions.type import TSType - - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSReadonlyType(Type[Parent], Generic[Parent]): - """Readonly type - - Examples: - readonly s - - Attributes: - type: The underlying TypeScript type associated with this readonly type. - """ - - type: "TSType[Self]" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): - super().__init__(ts_node, file_node_id, ctx, parent) - self.type = self._parse_type(ts_node.named_children[0]) - - @property - @reader - def name(self) -> str | None: - """Retrieves the name of the type. - - Gets the name from the underlying type object. Since this is a property getter, it is decorated with @reader - to ensure safe concurrent access. - - Returns: - str | None: The name of the type, or None if the type has no name. - """ - return self.type.name - - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - self.type._compute_dependencies(usage_type, dest) - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from self.with_resolution_frame(self.type) diff --git a/src/codegen/sdk/typescript/expressions/string.py b/src/codegen/sdk/typescript/expressions/string.py deleted file mode 100644 index 1b078abd7..000000000 --- a/src/codegen/sdk/typescript/expressions/string.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.expressions import Expression, String -from codegen.sdk.core.node_id_factory import NodeId -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - - -Parent = TypeVar("Parent", bound="Expression") - - -@ts_apidoc -class TSString(String, Generic[Parent]): - """A TypeScript string node representing both literal strings and template strings. - - This class handles both regular string literals and template strings in TypeScript, - providing functionality to parse and manage template string expressions. It extends - the base String class with TypeScript-specific capabilities. - - Attributes: - expressions (list): A list of parsed expressions from template string substitutions. - Empty for regular string literals. - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - if ts_node.type == "template_string": - substitutions = [x for x in ts_node.named_children if x.type == "template_substitution"] - self.expressions = [self._parse_expression(x.named_children[0]) for x in substitutions] - else: - self.expressions = [] diff --git a/src/codegen/sdk/typescript/expressions/ternary_expression.py b/src/codegen/sdk/typescript/expressions/ternary_expression.py deleted file mode 100644 index a10d70667..000000000 --- a/src/codegen/sdk/typescript/expressions/ternary_expression.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import TYPE_CHECKING, TypeVar - -from codegen.sdk.core.expressions.ternary_expression import TernaryExpression -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@ts_apidoc -class TSTernaryExpression(TernaryExpression[Parent]): - """Any ternary expression in the code where a condition will determine branched execution""" - - def __init__(self, ts_node, file_node_id, ctx, parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent=parent) - self.condition = self.child_by_field_name("condition") - self.consequence = self.child_by_field_name("consequence") - self.alternative = self.child_by_field_name("alternative") diff --git a/src/codegen/sdk/typescript/expressions/type.py b/src/codegen/sdk/typescript/expressions/type.py deleted file mode 100644 index 74d4d7ad3..000000000 --- a/src/codegen/sdk/typescript/expressions/type.py +++ /dev/null @@ -1,2 +0,0 @@ -TSType = "TSUnionType[Parent] | TSObjectType[Parent] | TSNamedType[Parent] | TSGenericType[Parent] | TSQueryType[Parent] | TSReadonlyType[Parent] | NoneType[Parent] | TSUndefinedType[Parent]" -__all__ = ["TSType"] diff --git a/src/codegen/sdk/typescript/expressions/undefined_type.py b/src/codegen/sdk/typescript/expressions/undefined_type.py deleted file mode 100644 index 0a0abd49e..000000000 --- a/src/codegen/sdk/typescript/expressions/undefined_type.py +++ /dev/null @@ -1,29 +0,0 @@ -from collections.abc import Generator -from typing import Generic, Self, TypeVar, override - -from codegen.sdk.codebase.resolution_stack import ResolutionStack -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSUndefinedType(Type[Parent], Generic[Parent]): - """Undefined type. Represents the undefined keyword - Examples: - undefined - """ - - @noapidoc - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): - pass - - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - yield from [] diff --git a/src/codegen/sdk/typescript/expressions/union_type.py b/src/codegen/sdk/typescript/expressions/union_type.py deleted file mode 100644 index b0df5b24d..000000000 --- a/src/codegen/sdk/typescript/expressions/union_type.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.expressions.union_type import UnionType -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.typescript.expressions.type import TSType - -Parent = TypeVar("Parent") - - -@ts_apidoc -class TSUnionType(UnionType["TSType", Parent], Generic[Parent]): - """Union type - - Examples: - string | number - """ - - pass diff --git a/src/codegen/sdk/typescript/external/dependency_manager.py b/src/codegen/sdk/typescript/external/dependency_manager.py deleted file mode 100644 index 84d1e12a5..000000000 --- a/src/codegen/sdk/typescript/external/dependency_manager.py +++ /dev/null @@ -1,376 +0,0 @@ -import concurrent.futures -import json -import os -import shutil -import subprocess -import uuid -from dataclasses import dataclass -from enum import Enum - -import pyjson5 -import requests - -from codegen.sdk.core.external.dependency_manager import DependencyManager -from codegen.sdk.utils import shadow_files -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class InstallerType(Enum): - NPM = "npm" - YARN = "yarn" - PNPM = "pnpm" - UNKNOWN = "unknown" - - -@dataclass -class PackageJsonData: - dependencies: dict[str, str] - dev_dependencies: dict[str, str] - package_data: dict - - -class TypescriptDependencyManager(DependencyManager): - should_install_dependencies: bool - installer_type: InstallerType - package_json_data: dict[str, PackageJsonData] - base_package_json_data: PackageJsonData | None - - """Handles dependency management for Typescript projects. Uses npm, yarn, or pnpm if applicable.""" - - def __init__(self, repo_path: str, base_path: str | None = None, should_install_dependencies: bool = True, force_installer: str | None = None): - super().__init__(repo_path, base_path) - logger.info(f"Initializing TypescriptDependencyManager with should_install_dependencies={should_install_dependencies}") - # Ensure that node, npm, yarn, and pnpm are installed - if not shutil.which("node"): - msg = "NodeJS is not installed" - raise RuntimeError(msg) - if not shutil.which("corepack"): - msg = "corepack is not installed" - raise RuntimeError(msg) - if not shutil.which("npm"): - msg = "npm is not installed" - raise RuntimeError(msg) - if not shutil.which("yarn"): - msg = "yarn is not installed" - raise RuntimeError(msg) - if not shutil.which("pnpm"): - msg = "pnpm is not installed" - raise RuntimeError(msg) - - self.should_install_dependencies = should_install_dependencies - # Detect the installer type - if force_installer: - self.installer_type = InstallerType(force_installer) - else: - self.installer_type = self._detect_installer_type() - - logger.info(f"Detected installer type: {self.installer_type}") - - # List of package.json files with their parsed data - self.package_json_data: dict[str, PackageJsonData] = {} - self.base_package_json_data: PackageJsonData | None = None - - def _detect_installer_type(self) -> InstallerType: - if os.path.exists(os.path.join(self.full_path, "yarn.lock")): - return InstallerType.YARN - elif os.path.exists(os.path.join(self.full_path, "package-lock.json")): - return InstallerType.NPM - elif os.path.exists(os.path.join(self.full_path, "pnpm-lock.yaml")): - return InstallerType.PNPM - else: - logger.warning("Could not detect installer type. Defaulting to NPM!") - return InstallerType.NPM - # return InstallerType.UNKNOWN - - @staticmethod - def _check_package_exists(package_name: str) -> bool: - """Check if a package exists on the npm registry.""" - url = f"https://registry.npmjs.org/{package_name}" - try: - response = requests.head(url) - return response.status_code == 200 - except requests.RequestException: - return False - - @classmethod - def _validate_dependencies(cls, deps: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]: - """Validate a dictionary of dependencies against npm registry.""" - valid_deps = {} - invalid_deps = {} - - # Use ThreadPoolExecutor for concurrent validation - with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: - future_to_package = {executor.submit(cls._check_package_exists, package): (package, version) for package, version in deps.items()} - - for future in concurrent.futures.as_completed(future_to_package): - package, version = future_to_package[future] - try: - exists = future.result() - # Hack to fix github packages - if "github" in version: - version = version.split("#")[0] - if exists: - valid_deps[package] = version - else: - invalid_deps[package] = version - except Exception as e: - logger.exception(f"Error checking package {package}: {e}") - - return valid_deps, invalid_deps - - def parse_dependencies(self): - # Clear the package_json_data - self.package_json_data.clear() - - # Walk through directory tree - for current_dir, subdirs, files in os.walk(self.full_path): - # Skip node_modules directories - if "node_modules" in current_dir: - continue - - # Check if package.json exists in current directory - if "package.json" in files: - # Convert to absolute path and append to results - package_json_path = os.path.join(current_dir, "package.json") - - # Parse the package.json file - try: - # Read package.json - with open(package_json_path) as f: - package_data = pyjson5.load(f) - - # Get dependencies and devDependencies - dependencies = package_data.get("dependencies", {}) - dev_dependencies = package_data.get("devDependencies", {}) - - self.package_json_data[package_json_path] = PackageJsonData(dependencies, dev_dependencies, package_data) - - except FileNotFoundError: - logger.exception(f"Could not find package.json at {package_json_path}") - except ValueError: - logger.exception(f"Invalid json in package.json at {package_json_path}") - except Exception as e: - raise e - - # Set the base package.json data - base_package_json_path = os.path.join(self.full_path, "package.json") - self.base_package_json_data = self.package_json_data.get(base_package_json_path, None) - - def _install_dependencies_npm(self): - logger.info("Installing dependencies with NPM") - # Shadow package-lock.json, if it exists - files_to_shadow = [] - # Check if package-lock.json exists. - if os.path.exists(os.path.join(self.full_path, "package-lock.json")): - files_to_shadow.append(os.path.join(self.full_path, "package-lock.json")) - - # Shadow the files - with shadow_files(files_to_shadow): - # Remove the original package-lock.json - for file_path in files_to_shadow: - os.remove(file_path) - - # Print the node version - logger.info(f"Node version: {subprocess.check_output(['node', '--version'], cwd=self.full_path, text=True).strip()}") - - # Print the npm version - logger.info(f"NPM version: {subprocess.check_output(['npm', '--version'], cwd=self.full_path, text=True).strip()}") - - # NPM Install - try: - logger.info(f"Running npm install with cwd {self.full_path}") - subprocess.run(["npm", "install"], cwd=self.full_path, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.exception(f"NPM FAIL: npm install failed with exit code {e.returncode}") - logger.exception(f"NPM FAIL stdout: {e.stdout}") - logger.exception(f"NPM FAIL stderr: {e.stderr}") - raise - - def _install_dependencies_yarn(self): - logger.info("Installing dependencies with Yarn") - # Shadow yarn.lock, yarn.config.cjs, and .yarnrc.yml, if they exist - files_to_shadow = [] - # Check if yarn.lock exists. - if os.path.exists(os.path.join(self.full_path, "yarn.lock")): - files_to_shadow.append(os.path.join(self.full_path, "yarn.lock")) - # Check if yarn.config.cjs exists. This fixes constraints - if os.path.exists(os.path.join(self.full_path, "yarn.config.cjs")): - files_to_shadow.append(os.path.join(self.full_path, "yarn.config.cjs")) - # Check if .yarnrc.yml exists. This fixes pre and post install scripts - if os.path.exists(os.path.join(self.full_path, ".yarnrc.yml")): - files_to_shadow.append(os.path.join(self.full_path, ".yarnrc.yml")) - - # Shadow the files - with shadow_files(files_to_shadow): - # If .yarnrc.yml exists, check if the yarnPath option is set and save it - yarn_path = None - if os.path.exists(os.path.join(self.full_path, ".yarnrc.yml")): - # Grab the line with "yarnPath" - with open(os.path.join(self.full_path, ".yarnrc.yml")) as f: - for line in f: - if "yarnPath" in line: - yarn_path = line.split(":")[1].strip() - break - # Remove all the shadowed files - for file_path in files_to_shadow: - os.remove(file_path) - - try: - # Disable PnP - with open(os.path.join(self.full_path, ".yarnrc.yml"), "w") as f: - f.write("nodeLinker: node-modules\n") - if yarn_path: - f.write(f"yarnPath: {yarn_path}\n") - - # Print the node version - logger.info(f"Node version: {subprocess.check_output(['node', '--version'], cwd=self.full_path, text=True).strip()}") - - # Print the yarn version - logger.info(f"Yarn version: {subprocess.check_output(['yarn', '--version'], cwd=self.full_path, text=True).strip()}") - - # This fixes a bug where swapping yarn versions corrups the metadata and package caches, - # causing all sorts of nasty issues - yarn_temp_global_dir: str = f"/tmp/yarn_tmp_{uuid.uuid4()}" - try: - # Yarn Install - try: - # Create custom flags for yarn - yarn_custom_flags = { - "YARN_ENABLE_IMMUTABLE_INSTALLS": "false", - "YARN_ENABLE_TELEMETRY": "false", - "YARN_ENABLE_GLOBAL_CACHE": "true", - "YARN_GLOBAL_FOLDER": yarn_temp_global_dir, - } - yarn_environ = { - **os.environ, - **yarn_custom_flags, - } - - # Set up yarn - logger.info(f"Running yarn install with cwd {self.full_path} and yarn_custom_flags {yarn_custom_flags}") - subprocess.run(["corepack", "enable"], cwd=self.full_path, check=True, capture_output=True, text=True) - subprocess.run(["corepack", "prepare", "--activate"], cwd=self.full_path, check=True, capture_output=True, text=True) - subprocess.run(["yarn", "install"], cwd=self.full_path, check=True, capture_output=True, text=True, env=yarn_environ) - except subprocess.CalledProcessError as e: - logger.exception(f"Yarn FAIL: yarn install failed with exit code {e.returncode}") - logger.exception(f"Yarn FAIL stdout: {e.stdout}") - logger.exception(f"Yarn FAIL stderr: {e.stderr}") - raise - finally: - # Clean up the temporary global directory - if os.path.exists(yarn_temp_global_dir): - shutil.rmtree(yarn_temp_global_dir) - finally: - # Check if the .yarnrc.yml file exists - if os.path.exists(os.path.join(self.full_path, ".yarnrc.yml")): - # Delete the .yarnrc.yml file - os.remove(os.path.join(self.full_path, ".yarnrc.yml")) - - def _install_dependencies_pnpm(self): - logger.info("Installing dependencies with PNPM") - # Shadow pnpm-lock.yaml, if it exists - files_to_shadow = [] - if os.path.exists(os.path.join(self.full_path, "pnpm-lock.yaml")): - files_to_shadow.append(os.path.join(self.full_path, "pnpm-lock.yaml")) - - # Shadow the files - with shadow_files(files_to_shadow): - # Remove all the shadowed files - for file_path in files_to_shadow: - os.remove(file_path) - - # Print the node version - logger.info(f"Node version: {subprocess.check_output(['node', '--version'], cwd=self.full_path, text=True).strip()}") - - # Print the pnpm version - logger.info(f"PNPM version: {subprocess.check_output(['pnpm', '--version'], cwd=self.full_path, text=True).strip()}") - - # PNPM Install - try: - logger.info(f"Running pnpm install with cwd {self.full_path}") - subprocess.run(["pnpm", "install"], cwd=self.full_path, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.exception(f"PNPM FAIL: pnpm install failed with exit code {e.returncode}") - logger.exception(f"PNPM FAIL stdout: {e.stdout}") - logger.exception(f"PNPM FAIL stderr: {e.stderr}") - raise - - def _clean_package_json(self, package_json_path: str): - # Get the package data - data = self.package_json_data[package_json_path] - - # Get valid dependencies - valid_deps, _ = self._validate_dependencies(data.dependencies) - valid_dev_deps, _ = self._validate_dependencies(data.dev_dependencies) - - # Create a slimmed down package.json with only the valid dependencies - clean_package_data = {} - - # Copy important fields - clean_package_data["name"] = data.package_data.get("name", "unknown") - clean_package_data["version"] = data.package_data.get("version", "v1.0.0") - if "packageManager" in data.package_data: - clean_package_data["packageManager"] = data.package_data["packageManager"] - if "workspaces" in data.package_data: - clean_package_data["workspaces"] = data.package_data["workspaces"] - - # Copy dependencies - clean_package_data["dependencies"] = valid_deps - clean_package_data["devDependencies"] = valid_dev_deps - - # Write the cleaned package.json - with open(package_json_path, "w") as f: - json_str = json.dumps(clean_package_data, indent=2) - f.write(json_str) - - def install_dependencies(self, validate_dependencies: bool = True): - if validate_dependencies: - with shadow_files(list(self.package_json_data.keys())): - logger.info(f"Cleaning package.json files: {list(self.package_json_data.keys())}") - with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: - executor.map(self._clean_package_json, self.package_json_data.keys()) - - # Install dependencies, now that we have a valid package.json - return self.install_dependencies(validate_dependencies=False) - else: - if self.installer_type == InstallerType.NPM: - return self._install_dependencies_npm() - elif self.installer_type == InstallerType.YARN: - return self._install_dependencies_yarn() - elif self.installer_type == InstallerType.PNPM: - return self._install_dependencies_pnpm() - else: - logger.warning(f"Installer type {self.installer_type} not implemented") - - def remove_dependencies(self): - # Delete node_modules folder if it exists - node_modules_path = os.path.join(self.full_path, "node_modules") - if os.path.exists(node_modules_path): - shutil.rmtree(node_modules_path) - - def _start(self): - try: - logger.info(f"Starting TypescriptDependencyManager with should_install_dependencies={self.should_install_dependencies}") - super()._start() - # Remove dependencies if we are installing them - if self.should_install_dependencies: - logger.info("Removing existing dependencies") - self.remove_dependencies() - - # Parse dependencies - logger.info("Parsing dependencies") - self.parse_dependencies() - - # Install dependencies if we are installing them - if self.should_install_dependencies: - logger.info("Installing dependencies") - self.install_dependencies() - - # We are ready - logger.info("Finalizing TypescriptDependencyManager") - self.is_ready = True - except Exception as e: - self._error = e - logger.error(f"Error starting TypescriptDependencyManager: {e}", exc_info=True) diff --git a/src/codegen/sdk/typescript/external/mega_racer.py b/src/codegen/sdk/typescript/external/mega_racer.py deleted file mode 100644 index ea0a9807a..000000000 --- a/src/codegen/sdk/typescript/external/mega_racer.py +++ /dev/null @@ -1,30 +0,0 @@ -from py_mini_racer import MiniRacer, init_mini_racer -from py_mini_racer._context import Context -from py_mini_racer._set_timeout import INSTALL_SET_TIMEOUT - - -class MegaRacer(MiniRacer): - """MegaRacer is a patch on MiniRacer that allows for more memory. - - Original MiniRacer: - MiniRacer evaluates JavaScript code using a V8 isolate. - - A MiniRacer instance can be explicitly closed using the close() method, or by using - the MiniRacer as a context manager, i.e,: - - with MiniRacer() as mr: - ... - - The MiniRacer instance will otherwise clean up the underlying V8 resource upon - garbage collection. - - Attributes: - json_impl: JSON module used by helper methods default is - [json](https://docs.python.org/3/library/json.html) - """ - - def __init__(self) -> None: - # Set the max old space size to 64GB - dll = init_mini_racer(ignore_duplicate_init=True, flags=["--max-old-space-size=65536"]) - self._ctx = Context(dll) - self.eval(INSTALL_SET_TIMEOUT) diff --git a/src/codegen/sdk/typescript/external/ts_analyzer_engine.py b/src/codegen/sdk/typescript/external/ts_analyzer_engine.py deleted file mode 100644 index 01405b8ce..000000000 --- a/src/codegen/sdk/typescript/external/ts_analyzer_engine.py +++ /dev/null @@ -1,250 +0,0 @@ -import json -import os -import shutil -import subprocess -import uuid -from abc import abstractmethod -from pathlib import Path -from typing import TYPE_CHECKING - -from py_mini_racer import MiniRacer -from py_mini_racer._objects import JSMappedObject -from py_mini_racer._types import JSEvalException - -from codegen.sdk.core.external.language_engine import LanguageEngine -from codegen.sdk.typescript.external.mega_racer import MegaRacer -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.core.external.dependency_manager import DependencyManager - from codegen.sdk.core.interfaces.editable import Editable - - -logger = get_logger(__name__) - - -class TypescriptEngine(LanguageEngine): - dependency_manager: "DependencyManager | None" - - def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: "DependencyManager | None" = None): - super().__init__(repo_path, base_path) - self.dependency_manager = dependency_manager - - @abstractmethod - def _start(self): - # If a dependency manager is provided, make sure it is ready - if self.dependency_manager: - logger.info(f"TypescriptEngine: Waiting for {self.dependency_manager.__class__.__name__} to be ready...") - self.dependency_manager.wait_until_ready(ignore_error=True) - # Start the engine - super()._start() - - -class V8TypescriptEngine(TypescriptEngine): - """Typescript-compiler based language engine using MiniRacer's V8-based JS engine. - - More experimental approach to type inference, but is faster and more flexible. - - Attributes: - hard_memory_limit (int): Maximum memory limit in bytes before V8 will force garbage collection - soft_memory_limit (int): Memory threshold in bytes that triggers garbage collection - """ - - hard_memory_limit: int - soft_memory_limit: int - ctx: MiniRacer | None - mr_type_script_analyzer: JSMappedObject | None - - def __init__( - self, - repo_path: str, - base_path: str | None = None, - dependency_manager: "DependencyManager | None" = None, - hard_memory_limit: int = 1024 * 1024 * 1024 * 16, - soft_memory_limit: int = 1024 * 1024 * 1024 * 8, - ): - super().__init__(repo_path, base_path, dependency_manager) - logger.info(f"Initializing V8TypescriptEngine with hard_memory_limit={hard_memory_limit} and soft_memory_limit={soft_memory_limit}") - self.hard_memory_limit: int = hard_memory_limit - self.soft_memory_limit: int = soft_memory_limit - self.ctx: MiniRacer | None = None - self.mr_type_script_analyzer: JSMappedObject | None = None - # Get the path to the current file - self.current_file_path: str = os.path.abspath(__file__) - # Get the path of the language engine - self.engine_path: str = os.path.join(os.path.dirname(self.current_file_path), "typescript_analyzer", "dist", "index.js") - if not os.path.exists(self.engine_path): - msg = f"Typescript analyzer engine not found at {self.engine_path}" - raise FileNotFoundError(msg) - self.engine_source: str = open(self.engine_path).read() - self._patch_engine_source() - - def _start(self): - try: - logger.info("Starting V8TypescriptEngine") - super()._start() - # Create the MiniRacer/MegaRacer context - self.ctx = MegaRacer() # MegaRacer is a patch on MiniRacer that allows for more memory - # Set to 16GB - self.ctx.set_hard_memory_limit(self.hard_memory_limit) - self.ctx.set_soft_memory_limit(self.soft_memory_limit) - - # Load the engine - logger.info(f"Loading engine source with {len(self.engine_source)} bytes") - self.ctx.eval(self.engine_source) - - # Set up proxy file system - logger.info("Setting up proxy file system") - self.ctx.eval("var interop_fs = new ProxyFileSystem();") - self.ctx.eval("var fs_files = {};") - fs_files = self.ctx.eval("fs_files") - self._populate_fs_files(fs_files) - self.ctx.eval("fs_file_map = new Map(Object.entries(fs_files));") - self.ctx.eval("interop_fs.setFiles(fs_file_map);") - - # Set up the analyzer - logger.info(f"Setting up analyzer with path {self.full_path}") - self.ctx.eval(f"const type_script_analyzer = new TypeScriptAnalyzer('{self.full_path}', interop_fs);") - self.mr_type_script_analyzer = self.ctx.eval("type_script_analyzer") - - # Finalize - logger.info("Finalizing V8TypescriptEngine") - self.is_ready = True - except Exception as e: - self._error = e - logger.error(f"Error starting V8TypescriptEngine: {e}", exc_info=True) - - def _populate_fs_files(self, fs_files: dict): - for root, _, files in os.walk(self.full_path): - for filename in files: - file_path = Path(root) / filename - s_fp = str(file_path) - - # Only process JS/TS related files - if not s_fp.endswith((".ts", ".tsx", ".js", ".jsx", ".json", ".d.ts")): - continue - - try: - with open(file_path, encoding="utf-8") as f: - if "node_modules" in s_fp: - if not s_fp.endswith(".json") and not s_fp.endswith(".d.ts"): - continue - content = f.read() - fs_files[str(file_path)] = content - except (UnicodeDecodeError, OSError): - # Skip files that can't be read as text - continue - - def _patch_engine_source(self): - """MiniRacer does not support require and export, so we need to patch the engine source to remove them.""" - logger.info("Patching engine source to remove require and export") - patch_map = { - "var require$$1 = require('fs');": "", - "var require$$2 = require('path');": "", - "var require$$3 = require('os');": "", - "var require$$6 = require('inspector');": "", - "exports.ProxyFileSystem = ProxyFileSystem;": "", - "exports.TypeScriptAnalyzer = TypeScriptAnalyzer;": "", - } - for old, new in patch_map.items(): - self.engine_source = self.engine_source.replace(old, new) - - def get_return_type(self, node: "Editable") -> str | None: - file_path = os.path.join(self.repo_path, node.filepath) - try: - return self.ctx.eval(f"type_script_analyzer.getFunctionAtPosition('{file_path}', {node.start_byte})") - except JSEvalException as e: - return None - - -class NodeTypescriptEngine(TypescriptEngine): - """Typescript-compiler based language engine using NodeJS and the TypeScript compiler. - - More mature approach to type inference, but is slower and less flexible. - - Attributes: - type_data (dict | None): Type data for the codebase - """ - - type_data: dict | None - - def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: "DependencyManager | None" = None): - super().__init__(repo_path, base_path, dependency_manager) - logger.info("Initializing NodeTypescriptEngine") - self.type_data: dict | None = None - - # Get the path to the current file - self.current_file_path: str = os.path.abspath(__file__) - # Ensure NodeJS and npm are installed - if not shutil.which("node") or not shutil.which("npm"): - msg = "NodeJS or npm is not installed" - raise RuntimeError(msg) - - # Get the path to the typescript analyzer - self.analyzer_path: str = os.path.join(os.path.dirname(self.current_file_path), "typescript_analyzer") - self.analyzer_entry: str = os.path.join(self.analyzer_path, "src", "run_full.ts") - if not os.path.exists(self.analyzer_path): - msg = f"Typescript analyzer not found at {self.analyzer_path}" - raise FileNotFoundError(msg) - - def _start(self): - try: - logger.info("Starting NodeTypescriptEngine") - super()._start() - # NPM Install - try: - logger.info("Installing typescript analyzer dependencies") - subprocess.run(["npm", "install"], cwd=self.analyzer_path, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.exception(f"NPM FAIL: npm install failed with exit code {e.returncode}") - logger.exception(f"NPM FAIL stdout: {e.stdout}") - logger.exception(f"NPM FAIL stderr: {e.stderr}") - raise - - # Create a temporary output file with a random name - output_file_path: str = f"/tmp/ts_analyzer_output_{uuid.uuid4()}.json" - try: - # Run the analyzer - try: - # Create custom flags for node - node_environ = {**os.environ, "NODE_OPTIONS": "--max_old_space_size=8192"} - - # Run the analyzer - logger.info(f"Running analyzer with project path {self.full_path} and output file {output_file_path}") - subprocess.run( - ["node", "--loader", "ts-node/esm", self.analyzer_entry, "--project", self.full_path, "--output", output_file_path], - cwd=self.analyzer_path, - check=True, - capture_output=True, - text=True, - env=node_environ, - ) - except subprocess.CalledProcessError as e: - logger.exception(f"ANALYZER FAIL: analyzer failed with exit code {e.returncode}") - logger.exception(f"ANALYZER FAIL stdout: {e.stdout}") - logger.exception(f"ANALYZER FAIL stderr: {e.stderr}") - raise - - # Load the type data - self.type_data = json.load(open(output_file_path)) - finally: - # Clean up the output file - if os.path.exists(output_file_path): - os.remove(output_file_path) - - # Finalize - logger.info("Finalizing NodeTypescriptEngine") - self.is_ready = True - except Exception as e: - self._error = e - logger.error(f"Error starting NodeTypescriptEngine: {e}", exc_info=True) - - def get_return_type(self, node: "Editable") -> str | None: - file_path: str = os.path.join(self.repo_path, node.filepath) - if not self.type_data: - return None - codebase_data: dict = self.type_data.get("files", {}) - file_data: dict = codebase_data.get(file_path, {}) - functions_data: dict = file_data.get("functions", {}) - function_data: dict = functions_data.get(node.name, {}) - return function_data.get("returnType", None) diff --git a/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py b/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py deleted file mode 100644 index 318f4add6..000000000 --- a/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import shutil -import subprocess - -from codegen.sdk.core.external.external_process import ExternalProcess -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class TSDeclassify(ExternalProcess): - def __init__(self, repo_path: str, base_path: str, working_dir: str = "/tmp/ts_declassify"): - super().__init__(repo_path, base_path) - self.working_dir = working_dir - - # Ensure NodeJS and npm are installed - if not shutil.which("node") or not shutil.which("npm"): - msg = "NodeJS or npm is not installed" - raise RuntimeError(msg) - - def _start(self): - try: - logger.info("Installing ts-declassify...") - - # Remove existing working directory - if os.path.exists(self.working_dir): - shutil.rmtree(self.working_dir) - - # Creating ts-declassify working directory - os.makedirs(self.working_dir, exist_ok=True) - - # NPM Init - try: - logger.info(f"Running npm init in {self.working_dir}") - subprocess.run(["npm", "init", "-y"], cwd=self.working_dir, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.exception(f"NPM FAIL: npm init failed with exit code {e.returncode}") - logger.exception(f"NPM FAIL stdout: {e.stdout}") - logger.exception(f"NPM FAIL stderr: {e.stderr}") - raise - - # NPM Install - try: - logger.info(f"Running npm install in {self.working_dir}") - subprocess.run(["npm", "install", "-D", "@codemod/cli", "react-declassify"], cwd=self.working_dir, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.exception(f"NPM FAIL: npm install failed with exit code {e.returncode}") - logger.exception(f"NPM FAIL stdout: {e.stdout}") - logger.exception(f"NPM FAIL stderr: {e.stderr}") - raise - - # Finalize - self.is_ready = True - except Exception as e: - self._error = e - logger.exception(f"Error installing ts-declassify: {e}") - raise e - - def reparse(self): - msg = "TSDeclassify does not support reparse" - raise NotImplementedError(msg) - - def declassify(self, source: str, filename: str = "file.tsx", error_on_failure: bool = True): - assert self.ready(), "TSDeclassify is not ready" - - try: - # Remove and recreate file.tsx - source_file = os.path.join(self.working_dir, filename) - with open(source_file, "w") as f: - f.write(source) - - # Run declassify - try: - subprocess.run(["npx", "codemod", "--plugin", "react-declassify", source_file], cwd=self.working_dir, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - logger.exception(f"DECLASSIFY FAIL: declassify failed with exit code {e.returncode}") - logger.exception(f"DECLASSIFY FAIL stdout: {e.stdout}") - logger.exception(f"DECLASSIFY FAIL stderr: {e.stderr}") - raise - - # Get the declassified source - with open(source_file) as f: - declassified_source = f.read() - - # Raise an error if the declassification failed - if error_on_failure and "Cannot perform transformation" in declassified_source: - msg = "Declassification failed!" - raise RuntimeError(msg) - finally: - # Remove file.tsx if it exists - if os.path.exists(source_file): - os.remove(source_file) - - return declassified_source diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore b/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore deleted file mode 100644 index 930dd1b95..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Typescript Analyzer Specific GitIgnores -node_modules -dist -package-lock.json diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/package.json b/src/codegen/sdk/typescript/external/typescript_analyzer/package.json deleted file mode 100644 index 894321d00..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/package.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "name": "typescript-analyzer", - "version": "1.0.0", - "description": "TypeScript project analyzer", - "main": "dist/index.js", - "types": "dist/index.d.ts", - "type": "module", - "scripts": { - "build": "rollup -c", - "start": "node --loader ts-node/esm src/run_full.ts", - "analyze": "node --loader ts-node/esm src/run_full.ts --project", - "get-type": "node --loader ts-node/esm src/get_type_at_position.ts" - }, - "dependencies": { - "typescript": "^5.0.0", - "yargs": "^17.7.2" - }, - "devDependencies": { - "@rollup/plugin-commonjs": "^28.0.0", - "@rollup/plugin-json": "^6.0.0", - "@rollup/plugin-node-resolve": "^15.0.0", - "@rollup/plugin-typescript": "^12.0.0", - "@rollup/plugin-virtual": "^3.0.2", - "@types/node": "^22.0.0", - "@types/yargs": "^17.0.32", - "rollup": "^4.9.0", - "ts-node": "^10.9.1", - "tslib": "^2.6.0" - } -} diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/rollup.config.js b/src/codegen/sdk/typescript/external/typescript_analyzer/rollup.config.js deleted file mode 100644 index 79a35c5f9..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/rollup.config.js +++ /dev/null @@ -1,36 +0,0 @@ -import { builtinModules } from "node:module"; -import commonjs from "@rollup/plugin-commonjs"; -import resolve from "@rollup/plugin-node-resolve"; -import typescript from "@rollup/plugin-typescript"; - -export default { - input: "src/index.ts", - output: { - file: "dist/index.js", - format: "cjs", - sourcemap: false, - }, - // Only exclude Node.js built-in modules that can't be bundled - external: builtinModules, - plugins: [ - // Resolve node_modules dependencies - resolve({ - preferBuiltins: false, - mainFields: ["module", "main"], - // Bundle node_modules content - modulesOnly: false, - }), - // Convert CommonJS modules to ES6 - commonjs({ - ignoreTryCatch: true, - // Include node_modules - include: /node_modules/, - }), - // Handle TypeScript - typescript({ - tsconfig: "./tsconfig.json", - declaration: true, - declarationDir: "dist", - }), - ], -}; diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/analyzer.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/analyzer.ts deleted file mode 100644 index 34c6d5b37..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/analyzer.ts +++ /dev/null @@ -1,410 +0,0 @@ -import * as ts from "typescript"; -import type { FileSystemInterface } from "./fsi"; - -function resolvePath(p: string): string { - // Simple resolve that just returns the path as-is - // In a real implementation this would handle .. and . segments - return p; -} - -function getDirname(p: string): string { - // Simple dirname that returns everything before the last / - const lastSlash = p.lastIndexOf("/"); - if (lastSlash === -1) return "."; - return p.slice(0, lastSlash); -} - -const TYPE_FORMAT_FLAGS = ts.TypeFormatFlags.NoTruncation; - -export interface FunctionInfo { - name: string; - returnType: string; - parameters: string; - kind: string; - filePath: string; // Added to track which file the function is from -} - -export class TypeScriptAnalyzer { - private program: ts.Program; - private typeChecker: ts.TypeChecker; - - constructor(projectPath: string, fileSystem?: FileSystemInterface) { - // Create a custom compiler host if custom file system functions are provided - const compilerHost: ts.CompilerHost = fileSystem - ? { - getSourceFile: ( - fileName: string, - languageVersion: ts.ScriptTarget, - onError?: (message: string) => void, - ) => { - const sourceText = fileSystem.readFile(fileName); - return sourceText - ? ts.createSourceFile(fileName, sourceText, languageVersion) - : undefined; - }, - getDefaultLibFileName: (defaultLibOptions: ts.CompilerOptions) => - `/${ts.getDefaultLibFileName(defaultLibOptions)}`, - writeFile: fileSystem.writeFile || (() => {}), - getCurrentDirectory: () => "/", - getDirectories: fileSystem.getDirectories || (() => []), - readDirectory: fileSystem.readDirectory || (() => []), - fileExists: (fileName: string) => fileSystem.fileExists(fileName), - readFile: (fileName: string) => fileSystem.readFile(fileName), - getCanonicalFileName: (fileName: string) => fileName, - useCaseSensitiveFileNames: () => true, - getNewLine: () => "\n", - getEnvironmentVariable: () => "", - } - : ts.createCompilerHost({}); - - // Create a custom parse config host - const parseConfigHost: ts.ParseConfigHost = { - useCaseSensitiveFileNames: true, - readDirectory: ( - path: string, - extensions?: readonly string[], - exclude?: readonly string[], - include?: readonly string[], - depth?: number, - ) => { - return fileSystem - ? fileSystem.readDirectory?.(path) || [] - : ts.sys.readDirectory(path, extensions, exclude, include, depth); - }, - fileExists: fileSystem ? fileSystem.fileExists : ts.sys.fileExists, - readFile: fileSystem ? fileSystem.readFile : ts.sys.readFile, - }; - - // Find the base tsconfig.json file - const baseConfigPath = ts.findConfigFile( - projectPath, - fileSystem ? fileSystem.fileExists : ts.sys.fileExists, - "tsconfig.json", - ); - - if (!baseConfigPath) { - throw new Error("Could not find a valid 'tsconfig.json'."); - } - - // Parse the base config file - const baseConfig = ts.readConfigFile( - baseConfigPath, - fileSystem ? fileSystem.readFile : ts.sys.readFile, - ); - - if (baseConfig.error) { - throw new Error( - `Error reading tsconfig.json: ${baseConfig.error.messageText}`, - ); - } - // Parse the config content - const parsedConfig = ts.parseJsonConfigFileContent( - baseConfig.config, - parseConfigHost, - projectPath, - ); - - // Find all tsconfig files and parse them - const allConfigPaths = this.findTsConfigFiles(projectPath, fileSystem); - const allFileNames = new Set(parsedConfig.fileNames); - - // Add files from each config - for (const configPath of allConfigPaths) { - if (configPath === baseConfigPath) continue; - - const config = ts.readConfigFile( - configPath, - fileSystem ? fileSystem.readFile : ts.sys.readFile, - ); - if (!config.error) { - const parsed = ts.parseJsonConfigFileContent( - config.config, - parseConfigHost, - getDirname(configPath), - ); - for (const f of parsed.fileNames) { - allFileNames.add(f); - } - } - } - - // Create program with custom host if provided - this.program = ts.createProgram({ - rootNames: Array.from(allFileNames), - options: parsedConfig.options, - host: compilerHost, - }); - this.typeChecker = this.program.getTypeChecker(); - } - - private findTsConfigFiles( - projectPath: string, - fileSystem?: FileSystemInterface, - ): string[] { - const tsconfigPaths: string[] = []; - const readDirectory = fileSystem - ? fileSystem.readDirectory - : ts.sys.readDirectory; - - // Get all files recursively from the directory - const entries = readDirectory?.(projectPath) || []; - - // Filter for tsconfig.json files, excluding node_modules and hidden directories - for (const entry of entries) { - if ( - entry.includes("node_modules") || - entry.split("/").some((part) => part.startsWith(".")) - ) { - continue; - } - - if (entry.endsWith("/tsconfig.json")) { - tsconfigPaths.push(entry); - } - } - - return tsconfigPaths; - } - - getFunctionAtPosition(filePath: string, position: number): string { - const resolvedPath = resolvePath(filePath); - const sourceFile = this.program.getSourceFile(resolvedPath); - - if (!sourceFile) { - throw new Error(`Could not find source file: ${filePath}`); - } - - // Find the node at the exact position - function findNodeAtPosition(node: ts.Node): ts.Node | undefined { - if (position >= node.getStart() && position < node.getEnd()) { - // Check children first to get the most specific node - let matchingChild: ts.Node | undefined; - ts.forEachChild(node, (child) => { - const foundNode = findNodeAtPosition(child); - if (foundNode) { - matchingChild = foundNode; - } - }); - return matchingChild || node; - } - return undefined; - } - - // Enhanced findContainingFunction to handle arrow functions - function findContainingFunction(node: ts.Node): ts.Node | undefined { - if (!node) return undefined; - - // Check if the node itself is a function-like declaration - if (ts.isFunctionLike(node)) { - return node; - } - - // Handle variable declarations with arrow functions - if (ts.isVariableDeclaration(node)) { - const initializer = node.initializer; - if (initializer && ts.isArrowFunction(initializer)) { - return initializer; - } - } - - // Handle property assignments with arrow functions - if (ts.isPropertyAssignment(node)) { - const initializer = node.initializer; - if (initializer && ts.isArrowFunction(initializer)) { - return initializer; - } - } - - // Handle binary expressions (e.g., assignments) with arrow functions - if (ts.isBinaryExpression(node)) { - const right = node.right; - if (right && ts.isArrowFunction(right)) { - return right; - } - } - - // Recursively check parent nodes - return node.parent ? findContainingFunction(node.parent) : undefined; - } - - // Find the node at position and its containing function - const nodeAtPosition = findNodeAtPosition(sourceFile); - if (!nodeAtPosition) { - throw new Error(`No node found at position ${position}`); - } - - const containingFunction = findContainingFunction(nodeAtPosition); - if (!containingFunction) { - throw new Error(`No function found at position ${position}`); - } - - // Get the function's return type - if ( - ts.isArrowFunction(containingFunction) || - ts.isFunctionLike(containingFunction) - ) { - // Try to get the signature directly first - const signature = this.typeChecker.getSignatureFromDeclaration( - containingFunction as ts.SignatureDeclaration, - ); - - if (signature) { - const returnType = this.typeChecker.getReturnTypeOfSignature(signature); - return this.typeChecker.typeToString( - returnType, - undefined, - TYPE_FORMAT_FLAGS, - ); - } - - // If no direct signature, try to get it from the type - const type = this.typeChecker.getTypeAtLocation(containingFunction); - const signatures = this.typeChecker.getSignaturesOfType( - type, - ts.SignatureKind.Call, - ); - - if (signatures.length > 0) { - const returnType = this.typeChecker.getReturnTypeOfSignature( - signatures[0], - ); - return this.typeChecker.typeToString( - returnType, - undefined, - TYPE_FORMAT_FLAGS, - ); - } - } - - throw new Error("Could not determine function return type"); - } - - private getParameterTypesString(node: ts.SignatureDeclaration): string { - return `(${node.parameters - .map((param) => { - const paramType = this.typeChecker.getTypeAtLocation(param); - return `${param.name.getText()}: ${this.typeChecker.typeToString(paramType, undefined, TYPE_FORMAT_FLAGS)}`; - }) - .join(", ")})`; - } - - private getFunctionKind(node: ts.Node): string { - if (ts.isFunctionDeclaration(node)) return "function"; - if (ts.isArrowFunction(node)) return "arrow function"; - if (ts.isMethodDeclaration(node)) return "method"; - if (ts.isFunctionExpression(node)) return "function expression"; - return "unknown"; - } - - private getNodeReturnType(node: ts.Node): string { - const signature = this.typeChecker.getSignatureFromDeclaration( - node as ts.SignatureDeclaration, - ); - - if (signature) { - const returnType = this.typeChecker.getReturnTypeOfSignature(signature); - return this.typeChecker.typeToString( - returnType, - undefined, - TYPE_FORMAT_FLAGS, - ); - } - - const type = this.typeChecker.getTypeAtLocation(node); - const signatures = this.typeChecker.getSignaturesOfType( - type, - ts.SignatureKind.Call, - ); - if (signatures.length > 0) { - const returnType = this.typeChecker.getReturnTypeOfSignature( - signatures[0], - ); - return this.typeChecker.typeToString( - returnType, - undefined, - TYPE_FORMAT_FLAGS, - ); - } - - return "unknown"; - } - - private analyzeFunctionsInFile(sourceFile: ts.SourceFile): FunctionInfo[] { - const functions: FunctionInfo[] = []; - const filePath = sourceFile.fileName; - - const visit = (node: ts.Node) => { - // Handle function declarations - if (ts.isFunctionDeclaration(node) && node.name) { - functions.push({ - name: node.name.getText(), - returnType: this.getNodeReturnType(node), - parameters: this.getParameterTypesString(node), - kind: "function", - filePath, - }); - } - // Handle arrow functions and function expressions in variable declarations - else if (ts.isVariableStatement(node)) { - for (const declaration of node.declarationList.declarations) { - if (ts.isVariableDeclaration(declaration) && declaration.name) { - const initializer = declaration.initializer; - if ( - initializer && - (ts.isArrowFunction(initializer) || - ts.isFunctionExpression(initializer)) - ) { - functions.push({ - name: declaration.name.getText(), - returnType: this.getNodeReturnType(initializer), - parameters: this.getParameterTypesString(initializer), - kind: this.getFunctionKind(initializer), - filePath, - }); - } - } - } - } - // Handle class methods - else if (ts.isMethodDeclaration(node) && node.name) { - const parentClass = node.parent; - const className = - ts.isClassDeclaration(parentClass) && parentClass.name - ? `${parentClass.name.getText()}.` - : ""; - functions.push({ - name: className + node.name.getText(), - returnType: this.getNodeReturnType(node), - parameters: this.getParameterTypesString(node), - kind: "method", - filePath, - }); - } - - ts.forEachChild(node, visit); - }; - - ts.forEachChild(sourceFile, visit); - return functions; - } - - getAllFunctionsInProject(): FunctionInfo[] { - const functions: FunctionInfo[] = []; - - // Get all source files from the program - const sourceFiles = this.program.getSourceFiles(); - - // Filter out declaration files and analyze each source file - for (const sourceFile of sourceFiles) { - if ( - !sourceFile.isDeclarationFile && - !sourceFile.fileName.includes("node_modules") - ) { - const fileFunctions = this.analyzeFunctionsInFile(sourceFile); - functions.push(...fileFunctions); - } - } - - return functions; - } -} diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/fs_proxy.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/fs_proxy.ts deleted file mode 100644 index d2a20f8f9..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/fs_proxy.ts +++ /dev/null @@ -1,120 +0,0 @@ -import type { FileSystemInterface } from "./fsi"; - -export class ProxyFileSystem implements FileSystemInterface { - public files: Map = new Map(); - - constructor() { - // Bind methods to ensure correct 'this' context - this.setFiles = this.setFiles.bind(this); - this.fileExists = this.fileExists.bind(this); - this.readFile = this.readFile.bind(this); - this.readDirectory = this.readDirectory.bind(this); - this.getDirectories = this.getDirectories.bind(this); - this.normalizePath = this.normalizePath.bind(this); - this.getParentDirectory = this.getParentDirectory.bind(this); - this.debugPrintFiles = this.debugPrintFiles.bind(this); - } - - setFiles(files: Map): void { - this.files = files; - } - - addFile(path: string, content: string): void { - const normalized = this.normalizePath(path); - this.files.set(normalized, content); - } - - readFile = (path: string): string | undefined => { - const normalized = this.normalizePath(path); - console.log(`Reading file: ${normalized}`); - return this.files.get(normalized); - }; - - fileExists = (path: string): boolean => { - const normalized = this.normalizePath(path); - console.log(`Checking if file exists: ${normalized}`); - - // Direct file check - if (this.files.has(normalized)) { - return true; - } - - // For tsconfig.json, check parent directories - if (path.endsWith("tsconfig.json")) { - let currentDir = normalized; - while (currentDir !== "/") { - currentDir = this.getParentDirectory(currentDir); - const configPath = this.normalizePath(`${currentDir}/tsconfig.json`); - if (this.files.has(configPath)) { - return true; - } - } - // Check root - return this.files.has("/tsconfig.json"); - } - - return false; - }; - - readDirectory = (path: string): string[] => { - const normalized = this.normalizePath(path); - console.log(`Reading directory: ${normalized}`); - - const files: string[] = []; - for (const filePath of this.files.keys()) { - if (filePath.startsWith(normalized)) { - files.push(filePath); - } - } - return files; - }; - - getDirectories = (path: string): string[] => { - const normalized = this.normalizePath(path); - console.log(`Getting directories under: ${normalized}`); - - const directories = new Set(); - for (const filePath of this.files.keys()) { - if (filePath.startsWith(normalized)) { - // Get relative path from the requested directory - const relativePath = filePath.slice(normalized.length); - if (relativePath) { - // Split the relative path and look for directories - const parts = relativePath.split("/").filter((p) => p); - if (parts.length > 1) { - // If there are subdirectories - directories.add(parts[0]); // Add first subdirectory - } - } - } - } - return Array.from(directories); - }; - - protected normalizePath(path: string): string { - // Remove any './' or multiple slashes and ensure leading slash - let normalized = path.replace(/\/\.\//g, "/").replace(/\/+/g, "/"); - if (!normalized.startsWith("/")) { - normalized = `/${normalized}`; - } - return normalized; - } - - protected getParentDirectory(path: string): string { - const normalized = this.normalizePath(path); - const lastSlash = normalized.lastIndexOf("/"); - if (lastSlash <= 0) return "/"; - return normalized.slice(0, lastSlash) || "/"; - } - - debugPrintFiles(): void { - console.log("\nProxy File System Contents:"); - for (const [path, content] of this.files.entries()) { - console.log(`\nFile: ${path}`); - console.log( - "Content:", - content.slice(0, 100) + (content.length > 100 ? "..." : ""), - ); - } - } -} diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/fsi.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/fsi.ts deleted file mode 100644 index 7500416df..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/fsi.ts +++ /dev/null @@ -1,7 +0,0 @@ -export interface FileSystemInterface { - readFile: (path: string) => string | undefined; - writeFile?: (path: string, data: string) => void; - readDirectory?: (path: string) => string[]; - getDirectories?: (path: string) => string[]; - fileExists: (path: string) => boolean; -} diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/get_type_at_position.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/get_type_at_position.ts deleted file mode 100644 index 8db1334e5..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/get_type_at_position.ts +++ /dev/null @@ -1,57 +0,0 @@ -import path from "node:path"; -import yargs from "yargs"; -import { hideBin } from "yargs/helpers"; -import { TypeScriptAnalyzer } from "./analyzer.js"; - -function parseArgs() { - return yargs(hideBin(process.argv)) - .option("project", { - alias: "p", - type: "string", - description: "Path to the TypeScript project root", - demandOption: true, - }) - .option("file", { - alias: "f", - type: "string", - description: "Path to the specific TypeScript file", - demandOption: true, - }) - .option("position", { - alias: "pos", - type: "number", - description: "Byte position in the file", - demandOption: true, - }) - .help() - .parseSync(); -} - -function main() { - try { - const argv = parseArgs(); - const projectPath = path.resolve(argv.project); - const filePath = path.resolve(argv.file); - const position = argv.position; - - // Create analyzer instance - const analyzer = new TypeScriptAnalyzer(projectPath); - - try { - // Get return type at position - const returnType = analyzer.getFunctionAtPosition(filePath, position); - // Print only the return type, nothing else - console.log(returnType); - } catch (error) { - // Print just the error message without any formatting - console.error(error instanceof Error ? error.message : "Unknown error"); - process.exit(1); - } - } catch (error) { - // Handle project initialization errors - console.error(error instanceof Error ? error.message : "Unknown error"); - process.exit(1); - } -} - -main(); diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/index.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/index.ts deleted file mode 100644 index 4e127a300..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export { TypeScriptAnalyzer, FunctionInfo } from "./analyzer"; -export { ProxyFileSystem } from "./fs_proxy"; -export { FileSystemInterface } from "./fsi"; diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/run_full.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/run_full.ts deleted file mode 100644 index f255dfe35..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/run_full.ts +++ /dev/null @@ -1,140 +0,0 @@ -import * as fs from "node:fs"; -import * as path from "node:path"; -import yargs from "yargs"; -import { hideBin } from "yargs/helpers"; -import { type FunctionInfo, TypeScriptAnalyzer } from "./analyzer"; - -interface FunctionAnalysis { - name: string; - returnType: string; - parameters?: string; - kind?: string; -} - -interface FileAnalysis { - relativePath: string; - functions: { [functionName: string]: FunctionAnalysis }; -} - -interface AnalyzerOutput { - projectPath: string; - analysisDate: string; - files: { [filePath: string]: FileAnalysis }; - summary: { - totalFiles: number; - totalFunctions: number; - }; -} - -// Parse command line arguments -const argv = yargs(hideBin(process.argv)) - .option("project", { - alias: "p", - type: "string", - description: "Path to the TypeScript project", - demandOption: true, - }) - .option("output", { - alias: "o", - type: "string", - description: "Output JSON file path", - default: "typescript-analysis.json", - }) - .option("minimal", { - alias: "m", - type: "boolean", - description: "Output only function names and return types", - default: false, - }) - .option("pretty", { - type: "boolean", - description: "Pretty print JSON output", - default: true, - }) - .help() - .parseSync(); - -function groupFunctionsByFile( - functions: FunctionInfo[], - projectPath: string, - minimal: boolean, -): AnalyzerOutput["files"] { - const files: AnalyzerOutput["files"] = {}; - - for (const func of functions) { - const relativePath = path.relative(projectPath, func.filePath); - - if (!files[func.filePath]) { - files[func.filePath] = { - relativePath, - functions: {}, - }; - } - - const functionAnalysis: FunctionAnalysis = { - name: func.name, - returnType: func.returnType, - ...(minimal - ? {} - : { - parameters: func.parameters, - kind: func.kind, - }), - }; - - files[func.filePath].functions[func.name] = functionAnalysis; - } - - return files; -} - -async function main() { - try { - // Resolve absolute paths - const projectPath = path.resolve(argv.project as string); - const outputPath = path.resolve(argv.output as string); - - console.log(`Analyzing TypeScript project at: ${projectPath}`); - - // Create analyzer instance - const analyzer = new TypeScriptAnalyzer(projectPath); - - // Get all functions - const functions = analyzer.getAllFunctionsInProject(); - - // Group functions by file - const groupedFiles = groupFunctionsByFile( - functions, - projectPath, - argv.minimal as boolean, - ); - - // Prepare output data - const output: AnalyzerOutput = { - projectPath, - analysisDate: new Date().toISOString(), - files: groupedFiles, - summary: { - totalFiles: Object.keys(groupedFiles).length, - totalFunctions: functions.length, - }, - }; - - // Write to file - fs.writeFileSync( - outputPath, - JSON.stringify(output, null, argv.pretty ? 2 : 0), - ); - - console.log("\nAnalysis complete!"); - console.log(`Output written to: ${outputPath}`); - console.log("\nSummary:"); - console.log(`- Total files analyzed: ${output.summary.totalFiles}`); - console.log(`- Total functions found: ${output.summary.totalFunctions}`); - } catch (error) { - console.error("Error during analysis:", error); - process.exit(1); - } -} - -main(); diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/src/test_fsi.ts b/src/codegen/sdk/typescript/external/typescript_analyzer/src/test_fsi.ts deleted file mode 100644 index bfa90cbdc..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/src/test_fsi.ts +++ /dev/null @@ -1,103 +0,0 @@ -import * as ts from "typescript"; -import { TypeScriptAnalyzer } from "./analyzer"; -import { ProxyFileSystem } from "./fs_proxy"; - -class MockTestFileSystem extends ProxyFileSystem { - constructor() { - super(); - - // Add a mock tsconfig.json with all required fields - this.addFile( - "/tsconfig.json", - JSON.stringify({ - compilerOptions: { - target: "ES2020", - module: "ES2020", - strict: true, - esModuleInterop: true, - skipLibCheck: true, - forceConsistentCasingInFileNames: true, - }, - files: ["test.ts"], - include: ["**/*"], - exclude: ["node_modules"], - }), - ); - - // Add typescript lib files that might be needed - this.addFile("/lib.es2020.d.ts", ""); // Empty placeholder - this.addFile("/lib.dom.d.ts", ""); // Empty placeholder - this.addFile("/lib.dom.iterable.d.ts", ""); // Empty placeholder - - // Add a sample TypeScript file - this.addFile( - "/src/test.ts", - ` - export function basicFunction(x: number): string { - return x.toString(); - } - - export const arrowFunction = (y: string): number => { - return parseInt(y); - }; - - export class TestClass { - classMethod(z: boolean): void { - console.log(z); - } - } - `, - ); - } -} - -async function runTests() { - console.log("Starting FileSystemInterface tests...\n"); - - const mockFS = new MockTestFileSystem(); - mockFS.debugPrintFiles(); - - try { - console.log("\nCreating TypeScriptAnalyzer with mock file system..."); - const analyzer = new TypeScriptAnalyzer("/", mockFS); - - console.log("\nTesting getAllFunctionsInProject():"); - const functions = analyzer.getAllFunctionsInProject(); - console.log(`Found ${functions.length} functions:`); - for (const func of functions) { - console.log(`- ${func.name} (${func.kind})`); - console.log(` Return type: ${func.returnType}`); - console.log(` Parameters: ${func.parameters}\n`); - } - - console.log("Testing getFunctionAtPosition():"); - const fileContent = mockFS.readFile("/src/test.ts") || ""; - const positions = [ - { pos: fileContent.indexOf("basicFunction"), desc: "basicFunction" }, - { pos: fileContent.indexOf("arrowFunction"), desc: "arrowFunction" }, - { pos: fileContent.indexOf("classMethod"), desc: "classMethod" }, - ]; - - for (const { pos, desc } of positions) { - if (pos === -1) { - console.log(`Could not find position for ${desc}`); - continue; - } - try { - const returnType = analyzer.getFunctionAtPosition( - "/src/test.ts", - pos + 20, - ); - console.log(`- ${desc} return type: ${returnType}`); - } catch (error) { - console.error(`- Error getting return type for ${desc}:`, error); - } - } - } catch (error) { - console.error("Test failed:", error); - console.error("Full error:", error instanceof Error ? error.stack : error); - process.exit(1); - } -} - -runTests().catch(console.error); diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/tsconfig.json b/src/codegen/sdk/typescript/external/typescript_analyzer/tsconfig.json deleted file mode 100644 index 5d1de5bd3..000000000 --- a/src/codegen/sdk/typescript/external/typescript_analyzer/tsconfig.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "compilerOptions": { - "target": "ES2020", - "module": "ES2020", - "moduleResolution": "node", - "esModuleInterop": true, - "strict": true, - "skipLibCheck": true, - "forceConsistentCasingInFileNames": true, - "outDir": "./dist", - "declaration": true, - "declarationDir": "./dist", - "rootDir": "./src" - }, - "include": ["src/**/*"], - "exclude": ["node_modules", "dist"], - "ts-node": { - "esm": true, - "experimentalSpecifierResolution": "node" - } -} diff --git a/src/codegen/sdk/typescript/file.py b/src/codegen/sdk/typescript/file.py deleted file mode 100644 index 4c937292d..000000000 --- a/src/codegen/sdk/typescript/file.py +++ /dev/null @@ -1,450 +0,0 @@ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import mover, reader, writer -from codegen.sdk.core.file import SourceFile -from codegen.sdk.core.interfaces.exportable import Exportable -from codegen.sdk.enums import ImportType, NodeType, SymbolType -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.typescript.assignment import TSAssignment -from codegen.sdk.typescript.class_definition import TSClass -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.export import TSExport -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.import_resolution import TSImport -from codegen.sdk.typescript.interface import TSInterface -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.namespace import TSNamespace -from codegen.sdk.utils import calculate_base_path -from codegen.shared.decorators.docs import noapidoc, ts_apidoc -from codegen.shared.enums.programming_language import ProgrammingLanguage - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.typescript.detached_symbols.promise_chain import TSPromiseChain - from codegen.sdk.typescript.symbol import TSSymbol - from codegen.sdk.typescript.ts_config import TSConfig - from codegen.sdk.typescript.type_alias import TSTypeAlias - - -@ts_apidoc -class TSFile(SourceFile[TSImport, TSFunction, TSClass, TSAssignment, TSInterface, TSCodeBlock], TSHasBlock, Exportable): - """Extends the SourceFile class to provide TypeScript-specific functionality. - - Attributes: - programming_language: The programming language of the file. Set to ProgrammingLanguage.TYPESCRIPT. - ts_config: The ts_config file nearest to this file. - """ - - programming_language = ProgrammingLanguage.TYPESCRIPT - ts_config: TSConfig | None = None - - @cached_property - @reader(cache=False) - def exports(self) -> list[TSExport]: - """Returns all Export symbols in the file. - - Retrieves a list of all top-level export declarations in the current TypeScript file. - Does not include exports inside namespaces. - - Returns: - list[TSExport]: A list of TSExport objects representing all top-level export declarations in the file. - """ - # Filter to only get exports that are direct children of the file's code block - return sort_editables(filter(lambda node: isinstance(node, TSExport) and ((node.parent.parent.parent == self) or (node.parent.parent == self)), self.get_nodes(sort=False)), by_id=True) - - @property - @reader(cache=False) - def export_statements(self) -> list[ExportStatement[TSExport]]: - """Returns a list of all export statements in the file. - - Each export statement in the returned list can contain multiple exports. The export statements - are sorted by their position in the file. - - Args: - None - - Returns: - list[ExportStatement[TSExport]]: A list of ExportStatement objects, where each ExportStatement - contains one or more TSExport objects. - """ - export_statements = [exp.export_statement for exp in self.exports] - return sort_editables(export_statements) - - @property - @reader(cache=False) - def default_exports(self) -> list[TSExport]: - """Returns all default export symbols from the file. - - A property method that retrieves all export objects that are designated as default exports from the file. - - Returns: - list[TSExport]: A list of default export objects. Each object belongs to a single export statement. - """ - return [x for x in self.exports if x.is_default_export()] - - @property - @reader - def named_exports(self) -> list[TSExport]: - """Returns the named exports declared in the file. - - Gets all export statements in the file that are not default exports. These exports are defined - using the `export` keyword rather than `export default`. - - Args: - self (TSFile): The TypeScript file object. - - Returns: - list[TSExport]: A list of TSExport objects representing named exports in the file. - """ - return [x for x in self.exports if not x.is_default_export()] - - @reader - def get_export(self, export_name: str) -> TSExport | None: - """Returns an export object with the specified name from the file. - - This method searches for an export with the given name in the file. - - Args: - export_name (str): The name of the export to find. - - Returns: - TSExport | None: The export object if found, None otherwise. - """ - return next((x for x in self.exports if x.name == export_name), None) - - @property - @reader - def interfaces(self) -> list[TSInterface]: - """Returns all Interfaces in the file. - - Retrieves all symbols in the file that are of type Interface. - - Args: - None - - Returns: - list[TSInterface]: A list of TypeScript interface symbols defined in the file. - """ - return [s for s in self.symbols if s.symbol_type == SymbolType.Interface] - - @reader - def get_interface(self, name: str) -> TSInterface | None: - """Retrieves a specific interface from the file by its name. - - Args: - name (str): The name of the interface to find. - - Returns: - TSInterface | None: The interface with the specified name if found, None otherwise. - """ - return next((x for x in self.interfaces if x.name == name), None) - - @property - @reader - def types(self) -> list[TSTypeAlias]: - """Returns all type aliases in the file. - - Retrieves a list of all type aliases defined in the current TypeScript/JavaScript file. - - Returns: - list[TSTypeAlias]: A list of all type aliases in the file. Empty list if no type aliases are found. - """ - return [s for s in self.symbols if s.symbol_type == SymbolType.Type] - - @reader - def get_type(self, name: str) -> TSTypeAlias | None: - """Returns a specific Type by name from the file's types. - - Retrieves a TypeScript type alias by its name from the file's collection of types. - - Args: - name (str): The name of the type alias to retrieve. - - Returns: - TSTypeAlias | None: The TypeScript type alias with the matching name, or None if not found. - """ - return next((x for x in self.types if x.name == name), None) - - @staticmethod - def get_extensions() -> list[str]: - """Returns a list of file extensions that this class can parse. - - Returns a list of file extensions for TypeScript and JavaScript files that this File class can parse and process. - - Returns: - list[str]: A list of file extensions including '.tsx', '.ts', '.jsx', and '.js'. - """ - return [".tsx", ".ts", ".jsx", ".js"] - - def symbol_can_be_added(self, symbol: TSSymbol) -> bool: - """Determines if a TypeScript symbol can be added to this file based on its type and JSX compatibility. - - This method checks whether a given symbol can be added to the current TypeScript file by validating its compatibility with the file's extension. - In particular, it ensures that JSX functions are only added to appropriate file types (.tsx or .jsx). - - Args: - symbol (TSSymbol): The TypeScript symbol to be checked. - - Returns: - bool: True if the symbol can be added to this file, False otherwise. - """ - if symbol.symbol_type == SymbolType.Function: - if symbol.is_jsx: - if not (self.file_path.endswith("tsx") or self.file_path.endswith("jsx")): - return False - return True - - @reader - def get_config(self) -> TSConfig | None: - """Returns the nearest tsconfig.json applicable to this file. - - Gets the TypeScript configuration for the current file by retrieving the nearest tsconfig.json file in the directory hierarchy. - - Returns: - TSConfig | None: The TypeScript configuration object if found, None otherwise. - """ - return self.ts_config - - @writer - def add_export_to_symbol(self, symbol: TSSymbol) -> None: - """Adds an export keyword to a symbol in a TypeScript file. - - Marks a symbol for export by adding the 'export' keyword. This modifies the symbol's - declaration to make it available for import by other modules. - - Args: - symbol (TSSymbol): The TypeScript symbol (function, class, interface, etc.) to be exported. - - Returns: - None - """ - # TODO: this should be in symbol.py class. Rename as `add_export` - symbol.add_keyword("export") - - @writer - def remove_unused_exports(self) -> None: - """Removes unused exports from the file. - - Analyzes all exports in the file and removes any that are not used. An export is considered unused if it has no direct - symbol usages and no re-exports that are used elsewhere in the codebase. - - When removing unused exports, the method also cleans up any related unused imports. For default exports, it removes - the 'export default' keyword, and for named exports, it removes the 'export' keyword or the entire export statement. - - Args: - None - - Returns: - None - """ - for export in self.exports: - symbol_export_unused = True - symbols_to_remove = [] - - exported_symbol = export.resolved_symbol - for export_usage in export.symbol_usages: - if export_usage.node_type == NodeType.IMPORT or (export_usage.node_type == NodeType.EXPORT and export_usage.resolved_symbol != exported_symbol): - # If the import has no usages then we can add the import to the list of symbols to remove - reexport_usages = export_usage.symbol_usages - if len(reexport_usages) == 0: - symbols_to_remove.append(export_usage) - break - - # If any of the import's usages are valid symbol usages, export is used. - if any(usage.node_type == NodeType.SYMBOL for usage in reexport_usages): - symbol_export_unused = False - break - - symbols_to_remove.append(export_usage) - - elif export_usage.node_type == NodeType.SYMBOL: - symbol_export_unused = False - break - - # export is not used, remove it - if symbol_export_unused: - # remove the unused imports - for imp in symbols_to_remove: - imp.remove() - - if exported_symbol == exported_symbol.export.declared_symbol: - # change this to be more robust - if exported_symbol.source.startswith("export default "): - exported_symbol.replace("export default ", "") - else: - exported_symbol.replace("export ", "") - else: - exported_symbol.export.remove() - if exported_symbol.export != export: - export.remove() - - @noapidoc - def _get_export_data(self, relative_path: str, export_type: str = "EXPORT") -> tuple[tuple[str, str], dict[str, callable]]: - quoted_paths = (f"'{relative_path}'", f'"{relative_path}"') - export_type_conditions = { - "WILDCARD": lambda exp: exp.is_wildcard_export(), - "TYPE": lambda exp: exp.is_type_export(), - # Changed this condition - it was incorrectly handling type exports - "EXPORT": lambda exp: (not exp.is_type_export() and not exp.is_wildcard_export()), - } - return quoted_paths, export_type_conditions - - @reader - def has_export_statement_for_path(self, relative_path: str, export_type: str = "EXPORT") -> bool: - """Checks if the file has exports of specified type that contains the given path in single or double quotes. - - Args: - relative_path (str): The path to check for in export statements - export_type (str): Type of export to check for - "WILDCARD", "TYPE", or "EXPORT" (default) - - Returns: - bool: True if there exists an export of specified type with the exact relative path (quoted) - in its source, False otherwise. - """ - if not self.export_statements: - return False - - quoted_paths, export_type_conditions = self._get_export_data(relative_path, export_type) - condition = export_type_conditions[export_type] - - return any(any(quoted_path in stmt.source for quoted_path in quoted_paths) and any(condition(exp) for exp in stmt.exports) for stmt in self.export_statements) - - #################################################################################################################### - # GETTERS - #################################################################################################################### - - @reader - def get_export_statement_for_path(self, relative_path: str, export_type: str = "EXPORT") -> ExportStatement | None: - """Gets the first export of specified type that contains the given path in single or double quotes. - - Args: - relative_path (str): The path to check for in export statements - export_type (str): Type of export to get - "WILDCARD", "TYPE", or "EXPORT" (default) - - Returns: - TSExport | None: The first matching export if found, None otherwise. - """ - if not self.export_statements: - return None - - quoted_paths, export_type_conditions = self._get_export_data(relative_path, export_type) - condition = export_type_conditions[export_type] - - for stmt in self.export_statements: - if any(quoted_path in stmt.source for quoted_path in quoted_paths): - for exp in stmt.exports: - if condition(exp): - return exp - - return None - - @noapidoc - def get_import_module_name_for_file(self, filepath: str, ctx: CodebaseContext) -> str: - """Returns the module name that this file gets imported as""" - # TODO: support relative and absolute module path - import_path = filepath - - # Apply path import aliases to import_path - if self.ts_config: - import_path = self.ts_config.translate_absolute_path(import_path) - - # Remove file extension - import_path = os.path.splitext(import_path)[0] - return f"'{import_path}'" - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Generates and returns an import statement for the file. - - Constructs an import statement string based on the file's name and module information. - - Args: - alias (str | None): Alternative name for the imported module. Defaults to None. - module (str | None): Module path to import from. If None, uses file's default module name. - import_type (ImportType): The type of import statement. Defaults to ImportType.UNKNOWN. - is_type_import (bool): Whether this is a type-only import. Defaults to False. - - Returns: - str: A formatted import statement string importing all exports from the module. - """ - import_module = module if module is not None else self.import_module_name - file_module = self.name - return f"import * as {file_module} from {import_module}" - - @cached_property - @noapidoc - @reader(cache=True) - def valid_import_names(self) -> dict[str, Symbol | TSImport]: - """Returns a dict mapping name => Symbol (or import) in this file that can be imported from another file""" - valid_export_names = {} - if len(self.default_exports) == 1: - valid_export_names["default"] = self.default_exports[0] - for export in self.exports: - for name, dest in export.names: - valid_export_names[name] = dest - return valid_export_names - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @mover - def update_filepath(self, new_filepath: str) -> None: - """Updates the file path of the current file and all associated imports. - - Renames the current file to a new file path and updates all imports that reference this file to point to the new location. - - Args: - new_filepath (str): The new file path to move the file to. - - Returns: - None - """ - # =====[ Add the new filepath as a new file node in the graph ]===== - new_file = self.ctx.node_classes.file_cls.from_content(new_filepath, self.content, self.ctx) - # =====[ Change the file on disk ]===== - self.transaction_manager.add_file_rename_transaction(self, new_filepath) - # =====[ Update all the inbound imports to point to the new module ]===== - for imp in self.inbound_imports: - existing_imp = imp.module.source.strip("'") - new_module_name = new_file.import_module_name.strip("'") - # Web specific hacks - if self.ctx.repo_name == "web": - if existing_imp.startswith("./"): - relpath = calculate_base_path(new_filepath, existing_imp) - new_module_name = new_module_name.replace(relpath, ".") - elif existing_imp.startswith("~/src"): - new_module_name = new_module_name.replace("src/", "~/src/") - imp.set_import_module(f"'{new_module_name}'") - - @reader - def get_namespace(self, name: str) -> TSNamespace | None: - """Returns a specific namespace by name from the file's namespaces. - - Args: - name (str): The name of the namespace to find. - - Returns: - TSNamespace | None: The namespace with the specified name if found, None otherwise. - """ - return next((x for x in self.symbols if isinstance(x, TSNamespace) and x.name == name), None) - - @property - @reader - def promise_chains(self) -> list[TSPromiseChain]: - """Returns all promise chains in the file. - - Returns: - list[TSPromiseChain]: A list of promise chains in the file. - """ - promise_chains = [] - for function in self.functions: - for promise_chain in function.promise_chains: - promise_chains.append(promise_chain) - return promise_chains diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py deleted file mode 100644 index ee71ee9db..000000000 --- a/src/codegen/sdk/typescript/function.py +++ /dev/null @@ -1,453 +0,0 @@ -from __future__ import annotations - -from functools import cached_property -from typing import TYPE_CHECKING - -from codegen.sdk.core.autocommit import commiter, reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.function import Function -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.typescript.detached_symbols.decorator import TSDecorator -from codegen.sdk.typescript.detached_symbols.parameter import TSParameter -from codegen.sdk.typescript.enums import TSFunctionTypeNames -from codegen.sdk.typescript.expressions.type import TSType -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.placeholder.placeholder_return_type import TSReturnTypePlaceholder -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.sdk.utils import find_all_descendants -from codegen.shared.decorators.docs import noapidoc, ts_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.import_resolution import Import, WildcardImport - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.export_statement import ExportStatement - from codegen.sdk.core.statements.symbol_statement import SymbolStatement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - from codegen.sdk.typescript.detached_symbols.promise_chain import TSPromiseChain -_VALID_TYPE_NAMES = {function_type.value for function_type in TSFunctionTypeNames} -logger = get_logger(__name__) - - -@ts_apidoc -class TSFunction(Function[TSDecorator, "TSCodeBlock", TSParameter, TSType], TSHasBlock, TSSymbol): - """Representation of a Function in JavaScript/TypeScript""" - - @noapidoc - @commiter - def parse(self, ctx: CodebaseContext) -> None: - super().parse(ctx) - - self.return_type = self.child_by_field_name("return_type", placeholder=TSReturnTypePlaceholder) - if parameters_node := self.ts_node.child_by_field_name("parameters"): - self._parameters = Collection(parameters_node, self.file_node_id, self.ctx, self) - params = [x for x in parameters_node.children if x.type in ("required_parameter", "optional_parameter")] - symbols = None - # Deconstructed object parameters - if len(params) == 1: - pattern = params[0].child_by_field_name("pattern") - type_annotation = None - if type_node := params[0].child_by_field_name("type"): - type_annotation = self._parse_type(type_node) - if pattern and pattern.type == "object_pattern": - params = [x for x in pattern.children if x.type in ("shorthand_property_identifier_pattern", "object_assignment_pattern", "pair_pattern")] - symbols = [TSParameter(x, i, self._parameters, type_annotation) for (i, x) in enumerate(params)] - # Default case - regular parameters - if symbols is None: - symbols = [TSParameter(x, i, self._parameters) for (i, x) in enumerate(params)] - self._parameters._init_children(symbols) - elif parameters_node := self.ts_node.child_by_field_name("parameter"): - self._parameters = Collection(parameters_node, self.file_node_id, self.ctx, self) - self._parameters._init_children([TSParameter(parameters_node, 0, self._parameters)]) - else: - logger.warning(f"Couldn't find parameters for {self!r}") - self._parameters = [] - self.type_parameters = self.child_by_field_name("type_parameters") - - @property - @reader - def function_type(self) -> TSFunctionTypeNames: - """Gets the type of function from its TreeSitter node. - - Extracts and returns the type of function (e.g., arrow function, generator function, function expression) - from the node's type information. - - Args: - None: Property method that uses instance's ts_node. - - Returns: - TSFunctionTypeNames: The function type enum value representing the specific type of function. - """ - return TSFunctionTypeNames(self.ts_node.type) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - # If a destination is provided, use it, otherwise use the default destination - # This is used for cases where a non-symbol (eg. argument) value parses as a function - dest = dest or self.self_dest - - # =====[ Typed Parameters ]===== - # Have to grab types from the parameters - if self.parameters is not None: - for param in self.parameters: - assignment_patterns = find_all_descendants(param.ts_node, {"object_pattern", "object_assignment_pattern", "assignment_pattern"}) - if assignment_patterns: - dest.add_all_identifier_usages_for_child_node(UsageKind.GENERIC, assignment_patterns[0]) - if self.type_parameters: - self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) - # =====[ Return type ]===== - if self.return_type: - # Need to parse all the different types - self.return_type._compute_dependencies(UsageKind.RETURN_TYPE, dest) - - # =====[ Code Block ]===== - self.code_block._compute_dependencies(usage_type, dest) - - @classmethod - @noapidoc - def from_function_type(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: SymbolStatement | ExportStatement) -> TSFunction: - """Creates a TSFunction object from a function declaration.""" - if ts_node.type not in [function_type.value for function_type in TSFunctionTypeNames]: - msg = f"Node type={ts_node.type} is not a function declaration" - raise ValueError(msg) - file = ctx.get_node(file_node_id) - if canonical := file._range_index.get_canonical_for_range(ts_node.range, ts_node.kind_id): - return canonical - return cls(ts_node, file_node_id, ctx, parent=parent) - - @staticmethod - @noapidoc - def _get_name_node(ts_node: TSNode) -> TSNode | None: - if ts_node.type == "function_declaration": - return ts_node.child_by_field_name("name") - elif ts_node.type == "function_expression": - if name := ts_node.child_by_field_name("name"): - return name - return ts_node.parent.child_by_field_name("name") - elif ts_node.type == "arrow_function": - ts_node = ts_node.parent - while ts_node.type in ("parenthesized_expression", "binary_expression"): - ts_node = ts_node.parent - if ts_node.type == "pair": - return ts_node.child_by_field_name("key") - elif ts_node.type == "return_statement": - func_expression = next((x for x in ts_node.children if x.type == ("function_expression")), None) - if func_expression: - return func_expression.child_by_field_name("name") - return ts_node.child_by_field_name("name") - - @property - @reader - def function_signature(self) -> str: - """Returns a string representation of the function's signature. - - Generates a string containing the full function signature including name, parameters, and return type - based on the function's type (arrow function, generator function, function expression, etc.). - - Returns: - str: A string containing the complete function signature. For example: 'function foo(bar: string): number' - - Raises: - NotImplementedError: If the function type is not implemented. - """ - if self.function_type == TSFunctionTypeNames.FunctionDeclaration: - func_def_src = f"function {self.name}" - elif self.function_type == TSFunctionTypeNames.GeneratorFunctionDeclaration: - func_def_src = f"function* {self.name}" - elif self.function_type == TSFunctionTypeNames.ArrowFunction: - func_def_src = f"{self.name} = " - elif self.function_type == TSFunctionTypeNames.FunctionExpression: - func_def_src = f"{self.name} = function" - else: - msg = "function type not implemented" - raise NotImplementedError(msg) - if self.parameters is not None: - func_def_src += self.parameters.source - if self.return_type: - func_def_src += ": " + self.return_type.source - return func_def_src - - @cached_property - @reader - def is_private(self) -> bool: - """Determines if a function is private based on its accessibility modifier. - - This property examines the function's accessibility modifier to determine if it's marked as private. In TypeScript, this means the function has the 'private' keyword. - - Returns: - bool: True if the function has a 'private' accessibility modifier, False otherwise. - """ - modifier = self.ts_node.children[0] - return modifier.type == "accessibility_modifier" and modifier.text == b"private" - - @cached_property - @reader - def is_magic(self) -> bool: - """Returns whether this method is a magic method. - - A magic method is a method whose name starts and ends with double underscores, like __init__ or __str__. - In this implementation, all methods are considered non-magic in TypeScript. - - Returns: - bool: False, as TypeScript does not have magic methods. - """ - return False - - @property - @reader - def is_anonymous(self) -> bool: - """Property indicating whether a function is anonymous. - - Returns True if the function has no name or if its name is an empty string. - - Returns: - bool: True if the function is anonymous, False otherwise. - """ - return not self.name or self.name.strip() == "" - - @property - def is_async(self) -> bool: - """Determines if the function is asynchronous. - - Checks the function's node children to determine if the function is marked as asynchronous. - - Returns: - bool: True if the function is asynchronous (has 'async' keyword), False otherwise. - """ - return any("async" == x.type for x in self.ts_node.children) - - @property - @reader - def is_arrow(self) -> bool: - """Returns True iff the function is an arrow function. - - Identifies whether the current function is an arrow function (lambda function) in TypeScript/JavaScript. - - Returns: - bool: True if the function is an arrow function, False otherwise. - """ - return self.function_type == TSFunctionTypeNames.ArrowFunction - - @property - @reader - def is_property(self) -> bool: - """Determines if the function is a property. - - Checks if any of the function's decorators are '@property' or '@cached_property'. - - Returns: - bool: True if the function has a @property or @cached_property decorator, False otherwise. - """ - return any(dec in ("@property", "@cached_property") for dec in self.decorators) - - @property - @reader - def _named_arrow_function(self) -> TSNode | None: - """Returns the name of the named arrow function, if it exists.""" - if self.is_arrow: - node = self.ts_node - if name := self.get_name(): - node = name.ts_node - parent = node.parent - if parent.type == "variable_declarator": - return parent.parent - return None - - @property - @reader - def is_jsx(self) -> bool: - """Determines if the function is a React component by checking if it returns a JSX element. - - A function is considered a React component if it contains at least one JSX element in its body - and either has no name or has a name that starts with an uppercase letter. - - Returns: - bool: True if the function is a React component, False otherwise. - """ - # Must contain a React component - if len(self.jsx_elements) == 0: - return False - # Must be uppercase name - if not self.name: - return True - return self.name[0].isupper() - - #################################################################################################################### - # MANIPULATIONS - #################################################################################################################### - - @writer - def asyncify(self) -> None: - """Modifies the function to be asynchronous, if it is not already. - - This method converts a synchronous function to be asynchronous by adding the 'async' keyword and wrapping - the return type in a Promise if a return type exists. - - Returns: - None - - Note: - If the function is already asynchronous, this method does nothing. - """ - if self.is_async: - return - self.add_keyword("async") - if self.return_type and self.return_type.name != "Promise": - self.return_type.insert_before("Promise<", newline=False) - self.return_type.insert_after(">", newline=False) - - @writer - def arrow_to_named(self, name: str | None = None) -> None: - """Converts an arrow function to a named function in TypeScript/JavaScript. - - Transforms an arrow function into a named function declaration, preserving type parameters, parameters, - return types, and function body. If the function is already asynchronous, the async modifier is preserved. - - Args: - name (str | None): The name for the converted function. If None, uses the name of the variable - the arrow function is assigned to. - - Returns: - None - - Raises: - ValueError: If name is None and the arrow function is not assigned to a named variable. - """ - if not self.is_arrow or self.name is None: - return - - if name is None and self._name_node is None: - msg = "The `name` argument must be provided when converting an arrow function that is not assigned to any variable." - raise ValueError(msg) - - node = self._named_arrow_function - # Replace variable declaration with function declaration - async_prefix = "async " if self.is_async else "" - edit_start = node.start_byte - type_param_node = self.ts_node.child_by_field_name("type_parameters") - if param_node := self.ts_node.child_by_field_name("parameters"): - edit_end = param_node.start_byte - self._edit_byte_range(f"{async_prefix}function {name or self.name}{type_param_node.text.decode('utf-8') if type_param_node else ''}", edit_start, edit_end) - elif param_node := self.ts_node.child_by_field_name("parameter"): - edit_end = param_node.start_byte - self._edit_byte_range(f"{async_prefix}function {name or self.name}{type_param_node.text.decode('utf-8') if type_param_node else ''}(", edit_start, edit_end) - self.insert_at(param_node.end_byte, ")") - - # Remove the arrow => - if self.return_type: - remove_start = self.return_type.end_byte + 1 - else: - remove_start = param_node.end_byte + 1 - self.remove_byte_range(remove_start, self.code_block.start_byte) - - # Add brackets surrounding the code block if not already present - if not self.code_block.source.startswith("{"): - self.insert_at(self.code_block.start_byte, "{ return ") - self.insert_at(node.end_byte, " }") - - # Move over variable type annotations as parameter type annotations - if (type_node := node.named_children[0].child_by_field_name("type")) and len(param_node.named_children) == 1: - destructured_param = self.parameters.ts_node.named_children[0] - self.insert_at(destructured_param.end_byte, type_node.text.decode("utf-8")) - - @noapidoc - @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: - """Resolves the name of a symbol in the function. - - This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. - Otherwise, it calls the superclass method to resolve the name. - - Args: - name (str): The name of the symbol to resolve. - start_byte (int | None): The start byte of the symbol to resolve. - strict (bool): If True considers candidates that don't satisfy start byte if none do. - - Returns: - Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import, or None if not found. - """ - if self.is_method: - if name == "this": - yield self.parent_class - return - yield from super().resolve_name(name, start_byte, strict=strict) - - @staticmethod - def is_valid_node(node: TSNode) -> bool: - """Determines if a given tree-sitter node corresponds to a valid function type. - - This method checks if a tree-sitter node's type matches one of the valid function types defined in the _VALID_TYPE_NAMES set. - - Args: - node (TSNode): The tree-sitter node to validate. - - Returns: - bool: True if the node's type is a valid function type, False otherwise. - """ - return node.type in _VALID_TYPE_NAMES - - @writer - def convert_props_to_interface(self) -> None: - """Converts React component props to TypeScript interfaces. - - For React components, converts inline props type definitions and PropTypes declarations - to a separate interface. The interface will be named {ComponentName}Props and inserted - before the component. - - Handles both simple types and complex types including: - - Inline object type definitions - - PropTypes declarations - - Union types and optional props - - Destructured parameters - - Generic type parameters - - Example: - ```typescript - // Before - function Button({ text, onClick }: { text: string, onClick: () => void }) { - return ; - } - - // After - interface ButtonProps { - text: string; - onClick: () => void; - } - function Button({ text, onClick }: ButtonProps) { - return ; - } - ``` - """ - if self.parameters and len(self.parameters) > 0: - if interface_name := self.convert_to_react_interface(): - if not self.parameters[0].is_destructured: - self.parameters[0].edit(interface_name) - else: - self.insert_at(self.parameters.ts_node.end_byte - 1, f": {interface_name}") - - @property - @reader - def promise_chains(self) -> list[TSPromiseChain]: - """Returns a list of promise chains in the function. - - Returns: - list[TSPromiseChain]: A list of promise chains in the function. - """ - promise_chains = [] - visited_base_functions = set() - function_calls = self.function_calls - - for function_call in function_calls: - if function_call.name == "then" and function_call.base not in visited_base_functions: - promise_chains.append(function_call.promise_chain) - visited_base_functions.add(function_call.base) - - return promise_chains diff --git a/src/codegen/sdk/typescript/import_resolution.py b/src/codegen/sdk/typescript/import_resolution.py deleted file mode 100644 index 82b770a79..000000000 --- a/src/codegen/sdk/typescript/import_resolution.py +++ /dev/null @@ -1,648 +0,0 @@ -from __future__ import annotations - -import os -from collections import deque -from typing import TYPE_CHECKING, Self, override - -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.expressions import Name -from codegen.sdk.core.import_resolution import Import, ImportResolution, WildcardImport -from codegen.sdk.core.interfaces.exportable import Exportable -from codegen.sdk.enums import ImportType, NodeType, SymbolType -from codegen.sdk.utils import find_all_descendants, find_first_ancestor, find_first_descendant -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from collections.abc import Generator - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.external_module import ExternalModule - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.import_statement import ImportStatement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.typescript.file import TSFile - from codegen.sdk.typescript.namespace import TSNamespace - from codegen.sdk.typescript.statements.import_statement import TSImportStatement - - -@ts_apidoc -class TSImport(Import["TSFile"], Exportable): - """Extends Import for TypeScript codebases.""" - - @reader - def is_type_import(self) -> bool: - """Checks if an import is a type import. - - Determines whether an import statement is specifically for types. This includes explicit type imports - (e.g., 'import type foo from bar'), exports of types, and dynamic imports followed by property access. - - Returns: - bool: True if the import is a type import, False otherwise. - """ - if self.ts_node.type == "import_statement": - return self.source.startswith("import type ") - elif self.ts_node.type == "export_statement": - return self.source.startswith("export type ") - elif call_node := find_first_descendant(self.ts_node, ["call_expression"]): - # If the import is an import using functions `import` or `require`, - # assume it is a type import if it is followed by a dot notation - while call_node.parent and call_node.parent.type in ["await_expression", "parenthesized_expression"]: - call_node = call_node.parent - sibling = call_node.next_named_sibling - return sibling and sibling.type == "property_identifier" - return False - - @reader - def is_module_import(self) -> bool: - """Determines if an import represents a module-level import. - - Module imports represent imports of an entire file rather than specific symbols from a file. - These imports must traverse through the file to resolve the actual symbol(s) being imported. - - Args: - self (TSImport): The import object to check. - - Returns: - bool: True if the import is a module-level import, False otherwise. - Returns True for: - - Imports of type MODULE, WILDCARD, or DEFAULT_EXPORT - - Side effect imports that are not type imports - """ - if self.import_type in [ImportType.MODULE, ImportType.WILDCARD, ImportType.DEFAULT_EXPORT]: - return True - return self.import_type == ImportType.SIDE_EFFECT and not self.is_type_import() - - @reader - def is_default_import(self) -> bool: - """Determines whether the import is a default export import. - - Checks if the import is importing a default export from a module. The default export - may be a single symbol or an entire module. - - Args: - self (TSImport): The import instance. - - Returns: - bool: True if the import is a default export import, False otherwise. - """ - return self.import_type == ImportType.DEFAULT_EXPORT - - @property - @reader - def namespace(self) -> str | None: - """If import is a module import, returns any namespace prefix that must be used with import reference. - - Returns the namespace prefix for import reference when the import is a module import, specifically when - the import resolves to a file node_type. The namespace is determined by the alias if set, otherwise None. - - Returns: - str | None: The alias name if the import resolves to a file node_type and has an alias, - None otherwise. - """ - resolved_symbol = self.resolved_symbol - if resolved_symbol is not None and resolved_symbol.node_type == NodeType.FILE: - return self.alias.source if self.alias is not None else None - return None - - @property - @reader - def imported_exports(self) -> list[Exportable]: - """Returns the enumerated list of exports imported from a module import. - - Returns a list of exports that this import statement references. The exports can be direct exports - or re-exports from other modules. - - Returns: - list[Exportable]: List of exported symbols. Empty list if this import doesn't reference any exports - or if imported_symbol is None. - """ - if self.imported_symbol is None: - return [] - - if not self.is_module_import(): - return [] if self.imported_symbol.export is None else [self.imported_symbol.export] - - from_file = self.imported_symbol - if from_file.node_type != NodeType.FILE: - return [] - - if self.is_default_import(): - return from_file.default_exports - - return from_file.exports - - @property - @reader - def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None: - """Returns the resolved symbol that the import is referencing. - - Follows the imported symbol and returns the final symbol it resolves to. For default imports, resolves to the exported symbol. - For module imports with matching symbol names, resolves through module imports to find the matching symbol. - For indirect imports, follows the import chain to find the ultimate symbol. - - Returns: - Union[Symbol, ExternalModule, TSFile, None]: The resolved symbol. Returns None if the import cannot be resolved, - Symbol for resolved import symbols, ExternalModule for external module imports, - or TSFile for module/file imports. - """ - imports_seen = set() - resolved_symbol = self.imported_symbol - - if resolved_symbol is None: - return None - - # If the default import is a single symbol export, resolve to the symbol - if self.is_default_import(): - if resolved_symbol is not None and resolved_symbol.node_type == NodeType.FILE: - file = resolved_symbol - if len(file.default_exports) == 1 and (export_symbol := file.default_exports[0]).is_default_symbol_export(): - while export_symbol and export_symbol.node_type == NodeType.EXPORT: - export_symbol = export_symbol.exported_symbol - resolved_symbol = export_symbol - - # If the imported symbol is a file even though the import is not a module import, - # hop through the file module imports to resolve the symbol that matches the import symbol name - if resolved_symbol and resolved_symbol.node_type == NodeType.FILE and not self.is_module_import(): - # Perform BFS search on the file's module imports to find the resolved symbol - module_imps_seen = set() - module_imports_to_search = deque([imp for imp in resolved_symbol.imports if imp.is_module_import()]) - while module_imports_to_search: - module_imp = module_imports_to_search.popleft() - if module_imp in module_imps_seen: - continue - - module_imps_seen.add(module_imp) - # Search through all the symbols that this module imp is potentially importing! - for export in module_imp.imported_exports: - if export.is_named_export(): - # TODO: Why does this break? When is symbol_name None? - if self.symbol_name is not None and export.name == self.symbol_name.source: - resolved_symbol = export.resolved_symbol - break - else: - exported_symbol = export.exported_symbol - if isinstance(exported_symbol, TSImport) and exported_symbol.is_module_import(): - module_imports_to_search.append(exported_symbol) - - # If the imported symbol is an indirect import, hop through the import resolution edges - while resolved_symbol is not None and resolved_symbol.node_type == NodeType.IMPORT: - if resolved_symbol in imports_seen: - return resolved_symbol - - imports_seen.add(resolved_symbol) - resolved_symbol = resolved_symbol.imported_symbol - - return resolved_symbol - - @reader - def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None: - """Resolves an import statement to its target file and symbol. - - This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports, - and supports various import types including named imports, default imports, and module imports. - - Args: - base_path (str | None): The base path to resolve imports from. If None, uses the codebase's base path - or the tsconfig base URL. - - Returns: - ImportResolution[TSFile] | None: An ImportResolution object containing the resolved file and symbol, - or None if the import could not be resolved (treated as an external module). - The ImportResolution contains: - - from_file: The file being imported from - - symbol: The specific symbol being imported (None for module imports) - - imports_file: True if importing the entire file/module - """ - try: - self.file: TSFile # Type cast ts_file - base_path = base_path or self.ctx.projects[0].base_path or "" - - # Get the import source path - import_source = self.module.source.strip('"').strip("'") if self.module else "" - - # Try to resolve the import using the tsconfig paths - if self.file.ts_config: - import_source = self.file.ts_config.translate_import_path(import_source) - - # Check if need to resolve relative import path to absolute path - relative_import = False - if import_source.startswith("."): - relative_import = True - - # Insert base path - # This has the happen before the relative path resolution - if not import_source.startswith(base_path): - import_source = os.path.join(base_path, import_source) - - # If the import is relative, convert it to an absolute path - if relative_import: - import_source = self._relative_to_absolute_import(import_source) - else: - import_source = os.path.normpath(import_source) - - # covers the case where the import is from a directory ex: "import { postExtract } from './post'" - import_name = import_source.split("/")[-1] - if "." not in import_name: - possible_paths = ["index.ts", "index.js", "index.tsx", "index.jsx"] - for p_path in possible_paths: - if self.ctx.to_absolute(os.path.join(import_source, p_path)).exists(): - import_source = os.path.join(import_source, p_path) - break - - # Loop through all extensions and try to find the file - extensions = ["", ".ts", ".d.ts", ".tsx", ".d.tsx", ".js", ".jsx"] - # Try both filename with and without extension - for import_source_base in (import_source, os.path.splitext(import_source)[0]): - for extension in extensions: - import_source_ext = import_source_base + extension - if file := self.ctx.get_file(import_source_ext): - if self.is_module_import(): - return ImportResolution(from_file=file, symbol=None, imports_file=True) - else: - # If the import is a named import, resolve to the named export in the file - if self.symbol_name is None: - return ImportResolution(from_file=file, symbol=None, imports_file=True) - export_symbol = file.get_export(export_name=self.symbol_name.source) - if export_symbol is None: - # If the named export is not found, it is importing a module re-export. - # In this case, resolve to the file itself and dynamically resolve the symbol later. - return ImportResolution(from_file=file, symbol=None, imports_file=True) - return ImportResolution(from_file=file, symbol=export_symbol) - - # If the imported file is not found, treat it as an external module - return None - except AssertionError: - # Codebase is probably trying to import file from outside repo - return None - - @noapidoc - @reader - def _relative_to_absolute_import(self, relative_import: str) -> str: - """Helper to go from a relative import to an absolute one. - Ex: "./foo/bar" in "src/file.ts" would be -> "src/foo/bar" - Ex: "../foo/bar" in "project/src/file.ts" would be -> "project/foo/bar" - """ - import_file_path = self.to_file.file_path # the filepath the import is in - import_dir = os.path.dirname(import_file_path) # the directory of the file this import is in - absolute_import = os.path.join(import_dir, relative_import) # absolute path of the import - normalized_absolute_import = os.path.normpath(absolute_import) # normalized absolute path of the import. removes redundant separators and './' or '../' segments. - return normalized_absolute_import - - @classmethod - @noapidoc - def from_export_statement(cls, source_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSImportStatement) -> list[TSImport]: - """Constructs import objects defined from an export statement""" - export_statement_node = find_first_ancestor(source_node, ["export_statement"]) - imports = [] - if export_clause := next((child for child in export_statement_node.named_children if child.type == "export_clause"), None): - # === [ Named export import ] === - # e.g. export { default as subtract } from './subtract'; - for export_specifier in export_clause.named_children: - name = export_specifier.child_by_field_name("name") - alias = export_specifier.child_by_field_name("alias") or name - import_type = ImportType.DEFAULT_EXPORT if (name and name.text.decode("utf-8") == "default") else ImportType.NAMED_EXPORT - imp = cls(ts_node=export_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=source_node, name_node=name, alias_node=alias, import_type=import_type) - imports.append(imp) - else: - # ==== [ Wildcard export import ] ==== - # Note: re-exporting using wildcard syntax does NOT include the default export! - if namespace_export := next((child for child in export_statement_node.named_children if child.type == "namespace_export"), None): - # Aliased wildcard export (e.g. export * as myNamespace from './m';) - alias = next(child for child in namespace_export.named_children if child.type == "identifier") or namespace_export - imp = cls( - ts_node=export_statement_node, - file_node_id=file_node_id, - ctx=ctx, - parent=parent, - module_node=source_node, - name_node=namespace_export, - alias_node=alias, - import_type=ImportType.WILDCARD, - ) - imports.append(imp) - else: - # No alias wildcard export (e.g. export * from './m';) - imp = cls(ts_node=export_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=source_node, name_node=None, alias_node=None, import_type=ImportType.WILDCARD) - imports.append(imp) - return imports - - @classmethod - @noapidoc - def from_import_statement(cls, import_statement_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSImportStatement) -> list[TSImport]: - source_node = import_statement_node.child_by_field_name("source") - import_clause = next((x for x in import_statement_node.named_children if x.type == "import_clause"), None) - if import_clause is None: - # === [ Side effect module import ] === - # Will not have any import usages in the file! (e.g. import './module';) - return [cls(ts_node=import_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=source_node, name_node=None, alias_node=None, import_type=ImportType.SIDE_EFFECT)] - - imports = [] - for import_type_node in import_clause.named_children: - if import_type_node.type == "identifier": - # === [ Default export import ] === - # e.g. import a from './module' - imp = cls( - ts_node=import_statement_node, - file_node_id=file_node_id, - ctx=ctx, - parent=parent, - module_node=source_node, - name_node=import_type_node, - alias_node=import_type_node, - import_type=ImportType.DEFAULT_EXPORT, - ) - imports.append(imp) - elif import_type_node.type == "named_imports": - # === [ Named export import ] === - # e.g. import { a, b as c } from './module'; - for import_specifier in import_type_node.named_children: - # Skip comment nodes - if import_specifier.type == "comment": - continue - - name_node = import_specifier.child_by_field_name("name") - alias_node = import_specifier.child_by_field_name("alias") or name_node - imp = cls( - ts_node=import_statement_node, - file_node_id=file_node_id, - ctx=ctx, - parent=parent, - module_node=source_node, - name_node=name_node, - alias_node=alias_node, - import_type=ImportType.NAMED_EXPORT, - ) - imports.append(imp) # MODIFY IMPORT HERE ? - elif import_type_node.type == "namespace_import": - # === [ Wildcard module import ] === - # Imports both default and named exports e.g. import * as someAlias from './module'; - alias_node = next(x for x in import_type_node.named_children if x.type == "identifier") - imp = cls( - ts_node=import_statement_node, - file_node_id=file_node_id, - ctx=ctx, - module_node=source_node, - parent=parent, - name_node=import_type_node, - alias_node=alias_node, - import_type=ImportType.WILDCARD, - ) - imports.append(imp) - return imports - - @classmethod - @noapidoc - def from_dynamic_import_statement(cls, import_call_node: TSNode, module_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: ImportStatement) -> list[TSImport]: - """Parses a dynamic import statement, given a reference to the `import`/`require` node and `module` node. - e.g. - const myModule = await import('./someFile')`; - const { exportedFunction, exportedVariable: aliasedVariable } = await import('./someFile'); - import('./someFile'); - - const myModule = require('./someFile')`; - const { exportedFunction, exportedVariable: aliasedVariable } = require('./someFile'); - require('./someFile'); - Note: imports using `require` will import whatever is defined in `module.exports = ...` or `export = ...` - """ - if module_node is None: - # TODO: fixme - return [] - imports = [] - - # TODO: FIX THIS, is a horrible hack to avoid a crash on the next.js - if len(module_node.named_children) == 0: - return [] - - # Grab the first element of dynamic import call expression argument list - module_node = module_node.named_children[0] - - # Get the top most parent of call expression node that bypasses wrappers that doesn't change the semantics - call_node = find_first_ancestor(import_call_node, ["call_expression"]) - while call_node.parent and call_node.parent.type in ["await_expression", "parenthesized_expression", "binary_expression", "ternary_expression"]: - call_node = call_node.parent - - import_statement_node = call_node.parent - if import_statement_node.type == "expression_statement": - # ==== [ Side effect module import ] ==== - # Will not have any import usages in the file! (e.g. await import('./module');) - imp = cls(ts_node=import_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=None, alias_node=None, import_type=ImportType.SIDE_EFFECT) - imports.append(imp) - else: - if import_statement_node.type == "member_expression": - # ==== [ Type import ] ==== - # Imports a type defined in module -- in javascript, type imports are entirely emitted - # e.g. type DynamicType = typeof import('./module').SomeType; - # const MyType = typeof import('./module').SomeType; - # const DefaultType = (await import('./module')).default - # import('./module').SomeType - # function foo(param: import('./module').SomeType) {} - name_node = import_statement_node.child_by_field_name("property") - parent_type_names = ["type_alias_declaration", "variable_declarator", "assignment_expression", "expression_statement"] - import_statement_node = find_first_ancestor(import_statement_node, parent_type_names, max_depth=2) or import_statement_node - else: - name_type_name = "left" if import_statement_node.type == "assignment_expression" else "name" - name_node = import_statement_node.child_by_field_name(name_type_name) - - # TODO: Handle dynamic import name not found (CG-8722) - if name_node is None: - alias_node = import_statement_node.child_by_field_name("name") or import_statement_node.child_by_field_name("left") - imp = cls( - ts_node=import_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=None, alias_node=alias_node, import_type=ImportType.SIDE_EFFECT - ) - imports.append(imp) - return imports - - # If import statement is a variable declaration, capture the variable scoping keyword (const, let, var, etc) - if import_statement_node.type == "lexical_declaration": - statement_node = import_statement_node - else: - statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node - - # ==== [ Named dynamic import ] ==== - if name_node.type == "property_identifier": - # If the type import is being stored into a variable, get the alias - if import_statement_node.type in ["type_alias_declaration", "variable_declarator"]: - alias_node = import_statement_node.child_by_field_name("name") - elif import_statement_node.type == "assignment_expression": - alias_node = import_statement_node.child_by_field_name("left") - else: - alias_node = name_node - import_type = ImportType.DEFAULT_EXPORT if name_node.text.decode("utf-8") == "default" else ImportType.NAMED_EXPORT - imp = cls(ts_node=statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=name_node, alias_node=alias_node, import_type=import_type) - imports.append(imp) - elif name_node.type == "identifier": - # ==== [ Aliased module import ] ==== - # Imports both default and named exports (e.g. const moduleImp = await import('./module');) - imp = cls(ts_node=statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=name_node, alias_node=name_node, import_type=ImportType.MODULE) - imports.append(imp) - elif name_node.type == "object_pattern": - # ==== [ Deconstructed import ] ==== - for imported_symbol in name_node.named_children: - if imported_symbol.type == "shorthand_property_identifier_pattern": - # ==== [ Named export import ] ==== - # e.g. const { symbol } = await import('./module') - imp = cls( - ts_node=statement_node, - file_node_id=file_node_id, - ctx=ctx, - parent=parent, - module_node=module_node, - name_node=imported_symbol, - alias_node=imported_symbol, - import_type=ImportType.NAMED_EXPORT, - ) - imports.append(imp) - elif imported_symbol.type == "pair_pattern": - # ==== [ Aliased named export import ] ==== - # e.g. const { symbol: aliasedSymbol } = await import('./module') - name_node = imported_symbol.child_by_field_name("key") - alias_node = imported_symbol.child_by_field_name("value") - imp = cls( - ts_node=statement_node, - file_node_id=file_node_id, - ctx=ctx, - parent=parent, - module_node=module_node, - name_node=name_node, - alias_node=alias_node, - import_type=ImportType.NAMED_EXPORT, - ) - imports.append(imp) - else: - continue - # raise ValueError(f"Unexpected alias name node type {imported_symbol.type}") - return imports - - @property - @reader - def import_specifier(self) -> Editable: - """Retrieves the import specifier node for this import. - - Finds and returns the import specifier node containing this import's name and optional alias. - For named imports, this is the import_specifier or export_specifier node. - For other imports, this is the identifier node containing the import name. - - Returns: - Editable: The import specifier node containing this import's name and alias. - For named imports, returns the import_specifier/export_specifier node. - For other imports, returns the identifier node containing the import name. - Returns None if no matching specifier is found. - """ - import_specifiers = find_all_descendants(self.ts_node, {"import_specifier", "export_specifier"}) - for import_specifier in import_specifiers: - alias = import_specifier.child_by_field_name("alias") - if alias is not None: - is_match = self.alias.source == alias.text.decode("utf-8") - else: - name = import_specifier.child_by_field_name("name") - is_match = self.symbol_name.source == name.text.decode("utf-8") - if is_match: - return Name(import_specifier, self.file_node_id, self.ctx, self) - if named := next(iter(find_all_descendants(self.ts_node, {"identifier"})), None): - if named.text.decode("utf-8") == self.symbol_name.source: - return Name(named, self.file_node_id, self.ctx, self) - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Generates an import string for an import statement. - - Generates a string representation of an import statement with optional type and alias information. - - Args: - alias (str | None): Alias name for the imported symbol. Defaults to None. - module (str | None): Module name to import from. Defaults to None. If not provided, uses the file's import module name. - import_type (ImportType): Type of import (e.g. WILDCARD, NAMED_EXPORT). Defaults to ImportType.UNKNOWN. - is_type_import (bool): Whether this is a type import. Defaults to False. - - Returns: - str: A string representation of the import statement. - """ - type_prefix = "type " if is_type_import else "" - import_module = module if module is not None else self.file.import_module_name - - if import_type == ImportType.WILDCARD: - file_as_module = self.file.name - return f"import {type_prefix}* as {file_as_module} from {import_module};" - elif alias is not None and alias != self.name: - return f"import {type_prefix}{{ {self.name} as {alias} }} from {import_module};" - else: - return f"import {type_prefix}{{ {self.name} }} from {import_module};" - - @property - @noapidoc - @override - def names(self) -> Generator[tuple[str, Self | WildcardImport[Self]], None, None]: - if self.import_type == ImportType.SIDE_EFFECT: - return - yield from super().names - - @property - def namespace_imports(self) -> list[TSNamespace]: - """Returns any namespace objects imported by this import statement. - - For example: - import * as MyNS from './mymodule'; - - Returns: - List of namespace objects imported - """ - if not self.is_namespace_import(): - return [] - - from codegen.sdk.typescript.namespace import TSNamespace - - resolved = self.resolved_symbol - if resolved is None or not isinstance(resolved, TSNamespace): - return [] - - return [resolved] - - @property - def is_namespace_import(self) -> bool: - """Returns True if this import is importing a namespace. - - Examples: - import { MathUtils } from './file1'; # True if MathUtils is a namespace - import * as AllUtils from './utils'; # True - """ - # For wildcard imports with namespace alias - if self.import_type == ImportType.WILDCARD and self.namespace: - return True - - # For named imports, check if any imported symbol is a namespace - if self.import_type == ImportType.NAMED_EXPORT: - for name, _ in self.names: - symbol = self.resolved_symbol - if symbol and symbol.symbol_type == SymbolType.Namespace: - return True - - return False - - @override - def set_import_module(self, new_module: str) -> None: - """Sets the module of an import. - - Updates the module of an import statement while maintaining the import symbol. - Uses single quotes by default (TypeScript standard), falling back to double quotes - only if the path contains single quotes. - - Args: - new_module (str): The new module path to import from. - - Returns: - None - """ - if self.module is None: - return - - # If already quoted, use as is - if (new_module.startswith('"') and new_module.endswith('"')) or (new_module.startswith("'") and new_module.endswith("'")): - self.module.source = new_module - return - - # Use double quotes if path contains single quotes, otherwise use single quotes (TypeScript standard) - quote = '"' if "'" in new_module else "'" - self.module.source = f"{quote}{new_module}{quote}" diff --git a/src/codegen/sdk/typescript/interface.py b/src/codegen/sdk/typescript/interface.py deleted file mode 100644 index 4107cde7c..000000000 --- a/src/codegen/sdk/typescript/interface.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, TypeVar - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.interface import Interface -from codegen.sdk.core.symbol_groups.parents import Parents -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.expressions.type import TSType -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.statements.attribute import TSAttribute -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.statement import Statement - -Parent = TypeVar("Parent", bound="TSHasBlock") - - -@ts_apidoc -class TSInterface(Interface[TSCodeBlock, TSAttribute, TSFunction, TSType], TSSymbol, TSHasBlock): - """Representation of an Interface in TypeScript - - Attributes: - parent_interfaces: All the interfaces that this interface extends. - code_block: The code block that contains the interface's body. - """ - - def __init__( - self, - ts_node: TSNode, - file_id: NodeId, - ctx: CodebaseContext, - parent: Statement[CodeBlock[Parent, ...]], - ) -> None: - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - super().__init__(ts_node, file_id, ctx, parent) - body_node = ts_node.child_by_field_name("body") - - # Find the nearest parent with a code_block - current_parent = parent - while not hasattr(current_parent, "code_block"): - current_parent = current_parent.parent - - self.code_block = TSCodeBlock(body_node, current_parent.code_block.level + 1, current_parent.code_block, self) - self.code_block.parse() - - @commiter - @noapidoc - def parse(self, ctx: CodebaseContext) -> None: - # =====[ Extends ]===== - # Look for parent interfaces in the "extends" clause - if extends_clause := self.child_by_field_types("extends_type_clause"): - self.parent_interfaces = Parents(extends_clause.ts_node, self.file_node_id, self.ctx, self) - super().parse(ctx) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - dest = dest or self.self_dest - - # =====[ Extends ]===== - if self.parent_interfaces is not None: - self.parent_interfaces._compute_dependencies(UsageKind.SUBCLASS, dest) - - # =====[ Body ]===== - # Look for type references in the interface body - self.code_block._compute_dependencies(usage_type, dest) - - @staticmethod - @noapidoc - def _get_name_node(ts_node: TSNode) -> TSNode | None: - if ts_node.type == "interface_declaration": - return ts_node.child_by_field_name("name") - return None - - @property - @reader - def attributes(self) -> list[TSAttribute]: - """Retrieves the list of attributes defined in the TypeScript interface. - - Args: - None - - Returns: - list[TSAttribute]: A list of the interface's attributes stored in the code block. - """ - return self.code_block.attributes diff --git a/src/codegen/sdk/typescript/interfaces/has_block.py b/src/codegen/sdk/typescript/interfaces/has_block.py deleted file mode 100644 index be8bb68c4..000000000 --- a/src/codegen/sdk/typescript/interfaces/has_block.py +++ /dev/null @@ -1,172 +0,0 @@ -from __future__ import annotations - -from functools import cached_property -from typing import TYPE_CHECKING, Self - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.detached_symbols.code_block import CodeBlock -from codegen.sdk.core.interfaces.has_block import HasBlock -from codegen.sdk.core.statements.statement import StatementType -from codegen.sdk.extensions.utils import find_all_descendants -from codegen.sdk.typescript.detached_symbols.decorator import TSDecorator -from codegen.sdk.typescript.statements.comment import TSComment, TSCommentType -from codegen.sdk.typescript.symbol_groups.comment_group import TSCommentGroup -from codegen.sdk.utils import find_index -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement - - -@ts_apidoc -class TSHasBlock(HasBlock["TSCodeBlock", TSDecorator]): - """A TypeScript base class that provides block-level code organization and decorator handling capabilities. - - This class extends the concept of block scoping for TypeScript code elements like classes and functions. - It provides functionality for managing code blocks, decorators, JSX elements, and documentation within - those blocks. The class supports operations such as retrieving and manipulating docstrings, - handling JSX components, and managing TypeScript decorators. - """ - - @property - @reader - def is_decorated(self) -> bool: - """Checks if the current symbol has a decorator. - - Determines if the symbol has a preceding decorator node. - - Returns: - bool: True if the symbol has a decorator node as its previous named sibling, - False otherwise. - """ - previous_sibling = self.ts_node.prev_named_sibling - # is decorated if it has a previous named sibling (i.e. the text above the function) and it is type=decorator - return previous_sibling and previous_sibling.type == "decorator" - - @property - @reader - def decorators(self) -> list[TSDecorator]: - """Returns a list of decorators associated with this symbol. - - Retrieves all decorators applied to this symbol by looking at both previous named siblings and decorator fields. - This includes both inline decorators and standalone decorator statements. - - Returns: - list[TSDecorator]: A list of TSDecorator objects representing all decorators applied to this symbol. - Returns an empty list if no decorators are found. - """ - decorators = [] - # Get all previous named siblings that are decorators, break once we hit a non decorator - prev_named_sibling = self.ts_node.prev_named_sibling - while prev_named_sibling and prev_named_sibling.type == "decorator": - decorators.append(TSDecorator(prev_named_sibling, self)) - prev_named_sibling = prev_named_sibling.prev_named_sibling - for child in self.ts_node.children_by_field_name("decorator"): - decorators.append(TSDecorator(child, self)) - return decorators - - @property - @reader - def jsx_elements(self) -> list[JSXElement[Self]]: - """Returns a list of all JSX elements contained within this symbol. - - Searches through the extended nodes of the symbol for any JSX elements or self-closing JSX elements - and returns them as a list of JSXElement objects. - - Args: - None - - Returns: - list[JSXElement[Self]]: A list of JSXElement objects contained within this symbol. - """ - jsx_elements = [] - for node in self.extended_nodes: - jsx_element_nodes = find_all_descendants(node.ts_node, {"jsx_element", "jsx_self_closing_element"}) - jsx_elements.extend([self._parse_expression(x) for x in jsx_element_nodes]) - return jsx_elements - - @reader - def get_component(self, component_name: str) -> JSXElement[Self] | None: - """Returns a specific JSX element from within this symbol's JSX elements. - - Searches through all JSX elements in this symbol's code block and returns the first one that matches - the given component name. - - Args: - component_name (str): The name of the JSX component to find. - - Returns: - JSXElement[Self] | None: The matching JSX element if found, None otherwise. - """ - for component in self.jsx_elements: - if component.name == component_name: - return component - return None - - @cached_property - @reader - def docstring(self) -> TSCommentGroup | None: - """Retrieves the docstring of a function or class. - - Returns any comments immediately preceding this node as a docstring. For nodes that are children of a HasBlock, it returns consecutive comments that end on the line before the node starts. - For other nodes, it returns formatted docstring comments. - - Returns: - TSCommentGroup | None: A CommentGroup representing the docstring if one exists, None otherwise. - """ - if self.parent.parent.parent and isinstance(self.parent.parent, CodeBlock): - comments = [] - sibling_statements = self.parent.parent.statements - index = find_index(self.ts_node, [x.ts_node for x in sibling_statements]) - if index == -1: - return None - - row = self.start_point[0] - for statement in reversed(sibling_statements[:index]): - if statement.end_point[0] != row - 1: - break - row = statement.start_point[0] - if statement.statement_type == StatementType.COMMENT: - comments.append(statement) - - return TSCommentGroup.from_comment_nodes(list(reversed(comments)), self) - - return TSCommentGroup.from_docstring(self) - - @writer - def set_docstring(self, docstring: str, auto_format: bool = True, clean_format: bool = True, leading_star: bool = True, force_multiline: bool = False) -> None: - """Sets or updates a docstring for a code element. - - Adds a new docstring if none exists, or updates the existing docstring. Handles formatting and placement - of the docstring according to the specified parameters. - - Args: - docstring (str): The docstring text to be added or updated. - auto_format (bool, optional): Whether to automatically format the text into a docstring format. Defaults to True. - clean_format (bool, optional): Whether to clean existing formatting from the docstring before inserting. Defaults to True. - leading_star (bool, optional): Whether to add leading "*" to each line of the comment block. Defaults to True. - force_multiline (bool, optional): Whether to force single line comments to be multi-line. Defaults to False. - - Returns: - None - """ - # Clean existing formatting off docstring - if clean_format: - docstring = TSComment.clean_comment(docstring) - - # If the docstring exists, edit it - if self.docstring: - if auto_format: - self.docstring.edit_text(docstring) - else: - self.docstring.edit(docstring) - else: - if auto_format: - docstring = TSComment.generate_comment(docstring, TSCommentType.SLASH_STAR, leading_star=leading_star, force_multiline=force_multiline) - # If a comment exists, insert the docstring after it - if self.comment: - self.comment.insert_after(docstring) - # If no comment exists, insert the docstring before the function - else: - self.extended.insert_before(docstring, fix_indentation=True) diff --git a/src/codegen/sdk/typescript/namespace.py b/src/codegen/sdk/typescript/namespace.py deleted file mode 100644 index 2442ce6da..000000000 --- a/src/codegen/sdk/typescript/namespace.py +++ /dev/null @@ -1,400 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, override - -from codegen.sdk.core.autocommit import commiter -from codegen.sdk.core.autocommit.decorators import writer -from codegen.sdk.core.export import Export -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.enums import SymbolType -from codegen.sdk.extensions.autocommit import reader -from codegen.sdk.extensions.sort import sort_editables -from codegen.sdk.extensions.utils import cached_property -from codegen.sdk.typescript.class_definition import TSClass -from codegen.sdk.typescript.enum_definition import TSEnum -from codegen.sdk.typescript.function import TSFunction -from codegen.sdk.typescript.interface import TSInterface -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.sdk.typescript.type_alias import TSTypeAlias -from codegen.shared.decorators.docs import noapidoc, ts_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from collections.abc import Sequence - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.core.statements.statement import Statement - from codegen.sdk.core.symbol import Symbol - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - from codegen.sdk.typescript.export import TSExport - from codegen.sdk.typescript.import_resolution import TSImport - - -logger = get_logger(__name__) - - -@ts_apidoc -class TSNamespace(TSSymbol, TSHasBlock, HasName, HasAttribute): - """Representation of a namespace module in TypeScript. - - Attributes: - symbol_type: The type of the symbol, set to SymbolType.Namespace. - code_block: The code block associated with this namespace. - """ - - symbol_type = SymbolType.Namespace - code_block: TSCodeBlock - - def __init__(self, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: Statement, namespace_node: TSNode | None = None) -> None: - ts_node = namespace_node or ts_node - name_node = ts_node.child_by_field_name("name") - super().__init__(ts_node, file_id, ctx, parent, name_node=name_node) - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - """Computes dependencies for the namespace by analyzing its code block. - - Args: - usage_type: Optional UsageKind specifying how the dependencies are used - dest: Optional HasName destination for the dependencies - """ - # Use self as destination if none provided - dest = dest or self.self_dest - # Compute dependencies from namespace's code block - self.code_block._compute_dependencies(usage_type, dest) - - @cached_property - def symbols(self) -> list[Symbol]: - """Returns all symbols defined within this namespace, including nested ones.""" - all_symbols = [] - for stmt in self.code_block.statements: - if stmt.ts_node_type == "export_statement": - for export in stmt.exports: - all_symbols.append(export.declared_symbol) - elif hasattr(stmt, "assignments"): - all_symbols.extend(stmt.assignments) - else: - all_symbols.append(stmt) - return all_symbols - - def get_symbol(self, name: str, recursive: bool = True, get_private: bool = False) -> Symbol | None: - """Get an exported or private symbol by name from this namespace. Returns only exported symbols by default. - - Args: - name: Name of the symbol to find - recursive: If True, also search in nested namespaces - get_private: If True, also search in private symbols - - Returns: - Symbol | None: The found symbol, or None if not found - """ - # First check direct symbols in this namespace - for symbol in self.symbols: - # Handle TSAssignmentStatement case - if hasattr(symbol, "assignments"): - for assignment in symbol.assignments: - if assignment.name == name: - # If we are looking for private symbols then return it, else only return exported symbols - if get_private: - return assignment - elif assignment.is_exported: - return assignment - - # Handle regular symbol case - if hasattr(symbol, "name") and symbol.name == name: - if get_private: - return symbol - elif symbol.is_exported: - return symbol - - # If recursive and this is a namespace, check its symbols - if recursive and isinstance(symbol, TSNamespace): - nested_symbol = symbol.get_symbol(name, recursive=True, get_private=get_private) - return nested_symbol - - return None - - @reader(cache=False) - @noapidoc - def get_nodes(self, *, sort_by_id: bool = False, sort: bool = True) -> Sequence[Importable]: - """Returns all nodes in the namespace, sorted by position in the namespace.""" - file_nodes = self.file.get_nodes(sort_by_id=sort_by_id, sort=sort) - start_limit = self.start_byte - end_limit = self.end_byte - namespace_nodes = [] - for file_node in file_nodes: - if file_node.start_byte > start_limit: - if file_node.end_byte < end_limit: - namespace_nodes.append(file_node) - else: - break - return namespace_nodes - - @cached_property - @reader(cache=False) - def exports(self) -> list[TSExport]: - """Returns all Export symbols in the namespace. - - Retrieves a list of all top-level export declarations in the current TypeScript namespace. - - Returns: - list[TSExport]: A list of TSExport objects representing all top-level export declarations in the namespace. - """ - # Filter to only get exports that are direct children of the namespace's code block - return sort_editables(filter(lambda node: isinstance(node, Export), self.get_nodes(sort=False)), by_id=True) - - @cached_property - def functions(self) -> list[TSFunction]: - """Get all functions defined in this namespace. - - Returns: - List of Function objects in this namespace - """ - return [symbol for symbol in self.symbols if isinstance(symbol, TSFunction)] - - def get_function(self, name: str, recursive: bool = True) -> TSFunction | None: - """Get a function by name from this namespace. - - Args: - name: Name of the function to find - recursive: If True, also search in nested namespaces - """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSFunction) else None - - @cached_property - def classes(self) -> list[TSClass]: - """Get all classes defined in this namespace. - - Returns: - List of Class objects in this namespace - """ - return [symbol for symbol in self.symbols if isinstance(symbol, TSClass)] - - def get_class(self, name: str, recursive: bool = True) -> TSClass | None: - """Get a class by name from this namespace. - - Args: - name: Name of the class to find - recursive: If True, also search in nested namespaces - """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSClass) else None - - def get_interface(self, name: str, recursive: bool = True) -> TSInterface | None: - """Get an interface by name from this namespace. - - Args: - name: Name of the interface to find - recursive: If True, also search in nested namespaces - """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSInterface) else None - - def get_type(self, name: str, recursive: bool = True) -> TSTypeAlias | None: - """Get a type alias by name from this namespace. - - Args: - name: Name of the type to find - recursive: If True, also search in nested namespaces - """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSTypeAlias) else None - - def get_enum(self, name: str, recursive: bool = True) -> TSEnum | None: - """Get an enum by name from this namespace. - - Args: - name: Name of the enum to find - recursive: If True, also search in nested namespaces - """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSEnum) else None - - def get_namespace(self, name: str, recursive: bool = True) -> TSNamespace | None: - """Get a namespace by name from this namespace. - - Args: - name: Name of the namespace to find - recursive: If True, also search in nested namespaces - - Returns: - TSNamespace | None: The found namespace, or None if not found - """ - # First check direct symbols in this namespace - for symbol in self.symbols: - if isinstance(symbol, TSNamespace) and symbol.name == name: - return symbol - - # If recursive and this is a namespace, check its symbols - if recursive and isinstance(symbol, TSNamespace): - nested_namespace = symbol.get_namespace(name, recursive=True) - return nested_namespace - - return None - - def get_nested_namespaces(self) -> list[TSNamespace]: - """Get all nested namespaces within this namespace. - - Returns: - list[TSNamespace]: List of all nested namespace objects - """ - nested = [] - for symbol in self.symbols: - if isinstance(symbol, TSNamespace): - nested.append(symbol) - nested.extend(symbol.get_nested_namespaces()) - return nested - - @writer - def add_symbol_from_source(self, source: str) -> None: - """Adds a symbol to a namespace from a string representation. - - This method adds a new symbol definition to the namespace by appending its source code string. The symbol will be added - after existing symbols if present, otherwise at the beginning of the namespace. - - Args: - source (str): String representation of the symbol to be added. This should be valid source code for - the file's programming language. - - Returns: - None: The symbol is added directly to the namespace's content. - """ - symbols = self.symbols - if len(symbols) > 0: - symbols[-1].insert_after("\n" + source, fix_indentation=True) - else: - self.insert_after("\n" + source) - - @commiter - def add_symbol(self, symbol: TSSymbol, should_export: bool = True) -> TSSymbol | None: - """Adds a new symbol to the namespace, optionally exporting it if applicable. If the symbol already exists in the namespace, returns the existing symbol. - - Args: - symbol: The symbol to add to the namespace (either a TSSymbol instance or source code string) - export: Whether to export the symbol. Defaults to True. - - Returns: - TSSymbol | None: The existing symbol if it already exists in the file or None if it was added. - """ - existing_symbol = self.get_symbol(symbol.name) - if existing_symbol is not None: - return existing_symbol - - if not self.file.symbol_can_be_added(symbol): - msg = f"Symbol {symbol.name} cannot be added to this file type." - raise ValueError(msg) - - source = symbol.source - if isinstance(symbol, TSFunction) and symbol.is_arrow: - raw_source = symbol._named_arrow_function.text.decode("utf-8") - else: - raw_source = symbol.ts_node.text.decode("utf-8") - if should_export and hasattr(symbol, "export") and (not symbol.is_exported or raw_source not in symbol.export.source): - source = source.replace(source, f"export {source}") - self.add_symbol_from_source(source) - - @commiter - def remove_symbol(self, symbol_name: str) -> TSSymbol | None: - """Removes a symbol from the namespace by name. - - Args: - symbol_name: Name of the symbol to remove - - Returns: - The removed symbol if found, None otherwise - """ - symbol = self.get_symbol(symbol_name) - if symbol: - # Remove from code block statements - for i, stmt in enumerate(self.code_block.statements): - if symbol.source == stmt.source: - logger.debug(f"stmt to be removed: {stmt}") - self.code_block.statements.pop(i) - return symbol - return None - - @commiter - def rename_symbol(self, old_name: str, new_name: str) -> None: - """Renames a symbol within the namespace. - - Args: - old_name: Current symbol name - new_name: New symbol name - """ - symbol = self.get_symbol(old_name) - if symbol: - symbol.rename(new_name) - - @commiter - @noapidoc - def export_symbol(self, name: str) -> None: - """Marks a symbol as exported in the namespace. - - Args: - name: Name of symbol to export - """ - symbol = self.get_symbol(name, get_private=True) - if not symbol or symbol.is_exported: - return - - export_source = f"export {symbol.source}" - symbol.parent.edit(export_source) - - @cached_property - @noapidoc - @reader(cache=True) - def valid_import_names(self) -> dict[str, TSSymbol | TSImport]: - """Returns set of valid import names for this namespace. - - This includes all exported symbols plus the namespace name itself - for namespace imports. - """ - valid_export_names = {} - valid_export_names[self.name] = self - for export in self.exports: - for name, dest in export.names: - valid_export_names[name] = dest - return valid_export_names - - def resolve_import(self, import_name: str) -> Symbol | None: - """Resolves an import name to a symbol within this namespace. - - Args: - import_name: Name to resolve - - Returns: - Resolved symbol or None if not found - """ - # First check direct symbols - for symbol in self.symbols: - if symbol.is_exported and symbol.name == import_name: - return symbol - - # Then check nested namespaces - for nested in self.get_nested_namespaces(): - resolved = nested.resolve_import(import_name) - if resolved is not None: - return resolved - - return None - - @override - def resolve_attribute(self, name: str) -> Symbol | None: - """Resolves an attribute access on the namespace. - - Args: - name: Name of the attribute to resolve - - Returns: - The resolved symbol or None if not found - """ - return self.valid_import_names.get(name, None) diff --git a/src/codegen/sdk/typescript/placeholder/placeholder_return_type.py b/src/codegen/sdk/typescript/placeholder/placeholder_return_type.py deleted file mode 100644 index 8a7a8bd8c..000000000 --- a/src/codegen/sdk/typescript/placeholder/placeholder_return_type.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.placeholder.placeholder import Placeholder -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.core.interfaces.editable import Editable - -Parent = TypeVar("Parent", bound="Editable") - - -@ts_apidoc -class TSReturnTypePlaceholder(Placeholder[Parent], Generic[Parent]): - """A placeholder class for function return type annotations in TypeScript. - - This class represents a placeholder for function return type annotations, allowing for modification - and addition of return type annotations after the parameter list. It provides functionality to - add or modify return type annotations with proper formatting. - """ - - def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: - """Modifies the return type annotation of a function. - - Adds or modifies the return type annotation of a function after its parameter list. - - Args: - new_src (str): The return type annotation to add. If it doesn't start with ':', a ':' will be prepended. - fix_indentation (bool, optional): Whether to fix the indentation of the added code. Defaults to False. - priority (int, optional): The priority of this edit. Defaults to 0. - dedupe (bool, optional): Whether to remove duplicate edits. Defaults to True. - - Returns: - None - - Note: - If new_src is empty or None, the method returns without making any changes. - """ - if new_src == "" or new_src is None: - return - if not new_src.startswith(": "): - new_src = ": " + new_src - - param_node = self._parent_node.child_by_field_name("parameters") - param_node.insert_after(new_src, newline=False) diff --git a/src/codegen/sdk/typescript/statements/__init__.py b/src/codegen/sdk/typescript/statements/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/sdk/typescript/statements/assignment_statement.py b/src/codegen/sdk/typescript/statements/assignment_statement.py deleted file mode 100644 index cf4926a8c..000000000 --- a/src/codegen/sdk/typescript/statements/assignment_statement.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations - -from collections import deque -from typing import TYPE_CHECKING, Self - -from codegen.sdk.core.expressions.multi_expression import MultiExpression -from codegen.sdk.core.statements.assignment_statement import AssignmentStatement -from codegen.sdk.extensions.autocommit import reader -from codegen.sdk.typescript.assignment import TSAssignment -from codegen.shared.decorators.docs import noapidoc, ts_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - from codegen.sdk.typescript.interfaces.has_block import TSHasBlock - - -logger = get_logger(__name__) - - -@ts_apidoc -class TSAssignmentStatement(AssignmentStatement["TSCodeBlock", TSAssignment]): - """A class that represents a TypeScript assignment statement in a codebase, such as `const x = 1` or `const { a: b } = myFunc()`.""" - - assignment_types = {"assignment_expression", "augmented_assignment_expression", "variable_declarator", "public_field_definition", "property_signature"} - - @classmethod - @reader - @noapidoc - def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int, assignment_node: TSNode) -> TSAssignmentStatement: - """Creates an assignment statement node from a TreeSitter assignment node. - - This class method constructs a TSAssignmentStatement from a TreeSitter node representing an assignment. The method validates that the assignment node type is - one of the supported types: assignment_expression, augmented_assignment_expression, variable_declarator, public_field_definition, or property_signature. - - Args: - ts_node (TSNode): The TreeSitter node representing the entire statement. - file_node_id (NodeId): The identifier for the file containing this node. - ctx (CodebaseContext): The codebase context being constructed. - parent (TSHasBlock): The parent block containing this statement. - code_block (TSCodeBlock): The code block containing this statement. - pos (int): The position of this statement within its code block. - assignment_node (TSNode): The TreeSitter node representing the assignment. - - Returns: - TSAssignmentStatement: A new assignment statement node. - - Raises: - ValueError: If the assignment_node.type is not one of the supported assignment types. - """ - if assignment_node.type not in cls.assignment_types: - msg = f"Invalid assignment node type: {assignment_node.type}" - raise ValueError(msg) - - return cls(ts_node, file_node_id, ctx, parent, pos, assignment_node=assignment_node) - - def _parse_assignments(self, assignment_node: TSNode) -> MultiExpression[Self, TSAssignment]: - if assignment_node.type in ["assignment_expression", "augmented_assignment_expression"]: - return TSAssignment.from_assignment(assignment_node, self.file_node_id, self.ctx, self) - elif assignment_node.type in ["variable_declarator", "public_field_definition", "property_signature"]: - return TSAssignment.from_named_expression(assignment_node, self.file_node_id, self.ctx, self) - - logger.info(f"Unknown assignment type: {assignment_node.type}") - return MultiExpression(assignment_node, self.file_node_id, self.ctx, self.parent, [self.parent._parse_expression(assignment_node)]) - - def _DEPRECATED_parse_assignments(self) -> MultiExpression[TSHasBlock, TSAssignment]: - if self.ts_node.type in ["lexical_declaration", "variable_declaration"]: - return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, self._DEPRECATED_parse_assignment_declarations()) - elif self.ts_node.type in ["expression_statement"]: - return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, self._DEPRECATED_parse_assignment_expression()) - elif self.ts_node.type in ["public_field_definition", "property_signature", "enum_assignment"]: - return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, self._DEPRECATED_parse_attribute_assignments()) - else: - msg = f"Unknown assignment type: {self.ts_node.type}" - raise ValueError(msg) - - def _DEPRECATED_parse_attribute_assignments(self) -> list[TSAssignment]: - left = self.ts_node.child_by_field_name("name") - right = self.ts_node.child_by_field_name("value") - return [TSAssignment(self.ts_node, self.file_node_id, self.ctx, self, left, right, left)] - - def _DEPRECATED_parse_assignment_declarations(self) -> list[TSAssignment]: - assignments = [] - for variable_declarator in self.ts_node.named_children: - if variable_declarator.type != "variable_declarator": - continue - left = variable_declarator.child_by_field_name("name") - type_node = variable_declarator.child_by_field_name("type") - right = variable_declarator.child_by_field_name("value") - if len(left.named_children) > 0: - to_parse: deque[tuple[TSNode, TSNode | None]] = deque([(left, type_node)]) - while to_parse: - child, _type = to_parse.popleft() - for identifier in child.named_children: - if identifier.type == "pair_pattern": - value = identifier.child_by_field_name("value") - to_parse.append((value, _type)) # TODO:CG-10064 - if value.type == "identifier": - # TODO: Support type resolution for aliased object unpacks - assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, value)) - else: - key = identifier.child_by_field_name("key") - assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, key)) - else: - assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, identifier)) - - else: - assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, left)) - while right and right.type == "assignment_expression": - left = right.child_by_field_name("left") - right = right.child_by_field_name("right") - assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, left)) - - return assignments - - def _DEPRECATED_parse_assignment_expression(self) -> list[TSAssignment]: - assignments = [] - for child in self.ts_node.named_children: - if child.type not in ["assignment_expression", "augmented_assignment_expression"]: - continue - left = child.child_by_field_name("left") - right = child.child_by_field_name("right") - assignments.append(TSAssignment(child, self.file_node_id, self.ctx, self, left, right, left)) - - return assignments diff --git a/src/codegen/sdk/typescript/statements/attribute.py b/src/codegen/sdk/typescript/statements/attribute.py deleted file mode 100644 index 9de9ae260..000000000 --- a/src/codegen/sdk/typescript/statements/attribute.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk._proxy import proxy_property -from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.statements.attribute import Attribute -from codegen.sdk.typescript.assignment import TSAssignment -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.statements.assignment_statement import TSAssignmentStatement -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.interfaces.has_block import TSHasBlock - - -@ts_apidoc -class TSAttribute(Attribute[TSCodeBlock, TSAssignment], TSAssignmentStatement): - """Typescript implementation of Attribute detached symbol.""" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos=pos, assignment_node=ts_node) - self.type = self.assignments[0].type - - @reader - def _get_name_node(self) -> TSNode: - """Returns the ID node from the root node of the symbol""" - return self.ts_node.child_by_field_name("name") - - @proxy_property - @reader - def local_usages(self: TSAttribute[TSHasBlock, TSCodeBlock]) -> list[Editable]: - """Returns local usages of a TypeScript attribute within its code block. - - Searches through all statements in the attribute's parent code block and finds instances where the attribute is referenced with 'this.' prefix. Excludes the attribute's own - declaration/assignment. - - Args: - self (TSAttribute[TSHasBlock, TSCodeBlock]): The TypeScript attribute instance. - - Returns: - list[Editable]: A sorted list of unique Editable instances representing local usages of the attribute, ordered by their position in the source code. - - Note: - This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. - """ - usages = [] - for statement in self.parent.statements: - var_references = statement.find(f"this.{self.name}", exact=True) - for var_reference in var_references: - # Exclude the variable usage in the assignment itself - if self.ts_node.byte_range[0] <= var_reference.ts_node.start_byte and self.ts_node.byte_range[1] >= var_reference.ts_node.end_byte: - continue - usages.append(var_reference) - return sorted(dict.fromkeys(usages), key=lambda x: x.ts_node.start_byte) - - @property - def is_private(self) -> bool: - """Determines if this attribute has a private accessibility modifier. - - Args: - self: The TypeScript attribute instance. - - Returns: - bool: True if the attribute has a 'private' accessibility modifier, False otherwise. - """ - modifier = self.ts_node.children[0] - return modifier.type == "accessibility_modifier" and modifier.text == b"private" - - @property - def is_optional(self) -> bool: - """Returns True if this attribute is marked as optional in TypeScript. - - Checks if the attribute has a question mark (`?`) symbol after its name, indicating it's an optional field. - - Returns: - bool: True if the attribute is optional, False otherwise. - """ - if sibling := self.get_name().next_sibling: - return sibling.ts_node.type == "?" - return False diff --git a/src/codegen/sdk/typescript/statements/block_statement.py b/src/codegen/sdk/typescript/statements/block_statement.py deleted file mode 100644 index 98c995a74..000000000 --- a/src/codegen/sdk/typescript/statements/block_statement.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.statements.block_statement import BlockStatement -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.shared.decorators.docs import apidoc - -if TYPE_CHECKING: - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - -Parent = TypeVar("Parent", bound="TSCodeBlock") - - -@apidoc -class TSBlockStatement(BlockStatement[Parent], TSHasBlock, Generic[Parent]): - """Statement which contains a block.""" diff --git a/src/codegen/sdk/typescript/statements/catch_statement.py b/src/codegen/sdk/typescript/statements/catch_statement.py deleted file mode 100644 index e6027d3a7..000000000 --- a/src/codegen/sdk/typescript/statements/catch_statement.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.statements.catch_statement import CatchStatement -from codegen.sdk.typescript.statements.block_statement import TSBlockStatement -from codegen.shared.decorators.docs import apidoc, noapidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - -Parent = TypeVar("Parent", bound="TSCodeBlock") - - -@apidoc -class TSCatchStatement(CatchStatement[Parent], TSBlockStatement, Generic[Parent]): - """Typescript catch clause. - - Attributes: - code_block: The code block that may trigger an exception - condition: The condition which triggers this clause - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.condition = self.child_by_field_name("parameter") - - @property - @noapidoc - def other_possible_blocks(self) -> list[ConditionalBlock]: - return [self.parent] diff --git a/src/codegen/sdk/typescript/statements/comment.py b/src/codegen/sdk/typescript/statements/comment.py deleted file mode 100644 index e465dbc97..000000000 --- a/src/codegen/sdk/typescript/statements/comment.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import annotations - -from enum import StrEnum - -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.statements.comment import Comment, lowest_indentation -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - - -@ts_apidoc -class TSCommentType(StrEnum): - """An enumeration representing different types of comments in TypeScript. - - Represents the possible types of comments that can be used in TypeScript code, - including double slash comments (//), slash star comments (/* */), and unknown - comment types. - - Attributes: - DOUBLE_SLASH (str): Represents a single-line comment starting with //. - SLASH_STAR (str): Represents a multi-line comment enclosed in /* */. - UNKNOWN (str): Represents an unknown or unrecognized comment type. - """ - - DOUBLE_SLASH = "DOUBLE_SLASH" - SLASH_STAR = "SLASH_STAR" - UNKNOWN = "UNKNOWN" - - -@ts_apidoc -class TSComment(Comment): - """Abstract representation of typescript comments""" - - @property - @reader - def comment_type(self) -> TSCommentType: - """Determines the type of comment in a TypeScript source code. - - Parses the comment markers to determine if it's a single-line comment (//) or a multi-line comment (/* */). If no known comment markers are found, returns UNKNOWN. - - Args: - self: The TSComment instance. - - Returns: - TSCommentType: The type of the comment. Can be DOUBLE_SLASH for single-line comments, - SLASH_STAR for multi-line comments, or UNKNOWN if no known comment markers are found. - """ - if self.source.startswith("//"): - return TSCommentType.DOUBLE_SLASH - elif self.source.startswith("/*"): - return TSCommentType.SLASH_STAR - return TSCommentType.UNKNOWN - - @noapidoc - @commiter - def _parse_comment(self) -> str: - """Parse out the comment into its text content""" - # Remove comment markers - if self.comment_type == TSCommentType.DOUBLE_SLASH: - if self.source.startswith("// "): - return self.source[3:] - elif self.source.startswith("//"): - return self.source[2:] - else: - return self.source - elif self.comment_type == TSCommentType.SLASH_STAR: - formatted_text = self.source - # Remove comment markers - if self.source.startswith("/** "): - formatted_text = self.source[4:] - elif self.source.startswith("/**"): - formatted_text = self.source[3:] - elif self.source.startswith("/* "): - formatted_text = self.source[3:] - elif self.source.startswith("/*"): - formatted_text = self.source[2:] - if formatted_text.endswith(" */"): - formatted_text = formatted_text[:-3] - elif formatted_text.endswith("*/"): - formatted_text = formatted_text[:-2] - formatted_text = formatted_text.strip("\n") - formatted_split = formatted_text.split("\n") - # Get indentation level - padding = lowest_indentation(formatted_split) - # Remove indentation - formatted_text = "\n".join([line[padding:] for line in formatted_split]) - # Remove leading "* " from each line - text_lines = [] - for line in formatted_text.split("\n"): - if line.lstrip().startswith("* "): - text_lines.append(line.lstrip()[2:]) - elif line.lstrip().startswith("*"): - text_lines.append(line.lstrip()[1:]) - else: - text_lines.append(line) - return "\n".join(text_lines).rstrip() - else: - # Return the source if the comment type is unknown - return self.source - - @noapidoc - @reader - def _unparse_comment(self, new_src: str): - """Unparses cleaned text content into a comment block""" - should_add_leading_star = any([line.lstrip().startswith("*") for line in self.source.split("\n")[:-1]]) if len(self.source.split("\n")) > 1 else True - return self.generate_comment(new_src, self.comment_type, leading_star=should_add_leading_star) - - @staticmethod - def generate_comment(new_src: str, comment_type: TSCommentType, leading_star: bool = True, force_multiline: bool = False) -> str: - """Generates a TypeScript comment block from the given text content. - - Creates a comment block in either single-line (//) or multi-line (/* */) format based on the specified comment type. - - Args: - new_src (str): The text content to be converted into a comment. - comment_type (TSCommentType): The type of comment to generate (DOUBLE_SLASH or SLASH_STAR). - leading_star (bool, optional): Whether to add leading "*" to each line in multi-line comments. Defaults to True. - force_multiline (bool, optional): Whether to force multi-line format for single-line content. Defaults to False. - - Returns: - str: The formatted comment block as a string. - """ - # Generate the comment block based on the comment type - if comment_type == TSCommentType.DOUBLE_SLASH: - # Add the comment character to each line - new_src = "\n".join([f"// {line}" for line in new_src.split("\n")]) - elif comment_type == TSCommentType.SLASH_STAR: - # Add triple quotes to the text - if "\n" in new_src or force_multiline: - # Check if we should add leading "* " to each line - if leading_star: - new_src = "\n".join([(" * " + x).rstrip() for x in new_src.split("\n")]) - new_src = "/**\n" + new_src + "\n */" - else: - new_src = "/*\n" + new_src + "\n*/" - else: - new_src = "/* " + new_src + " */" - return new_src - - @staticmethod - def clean_comment(comment: str) -> str: - """Cleans comment markers and whitespace from a comment string. - - Removes various types of comment markers ('/', '/*', '/**', '*/') and trims whitespace - from the beginning and end of the comment text. - - Args: - comment (str): The raw comment string to be cleaned. - - Returns: - str: The cleaned comment text with comment markers and excess whitespace removed. - """ - comment = comment.lstrip() - if comment.startswith("//"): - comment = comment[2:] - if comment.startswith("/**"): - comment = comment[3:] - if comment.startswith("/*"): - comment = comment[2:] - if comment.endswith("*/"): - comment = comment[:-2] - return comment.strip() diff --git a/src/codegen/sdk/typescript/statements/for_loop_statement.py b/src/codegen/sdk/typescript/statements/for_loop_statement.py deleted file mode 100644 index 23a6e0ca4..000000000 --- a/src/codegen/sdk/typescript/statements/for_loop_statement.py +++ /dev/null @@ -1,120 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.typescript.statements.block_statement import TSBlockStatement -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.expressions import Expression - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -@ts_apidoc -class TSForLoopStatement(ForLoopStatement["TSCodeBlock"], TSBlockStatement["TSCodeBlock"]): - """Abstract representation of the for loop in TypeScript. - - Attributes: - item: An item in the iterable object. Only applicable for `for...of` loops. - iterable: The iterable that is being iterated over. Only applicable for `for...of` loops. - - initializer: The counter variable. Applicable for traditional for loops. - condition: The condition for the loop. Applicable for traditional for loops. - increment: The increment expression. Applicable for traditional for loops. - """ - - # TODO: parse as statement - item: Expression[TSForLoopStatement] | None = None - # TODO: parse as statement - iterable: Expression[TSForLoopStatement] | None = None - - initializer: Expression[TSForLoopStatement] | None = None - condition: Expression[TSForLoopStatement] | None = None - increment: Expression[TSForLoopStatement] | None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - if ts_node.type == "for_statement": - self.initializer = self.child_by_field_name("initializer") - self.condition = self.child_by_field_name("condition") - self.increment = self.child_by_field_name("increment") - elif ts_node.type == "for_in_statement": - self.item = self.child_by_field_name("left") - self.iterable = self.child_by_field_name("right") - else: - msg = f"Invalid for loop type: {ts_node.type}" - raise ValueError(msg) - - @property - @reader - def is_for_in_loop(self) -> bool: - """Determines whether the current for loop is a `for...in` loop. - - A property that identifies if the current for loop is a 'for...in' loop by checking its tree-sitter node type. - - Returns: - bool: True if the for loop is a 'for...in' loop, False otherwise. - """ - return self.ts_node.type == "for_in_statement" - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Retrieves all function calls within a for loop statement. - - For a for...in loop, collects function calls from the iterable expression. - For a traditional for loop, collects function calls from the initializer, - condition, and increment expressions. Also includes function calls from - the superclass implementation. - - Returns: - list[FunctionCall]: A list of all FunctionCall objects found within the for loop statement. - """ - fcalls = [] - if self.is_for_in_loop: - fcalls.extend(self.iterable.function_calls) - else: - fcalls.extend(self.initializer.function_calls) - fcalls.extend(self.condition.function_calls) - if self.increment: - fcalls.extend(self.increment.function_calls) - fcalls.extend(super().function_calls) - return fcalls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.is_for_in_loop: - self.item._compute_dependencies(usage_type, dest) - self.iterable._compute_dependencies(usage_type, dest) - else: - self.initializer._compute_dependencies(usage_type, dest) - self.condition._compute_dependencies(usage_type, dest) - if self.increment: - self.increment._compute_dependencies(usage_type, dest) - super()._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = [] - if self.is_for_in_loop: - symbols.extend(self.item.descendant_symbols) - symbols.extend(self.iterable.descendant_symbols) - else: - symbols.extend(self.initializer.descendant_symbols) - symbols.extend(self.condition.descendant_symbols) - if self.increment: - symbols.extend(self.increment.descendant_symbols) - symbols.extend(super().descendant_symbols) - return symbols diff --git a/src/codegen/sdk/typescript/statements/if_block_statement.py b/src/codegen/sdk/typescript/statements/if_block_statement.py deleted file mode 100644 index 2a1318f1a..000000000 --- a/src/codegen/sdk/typescript/statements/if_block_statement.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.statements.if_block_statement import IfBlockStatement -from codegen.sdk.core.statements.statement import StatementType -from codegen.shared.decorators.docs import apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -logger = get_logger(__name__) - - -Parent = TypeVar("Parent", bound="TSCodeBlock") - - -@apidoc -class TSIfBlockStatement(IfBlockStatement[Parent, "TSIfBlockStatement"], Generic[Parent]): - """Typescript implementation of the if/elif/else statement block. - For example, if there is a code block like: - if (condition1) { - block1 - } else if (condition2) { - block2 - } else { - block3 - } - This class represents the entire block, including the conditions and nested code blocks. - """ - - statement_type = StatementType.IF_BLOCK_STATEMENT - _else_clause_node: TSNode | None = None - - def __init__( - self, - ts_node: TSNode, - file_node_id: NodeId, - ctx: CodebaseContext, - parent: Parent, - pos: int, - else_clause_node: TSNode | None = None, - main_if_block: TSIfBlockStatement | None = None, - ) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self._else_clause_node = else_clause_node - self._main_if_block = main_if_block - # Call .value to unwrap the parenthesis - condition = self.child_by_field_name("condition") - self.condition = condition.value if condition else None - self.consequence_block = self._parse_consequence_block() - self._alternative_blocks = self._parse_alternative_blocks() if self.is_if_statement else None - self.consequence_block.parse() - - @reader - def _parse_consequence_block(self) -> TSCodeBlock: - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - if self.is_if_statement or self.is_elif_statement: - consequence_node = self.ts_node.child_by_field_name("consequence") - else: - consequence_node = self.ts_node.named_children[0] - return TSCodeBlock(consequence_node, self.parent.level + 1, self.parent, self) - - @reader - def _parse_alternative_blocks(self) -> list[TSIfBlockStatement]: - if self.is_else_statement or self.is_elif_statement: - return [] - - if_blocks = [] - alt_block = self - while alt_node := alt_block.ts_node.child_by_field_name("alternative"): - if (if_node := alt_node.named_children[0]).type == "if_statement": - # Elif statements are represented as if statements with an else clause as the parent node - alt_block = TSIfBlockStatement(if_node, self.file_node_id, self.ctx, self.parent, self.index, else_clause_node=alt_node, main_if_block=self._main_if_block or self) - else: - # Else clause - alt_block = TSIfBlockStatement(alt_node, self.file_node_id, self.ctx, self.parent, self.index, main_if_block=self._main_if_block or self) - if_blocks.append(alt_block) - return if_blocks - - @property - @reader - def is_if_statement(self) -> bool: - """Determines if the current block is a standalone 'if' statement. - - Args: - None - - Returns: - bool: True if the current block is a standalone 'if' statement, False otherwise. - """ - return self.ts_node.type == "if_statement" and self._else_clause_node is None - - @property - @reader - def is_else_statement(self) -> bool: - """Determines if the current block is an else block. - - A property that checks if the current TreeSitter node represents an else clause in an if/elif/else statement structure. - - Returns: - bool: True if the current block is an else block, False otherwise. - """ - return self.ts_node.type == "else_clause" - - @property - @reader - def is_elif_statement(self) -> bool: - """Determines if the current block is an elif block. - - This method checks if the current block is an elif block by verifying that it is both an if_statement and has an else clause node associated with it. - - Returns: - bool: True if the current block is an elif block, False otherwise. - """ - return self.ts_node.type == "if_statement" and self._else_clause_node is not None - - @writer - def _else_if_to_if(self) -> None: - """Converts an elif block to an if block. - - Args: - None - - Returns: - None - """ - if not self.is_elif_statement: - return - - self.remove_byte_range(self.ts_node.start_byte - len("else "), self.ts_node.start_byte) diff --git a/src/codegen/sdk/typescript/statements/import_statement.py b/src/codegen/sdk/typescript/statements/import_statement.py deleted file mode 100644 index a54070588..000000000 --- a/src/codegen/sdk/typescript/statements/import_statement.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.expressions.builtin import Builtin -from codegen.sdk.core.statements.import_statement import ImportStatement -from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.typescript.import_resolution import TSImport -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - from codegen.sdk.typescript.file import TSFile - - -@ts_apidoc -class TSImportStatement(ImportStatement["TSFile", TSImport, "TSCodeBlock"], Builtin): - """A class representing an import statement in TypeScript, managing both static and dynamic imports. - - This class handles various types of TypeScript imports including regular import statements, - dynamic imports, and export statements. It provides functionality to manage and track imports - within a TypeScript file, enabling operations like analyzing dependencies, moving imports, - and modifying import statements. - - Attributes: - imports (Collection): A collection of TypeScript imports contained within the statement. - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int, *, source_node: TSNode | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - imports = [] - if ts_node.type == "import_statement": - imports.extend(TSImport.from_import_statement(ts_node, file_node_id, ctx, self)) - elif ts_node.type in ["call_expression", "lexical_declaration", "expression_statement", "type_alias_declaration"]: - import_call_node = source_node.child_by_field_name("function") - arguments = source_node.child_by_field_name("arguments") - imports.extend(TSImport.from_dynamic_import_statement(import_call_node, arguments, file_node_id, ctx, self)) - elif ts_node.type == "export_statement": - imports.extend(TSImport.from_export_statement(source_node, file_node_id, ctx, self)) - self.imports = Collection(ts_node, file_node_id, ctx, self, delimiter="\n", children=imports) - for imp in self.imports: - imp.import_statement = self diff --git a/src/codegen/sdk/typescript/statements/labeled_statement.py b/src/codegen/sdk/typescript/statements/labeled_statement.py deleted file mode 100644 index a898418d4..000000000 --- a/src/codegen/sdk/typescript/statements/labeled_statement.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from codegen.sdk.core.expressions import Expression, Name -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.statements.statement import Statement, StatementType -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -Parent = TypeVar("Parent", bound="TSCodeBlock") - - -@ts_apidoc -class TSLabeledStatement(Statement[Parent], HasName, Generic[Parent]): - """Statement with a named label. It resolves to various types of statements like loops, switch cases, etc. - - Examples: - ``` - outerLoop: for (let i = 0; i < 5; i++) { - innerLoop: for (let j = 0; j < 5; j++) { - if (i === 2 && j === 2) { - break outerLoop; // This will break out of the outer loop - } - console.log(`i: ${i}, j: ${j}`); - } - } - ``` - ``` - emptyStatement: { pass } - ``` - - Attributes: - body: The body of the labeled statement, which can be an Expression or None. - """ - - statement_type = StatementType.LABELED_STATEMENT - body: Expression | None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self._name_node = Name(ts_node.child_by_field_name("label"), file_node_id, ctx, self) - body_node = self.ts_node.child_by_field_name("body") - self.body = self._parse_expression(body_node) if body_node else None - - @property - def label(self) -> str: - """Returns the label of the labeled statement. - - Acts as a property getter that returns the name of the labeled statement. For example, in code like - 'outerLoop: for...', this would return 'outerLoop'. - - Returns: - str: The label name of the statement. - """ - return self.name diff --git a/src/codegen/sdk/typescript/statements/switch_case.py b/src/codegen/sdk/typescript/statements/switch_case.py deleted file mode 100644 index cdd43e1dd..000000000 --- a/src/codegen/sdk/typescript/statements/switch_case.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import TYPE_CHECKING - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.statements.switch_case import SwitchCase -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.statements.block_statement import TSBlockStatement -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement - - -@ts_apidoc -class TSSwitchCase(SwitchCase[TSCodeBlock["TSSwitchStatement"]], TSBlockStatement): - """Typescript switch case. - - Attributes: - default: is this a default case? - """ - - default: bool - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.condition = self.child_by_field_name("value") - self.default = self.ts_node.type == "switch_default" diff --git a/src/codegen/sdk/typescript/statements/switch_statement.py b/src/codegen/sdk/typescript/statements/switch_statement.py deleted file mode 100644 index 0dbec180f..000000000 --- a/src/codegen/sdk/typescript/statements/switch_statement.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Self - -from codegen.sdk.core.statements.switch_statement import SwitchStatement -from codegen.sdk.typescript.statements.switch_case import TSSwitchCase -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -@ts_apidoc -class TSSwitchStatement(SwitchStatement["TSCodeBlock[Self]", "TSCodeBlock", TSSwitchCase]): - """Typescript switch statement""" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - self.value = self.child_by_field_name("value") - code_block = self.ts_node.child_by_field_name("body") - self.cases = [] - for node in code_block.named_children: - self.cases.append(TSSwitchCase(node, file_node_id, ctx, self)) diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py deleted file mode 100644 index 947ed3fbd..000000000 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Self, override - -from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement -from codegen.sdk.extensions.autocommit import commiter, reader -from codegen.sdk.typescript.statements.block_statement import TSBlockStatement -from codegen.sdk.typescript.statements.catch_statement import TSCatchStatement -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from collections.abc import Sequence - - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.dataclasses.usage import UsageKind - from codegen.sdk.core.detached_symbols.function_call import FunctionCall - from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock - from codegen.sdk.core.interfaces.has_name import HasName - from codegen.sdk.core.interfaces.importable import Importable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -@ts_apidoc -class TSTryCatchStatement(TryCatchStatement["TSCodeBlock"], TSBlockStatement): - """Abstract representation of the try/catch/finally block in TypeScript. - - Attributes: - catch: The catch block. - """ - - catch: TSCatchStatement[Self] | None = None - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - if handler_node := self.ts_node.child_by_field_name("handler"): - self.catch = TSCatchStatement(handler_node, file_node_id, ctx, self) - if finalizer_node := self.ts_node.child_by_field_name("finalizer"): - self.finalizer = TSBlockStatement(finalizer_node, file_node_id, ctx, self.code_block) - - @property - @reader - def function_calls(self) -> list[FunctionCall]: - """Gets all function calls within a try-catch-finally statement. - - This property retrieves all function calls from the try block, catch block (if present), and finally block (if present). - - Returns: - list[FunctionCall]: A list of function calls found within the try-catch-finally statement, including those from - the try block, catch block (if it exists), and finally block (if it exists). - """ - fcalls = super().function_calls - if self.catch: - fcalls.extend(self.catch.function_calls) - if self.finalizer: - fcalls.extend(self.finalizer.function_calls) - return fcalls - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - super()._compute_dependencies(usage_type, dest) - if self.catch: - self.catch._compute_dependencies(usage_type, dest) - if self.finalizer: - self.finalizer._compute_dependencies(usage_type, dest) - - @property - @noapidoc - def descendant_symbols(self) -> list[Importable]: - symbols = super().descendant_symbols - if self.catch: - symbols.extend(self.catch.descendant_symbols) - if self.finalizer: - symbols.extend(self.finalizer.descendant_symbols) - return symbols - - @property - @reader - @override - def nested_code_blocks(self) -> list[TSCodeBlock]: - """Returns all nested CodeBlocks within the statement. - - Retrieves a list of all the code blocks nested within this try/catch/finally statement, including the catch and finally blocks if they exist. - - Returns: - list[TSCodeBlock]: A list of nested code blocks, including the catch and finally blocks. - """ - nested_blocks = super().nested_code_blocks - if self.catch: - nested_blocks.append(self.catch.code_block) - if self.finalizer: - nested_blocks.append(self.finalizer.code_block) - return nested_blocks - - @property - @noapidoc - def other_possible_blocks(self) -> Sequence[ConditionalBlock]: - if self.catch: - return [self.catch] - else: - return [] diff --git a/src/codegen/sdk/typescript/statements/while_statement.py b/src/codegen/sdk/typescript/statements/while_statement.py deleted file mode 100644 index fdbb1dee8..000000000 --- a/src/codegen/sdk/typescript/statements/while_statement.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.statements.while_statement import WhileStatement -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.shared.decorators.docs import ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - - -@ts_apidoc -class TSWhileStatement(WhileStatement["TSCodeBlock"], TSHasBlock): - """A TypeScript while statement class that represents while loops and manages their condition and code block. - - This class provides functionality for handling while statements in TypeScript code, - including managing the loop's condition and associated code block. It extends the base - WhileStatement class with TypeScript-specific behavior. - - Attributes: - condition (str | None): The condition expression of the while loop. - """ - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, pos) - condition = self.child_by_field_name("condition") - self.condition = condition.value if condition else None diff --git a/src/codegen/sdk/typescript/symbol.py b/src/codegen/sdk/typescript/symbol.py deleted file mode 100644 index e3cc89828..000000000 --- a/src/codegen/sdk/typescript/symbol.py +++ /dev/null @@ -1,514 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, Self, Unpack - -from codegen.sdk.core.assignment import Assignment -from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind, UsageType -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.core.expressions import Value -from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute -from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.exportable import Exportable -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.core.type_alias import TypeAlias -from codegen.sdk.enums import ImportType, NodeType -from codegen.sdk.typescript.import_resolution import TSImport -from codegen.sdk.typescript.statements.comment import TSComment, TSCommentType -from codegen.sdk.typescript.symbol_groups.comment_group import TSCommentGroup -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from tree_sitter import Node as TSNode - - from codegen.sdk.codebase.flagging.code_flag import CodeFlag - from codegen.sdk.codebase.flagging.enums import FlagKwargs - from codegen.sdk.core.detached_symbols.parameter import Parameter - from codegen.sdk.core.file import SourceFile - from codegen.sdk.core.import_resolution import Import - from codegen.sdk.core.interfaces.editable import Editable - from codegen.sdk.core.node_id_factory import NodeId - from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - from codegen.sdk.typescript.interfaces.has_block import TSHasBlock - - -@ts_apidoc -class TSSymbol(Symbol["TSHasBlock", "TSCodeBlock"], Exportable): - """A TypeScript symbol representing a code element with advanced manipulation capabilities. - - This class extends Symbol and Exportable to provide TypeScript-specific functionality for managing - code symbols. It offers methods for handling imports, comments, code refactoring, and file operations - like moving symbols between files while maintaining their dependencies and references. - - The class provides functionality for managing both inline and block comments, setting and retrieving - import strings, and maintaining semicolon presence. It includes capabilities for moving symbols between - files with options to handle dependencies and import strategy selection. - """ - - @reader - def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: - """Generates the appropriate import string for a symbol. - - Constructs and returns an import statement string based on the provided parameters, formatting it according - to TypeScript import syntax rules. - - Args: - alias (str | None, optional): The alias to use for the imported symbol. Defaults to None. - module (str | None, optional): The module to import from. If None, uses the file's import module name. - Defaults to None. - import_type (ImportType, optional): The type of import to generate (e.g., WILDCARD). Defaults to - ImportType.UNKNOWN. - is_type_import (bool, optional): Whether this is a type-only import. Defaults to False. - - Returns: - str: A formatted import statement string. - """ - type_prefix = "type " if is_type_import else "" - import_module = module if module is not None else self.file.import_module_name - - if import_type == ImportType.WILDCARD: - file_as_module = self.file.name - return f"import {type_prefix}* as {file_as_module} from {import_module};" - elif alias is not None and alias != self.name: - return f"import {type_prefix}{{ {self.name} as {alias} }} from {import_module};" - else: - return f"import {type_prefix}{{ {self.name} }} from {import_module};" - - @property - @reader(cache=False) - def extended_nodes(self) -> list[Editable]: - """Returns the list of nodes associated with this symbol including extended nodes. - - This property returns a list of Editable nodes that includes any wrapping or extended symbols like `export`, `public`, or decorators. - For example, if the symbol is within an `export_statement` or `lexical_declaration`, those nodes will be included in the list. - - Args: - No arguments. - - Returns: - list[Editable]: A list of Editable nodes including the symbol's extended nodes like export statements and decorators. - """ - nodes = super().extended_nodes - - # Check if the symbol is wrapped by another node like 'export_statement' - new_ts_node = self.ts_node - while (parent := new_ts_node.parent).type in ("export_statement", "lexical_declaration", "variable_declarator"): - new_ts_node = parent - - return [Value(new_ts_node, self.file_node_id, self.ctx, self.parent) if node.ts_node == self.ts_node else node for node in nodes] - - @property - @reader - def comment(self) -> TSCommentGroup | None: - """Retrieves the comment group associated with the symbol. - - Returns the TSCommentGroup object that contains any comments associated with the symbol. - A comment group represents one or more related comments that precede the symbol in the code. - - Returns: - TSCommentGroup | None: The comment group for the symbol if one exists, None otherwise. - """ - return TSCommentGroup.from_symbol_comments(self) - - @property - @reader - def inline_comment(self) -> TSCommentGroup | None: - """Property that retrieves the inline comment group associated with the symbol. - - Args: - None - - Returns: - TSCommentGroup | None: The inline comment group associated with the symbol if it exists, - otherwise None. - """ - return TSCommentGroup.from_symbol_inline_comments(self) - - @writer - def set_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, comment_type: TSCommentType = TSCommentType.DOUBLE_SLASH) -> None: - """Sets a comment to the symbol. - - Adds or updates a comment for a code symbol. If a comment already exists, it will be edited. If no - comment exists, a new comment group will be created. - - Args: - comment (str): The comment text to be added. - auto_format (bool, optional): Whether to automatically format the text into a comment syntax. - Defaults to True. - clean_format (bool, optional): Whether to clean the format of the comment before inserting. - Defaults to True. - comment_type (TSCommentType, optional): The style of comment to add. - Defaults to TSCommentType.DOUBLE_SLASH. - - Returns: - None - - Raises: - None - """ - if clean_format: - comment = TSComment.clean_comment(comment) - - # If comment already exists, add the comment to the existing comment group - if self.comment: - if auto_format: - self.comment.edit_text(comment) - else: - self.comment.edit(comment, fix_indentation=True) - else: - if auto_format: - comment = TSComment.generate_comment(comment, comment_type) - self.insert_before(comment, fix_indentation=True) - - @writer - def add_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, comment_type: TSCommentType = TSCommentType.DOUBLE_SLASH) -> None: - """Adds a new comment to the symbol. - - Appends a comment to an existing comment group or creates a new comment group if none exists. - - Args: - comment (str): The comment text to be added. - auto_format (bool): Whether to automatically format the text into a comment style. Defaults to True. - clean_format (bool): Whether to clean the format of the comment before inserting. Defaults to True. - comment_type (TSCommentType): Type of comment to add. Defaults to TSCommentType.DOUBLE_SLASH. - - Returns: - None - - Raises: - None - """ - if clean_format: - comment = TSComment.clean_comment(comment) - if auto_format: - comment = TSComment.generate_comment(comment, comment_type) - - # If comment already exists, add the comment to the existing comment group - if self.comment: - self.comment.insert_after(comment, fix_indentation=True) - else: - self.insert_before(comment, fix_indentation=True) - - @writer - def set_inline_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, node: TSNode | None = None) -> None: - """Sets an inline comment to the symbol. - - Sets or replaces an inline comment for a symbol at its current position. If an inline comment - already exists, it is replaced with the new comment. If no inline comment exists, a new one - will be created adjacent to the symbol. - - Args: - comment (str): The inline comment text to be added. - auto_format (bool, optional): Whether to automatically format the text as a comment. - Defaults to True. - clean_format (bool, optional): Whether to clean the comment format before inserting. - Defaults to True. - node (TSNode | None, optional): The specific node to attach the comment to. - Defaults to None. - - Returns: - None - - Raises: - None - """ - if clean_format: - comment = TSComment.clean_comment(comment) - - if self.inline_comment: - if auto_format: - self.inline_comment.edit_text(comment) - else: - self.inline_comment.edit(comment) - else: - if auto_format: - comment = " " + TSComment.generate_comment(comment, TSCommentType.DOUBLE_SLASH) - node = node or self.ts_node - Value(node, self.file_node_id, self.ctx, self).insert_after(comment, fix_indentation=False, newline=False) - - @property - @reader - def semicolon_node(self) -> Editable | None: - """Retrieves the semicolon node associated with a TypeScript symbol. - - A semicolon node is a TreeSitter node of type ';' that appears immediately after the symbol node. - - Returns: - Editable | None: The semicolon node wrapped as an Editable if it exists, None otherwise. - """ - sibbling = self.ts_node.next_sibling - if sibbling and sibbling.type == ";": - return Value(sibbling, self.file_node_id, self.ctx, self) - return None - - @property - @reader - def has_semicolon(self) -> bool: - """Checks whether the current symbol has a semicolon at the end. - - This property determines if a semicolon is present at the end of the symbol by checking - if the semicolon_node property exists. - - Returns: - bool: True if the symbol has a semicolon at the end, False otherwise. - """ - return self.semicolon_node is not None - - @noapidoc - def _move_to_file( - self, - file: SourceFile, - encountered_symbols: set[Symbol | Import], - include_dependencies: bool = True, - strategy: Literal["add_back_edge", "update_all_imports", "duplicate_dependencies"] = "update_all_imports", - ) -> tuple[NodeId, NodeId]: - # TODO: Prevent creation of import loops (!) - raise a ValueError and make the agent fix it - # =====[ Arg checking ]===== - if file == self.file: - return file.file_node_id, self.node_id - - # =====[ Move over dependencies recursively ]===== - if include_dependencies: - try: - for dep in self.dependencies: - if dep in encountered_symbols: - continue - - # =====[ Symbols - move over ]===== - elif isinstance(dep, TSSymbol): - if dep.is_top_level: - encountered_symbols.add(dep) - dep._move_to_file(file, encountered_symbols=encountered_symbols, include_dependencies=True, strategy=strategy) - - # =====[ Imports - copy over ]===== - elif isinstance(dep, TSImport): - if dep.imported_symbol: - file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type) - else: - file.add_import(dep.source) - - else: - msg = f"Unknown dependency type {type(dep)}" - raise ValueError(msg) - except Exception as e: - print(f"Failed to move dependencies of {self.name}: {e}") - else: - try: - for dep in self.dependencies: - if isinstance(dep, Assignment): - msg = "Assignment not implemented yet" - raise NotImplementedError(msg) - - # =====[ Symbols - move over ]===== - elif isinstance(dep, Symbol) and dep.is_top_level: - file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=isinstance(dep, TypeAlias)) - - if not dep.is_exported: - dep.file.add_export_to_symbol(dep) - pass - - # =====[ Imports - copy over ]===== - elif isinstance(dep, TSImport): - if dep.imported_symbol: - file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type, is_type_import=dep.is_type_import()) - else: - file.add_import(dep.source) - - except Exception as e: - print(f"Failed to move dependencies of {self.name}: {e}") - - # =====[ Make a new symbol in the new file ]===== - # This will update all edges etc. - file.add_symbol(self) - import_line = self.get_import_string(module=file.import_module_name) - - # =====[ Checks if symbol is used in original file ]===== - # Takes into account that it's dependencies will be moved - is_used_in_file = any(usage.file == self.file and usage.node_type == NodeType.SYMBOL and usage not in encountered_symbols for usage in self.symbol_usages) - - # ======[ Strategy: Duplicate Dependencies ]===== - if strategy == "duplicate_dependencies": - # If not used in the original file. or if not imported from elsewhere, we can just remove the original symbol - if not is_used_in_file and not any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages): - self.remove() - - # ======[ Strategy: Add Back Edge ]===== - # Here, we will add a "back edge" to the old file importing the self - elif strategy == "add_back_edge": - if is_used_in_file: - self.file.add_import(import_line) - if self.is_exported: - self.file.add_import(f"export {{ {self.name} }}") - elif self.is_exported: - module_name = file.name - self.file.add_import(f"export {{ {self.name} }} from '{module_name}'") - # Delete the original symbol - self.remove() - - # ======[ Strategy: Update All Imports ]===== - # Update the imports in all the files which use this symbol to get it from the new file now - elif strategy == "update_all_imports": - for usage in self.usages: - if isinstance(usage.usage_symbol, TSImport): - # Add updated import - if usage.usage_symbol.resolved_symbol is not None and usage.usage_symbol.resolved_symbol.node_type == NodeType.SYMBOL and usage.usage_symbol.resolved_symbol == self: - usage.usage_symbol.file.add_import(import_line) - usage.usage_symbol.remove() - elif usage.usage_type == UsageType.CHAINED: - # Update all previous usages of import * to the new import name - if usage.match and "." + self.name in usage.match: - if isinstance(usage.match, FunctionCall): - usage.match.get_name().edit(self.name) - if isinstance(usage.match, ChainedAttribute): - usage.match.edit(self.name) - usage.usage_symbol.file.add_import(import_line) - if is_used_in_file: - self.file.add_import(import_line) - # Delete the original symbol - self.remove() - - def _convert_proptype_to_typescript(self, prop_type: Editable, param: Parameter | None, level: int) -> str: - """Converts a PropType definition to its TypeScript equivalent.""" - # Handle basic types - type_map = {"string": "string", "number": "number", "bool": "boolean", "object": "object", "array": "any[]", "func": "CallableFunction"} - if prop_type.source in type_map: - return type_map[prop_type.source] - if isinstance(prop_type, ChainedAttribute): - if prop_type.attribute.source == "node": - return "T" - if prop_type.attribute.source == "element": - self.file.add_import("import React from 'react';\n") - return "React.ReactElement" - if prop_type.attribute.source in type_map: - return type_map[prop_type.attribute.source] - # if prop_type.attribute.source == "func": - # params = [] - # if param: - # for usage in param.usages: - # call = None - # if isinstance(usage.match, FunctionCall): - # call = usage.match - # elif isinstance(usage.match.parent, FunctionCall): - # call = usage.match.parent - # if call: - # for arg in call.args: - # resolved_value = arg.value.resolved_value - # if resolved_value.rstrip("[]") not in ("number", "string", "boolean", "any", "object"): - # resolved_value = "any" - # params.append(f"{arg.name or arg.source}: {resolved_value}") - # return f"({",".join(params)}) => void" - return "Function" - if prop_type.attribute.source == "isRequired": - return self._convert_proptype_to_typescript(prop_type.object, param, level) - if isinstance(prop_type, FunctionCall): - if prop_type.name == "isRequired": - return self._convert_proptype_to_typescript(prop_type.args[0].value, param, level) - # Handle arrays - if prop_type.name == "arrayOf": - item = self._convert_proptype_to_typescript(prop_type.args[0].value, param, level) - # needs_parens = isinstance(prop_type.args[0].value, FunctionCall) - needs_parens = False - return f"({item})[]" if needs_parens else f"{item}[]" - - # Handle oneOf - if prop_type.name == "oneOf": - values = [arg.source for arg in prop_type.args[0].value] - # Add parentheses if one of the values is a function - return " | ".join(f"({t})" if "() => void" == t else t for t in values) - # Handle anyOf (alias for oneOf) - if prop_type.name == "anyOf": - values = [arg.source for arg in prop_type.args[0].value] - # Add parentheses if one of the values is a function - return " | ".join(f"({t})" if "() => void" == t else t for t in values) - - # Handle oneOfType - if prop_type.name == "oneOfType": - types = [self._convert_proptype_to_typescript(arg, param, level) for arg in prop_type.args[0].value] - # Only add parentheses if one of the types is a function - return " | ".join(f"({t})" if "() => void" == t else t for t in types) - - # Handle shape - if prop_type.name == "shape": - return self._convert_dict(prop_type.args[0].value, level) - if prop_type.name == "objectOf": - return self._convert_object_of(prop_type.args[0].value, level) - return "any" - - def _convert_dict(self, value: Type, level: int) -> str: - """Converts a dictionary of PropTypes to a TypeScript interface string.""" - result = "{\n" - for key, value in value.items(): - is_required = isinstance(value, ChainedAttribute) and value.attribute.source == "isRequired" - optional = "" if is_required else "?" - indent = " " * level - param = next((p for p in self.parameters if p.name == key), None) if self.parameters else None - result += f"{indent}{key}{optional}: {self._convert_proptype_to_typescript(value, param, level + 1)};\n" - indent = " " * (level - 1) - - result += f"{indent}}}" - return result - - def _convert_object_of(self, value: Type, level: int) -> str: - """Converts a dictionary of PropTypes to a TypeScript interface string.""" - indent = " " * level - prev_indent = " " * (level - 1) - type_value = self._convert_proptype_to_typescript(value, None, level + 1) - return f"{{\n{indent}[key: string]: {type_value};\n{prev_indent}}}" - - def _get_static_prop_types(self) -> Type | None: - """Returns a dictionary of prop types for a React component.""" - for usage in self.usages: - if isinstance(usage.usage_symbol, Assignment) and usage.usage_symbol.name == "propTypes": - assert isinstance(usage.usage_symbol.value, Type), usage.usage_symbol.value.__class__ - return usage.usage_symbol.value - return None - - @noapidoc - def convert_to_react_interface(self) -> str | None: - if not self.is_jsx: - return None - - component_name = self.name - # Handle class components with static propTypes - if proptypes := self._get_static_prop_types(): - generics = "" - generic_name = "" - if "PropTypes.node" in proptypes.source: - generics = "" - generic_name = "" - self.file.add_import("import React from 'react';\n") - interface_name = f"{component_name}Props" - # Create interface definition - interface_def = f"interface {interface_name}{generics} {self._convert_dict(proptypes, 1)}" - - # Insert interface and update component - self.insert_before(interface_def + "\n") - - proptypes.parent_statement.remove() - for imp in self.file.imports: - if imp.module.source.strip("'").strip('"') in ("react", "prop-types"): - imp.remove_if_unused() - return interface_name + generic_name - - @writer - def flag(self, **kwargs: Unpack[FlagKwargs]) -> CodeFlag[Self]: - """Flags a TypeScript symbol by adding a flag comment and returning a CodeFlag. - - This implementation first creates the CodeFlag through the standard flagging system, - then adds a TypeScript-specific comment to visually mark the flagged code. - - Args: - **kwargs: Flag keyword arguments including optional 'message' - - Returns: - CodeFlag[Self]: The code flag object for tracking purposes - """ - # First create the standard CodeFlag through the base implementation - code_flag = super().flag(**kwargs) - - # Add a TypeScript comment to visually mark the flag - message = kwargs.get("message", "") - if message: - self.set_inline_comment(f"🚩 {message}") - - return code_flag diff --git a/src/codegen/sdk/typescript/symbol_groups/comment_group.py b/src/codegen/sdk/typescript/symbol_groups/comment_group.py deleted file mode 100644 index 3f23b276b..000000000 --- a/src/codegen/sdk/typescript/symbol_groups/comment_group.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from codegen.sdk.core.symbol_groups.comment_group import CommentGroup -from codegen.sdk.typescript.statements.comment import TSComment, TSCommentType -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - -if TYPE_CHECKING: - from codegen.sdk.typescript.symbol import TSSymbol - - -@ts_apidoc -class TSCommentGroup(CommentGroup): - """A group of related symbols that represent a comment or docstring in TypeScript - - For example: - ``` - // Comment 1 - // Comment 2 - // Comment 3 - ``` - would be 3 individual comments (accessible via `symbols`), but together they form a `CommentGroup` (accessible via `self). - """ - - @staticmethod - @noapidoc - def _get_sibbling_comments(symbol: TSSymbol) -> list[TSComment]: - # Locate the body that contains the comment nodes - current_node = symbol.ts_node - parent_node = symbol.ts_node.parent - while parent_node and parent_node.type not in ["program", "class_body", "block", "statement_block"]: - current_node = parent_node - parent_node = parent_node.parent - - if not parent_node: - return None - - # Find the correct index of function_node in parent_node's children - function_index = parent_node.children.index(current_node) - - if function_index is None: - return None # function_node is not a child of parent_node - - if function_index == 0: - return None # No nodes before this function, hence no comments - - comment_nodes = [] - # Iterate backwards from the function node to collect all preceding comment nodes - for i in range(function_index - 1, -1, -1): - if parent_node.children[i].type == "comment": - # Check if the comment is directly above each other - if parent_node.children[i].end_point[0] == parent_node.children[i + 1].start_point[0] - 1: - comment = TSComment.from_code_block(parent_node.children[i], symbol) - comment_nodes.insert(0, comment) - else: - break # Stop if there is a break in the comments - else: - break # Stop if a non-comment node is encountered - - return comment_nodes - - @classmethod - @noapidoc - def from_symbol_comments(cls, symbol: TSSymbol): - comment_nodes = cls._get_sibbling_comments(symbol) - if not comment_nodes: - return None - return cls(comment_nodes, symbol.file_node_id, symbol.ctx, symbol) - - @classmethod - @noapidoc - def from_symbol_inline_comments(cls, symbol: TSSymbol): - # Locate the body that contains the comment nodes - current_node = symbol.ts_node - parent_node = symbol.ts_node.parent - while parent_node and parent_node.type not in ["program", "class_body", "block", "statement_block"]: - current_node = parent_node - parent_node = parent_node.parent - - if not parent_node: - return None - - # Find the correct index of function_node in parent_node's children - function_index = parent_node.children.index(current_node) - - if function_index is None: - return None # function_node is not a child of parent_node - - comment_nodes = [] - # Check if there are any comments after the function node - if function_index + 1 < len(parent_node.children): - if parent_node.children[function_index + 1].type == "comment": - # Check if the comment is on the same line - if parent_node.children[function_index].end_point[0] == parent_node.children[function_index + 1].start_point[0]: - comment = TSComment.from_code_block(parent_node.children[function_index + 1], symbol) - comment_nodes.append(comment) - - if not comment_nodes: - return None - - return cls(comment_nodes, symbol.file_node_id, symbol.ctx, symbol) - - @classmethod - @noapidoc - def from_docstring(cls, symbol: TSSymbol) -> TSCommentGroup | None: - """Returns the docstring of the function""" - comment_nodes = cls._get_sibbling_comments(symbol) - if not comment_nodes: - return None - # Docstring comments are filtered by SLASH_STAR comments - docstring_nodes = [comment for comment in comment_nodes if comment.comment_type == TSCommentType.SLASH_STAR] - if not docstring_nodes: - return None - return cls(docstring_nodes, symbol.file_node_id, symbol.ctx, symbol) - - @classmethod - @noapidoc - def from_comment_nodes(cls, comment_nodes: list[TSComment], symbol: TSSymbol): - if not comment_nodes: - return None - - # Docstring comments are filtered by SLASH_STAR comments - docstring_nodes = [comment for comment in comment_nodes if comment.comment_type == TSCommentType.SLASH_STAR] - if not docstring_nodes: - return None - return cls(docstring_nodes, symbol.file_node_id, symbol.ctx, symbol) diff --git a/src/codegen/sdk/typescript/symbol_groups/dict.py b/src/codegen/sdk/typescript/symbol_groups/dict.py deleted file mode 100644 index d35f2a5c4..000000000 --- a/src/codegen/sdk/typescript/symbol_groups/dict.py +++ /dev/null @@ -1,144 +0,0 @@ -from typing import TYPE_CHECKING, Self, TypeVar, override - -from tree_sitter import Node as TSNode - -from codegen.sdk.core.autocommit import writer -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.expressions.string import String -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.node_id_factory import NodeId -from codegen.sdk.core.symbol_groups.dict import Dict, Pair -from codegen.sdk.extensions.autocommit import reader -from codegen.shared.decorators.docs import apidoc, noapidoc, ts_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.codebase.codebase_context import CodebaseContext - -Parent = TypeVar("Parent", bound="Editable") -TExpression = TypeVar("TExpression", bound=Expression) - -logger = get_logger(__name__) - - -@ts_apidoc -class TSPair(Pair): - """A TypeScript pair node that represents key-value pairs in object literals. - - A specialized class extending `Pair` for handling TypeScript key-value pairs, - particularly in object literals. It provides functionality for handling both - regular key-value pairs and shorthand property identifiers, with support for - reducing boolean conditions. - - Attributes: - shorthand (bool): Indicates whether this pair uses shorthand property syntax. - """ - - shorthand: bool - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: - super().__init__(ts_node, file_node_id, ctx, parent) - self.shorthand = ts_node.type == "shorthand_property_identifier" - - def _get_key_value(self) -> tuple[Expression[Self] | None, Expression[Self] | None]: - from codegen.sdk.typescript.function import TSFunction - - key, value = None, None - - if self.ts_node.type == "pair": - key = self.child_by_field_name("key") - value = self.child_by_field_name("value") - if TSFunction.is_valid_node(value.ts_node): - value = self._parse_expression(value.ts_node) - elif self.ts_node.type == "shorthand_property_identifier": - key = value = self._parse_expression(self.ts_node) - elif TSFunction.is_valid_node(self.ts_node): - value = self._parse_expression(self.ts_node) - key = value.get_name() - else: - return super()._get_key_value() - return key, value - - @writer - def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Reduces an editable to the following condition""" - if self.shorthand and node == self.value: - # Object shorthand - self.parent[self.key.source] = self.ctx.node_classes.bool_conversion[bool_condition] - else: - super().reduce_condition(bool_condition, node) - - -@apidoc -class TSDict(Dict, HasAttribute): - """A typescript dict object. You can use standard operations to operate on this dict (IE len, del, set, get, etc)""" - - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = TSPair) -> None: - super().__init__(ts_node, file_node_id, ctx, parent, delimiter=delimiter, pair_type=pair_type) - - def __getitem__(self, __key: str) -> TExpression: - for pair in self._underlying: - pair_match = None - - if isinstance(pair, Pair): - if isinstance(pair.key, String): - if pair.key.content == str(__key): - pair_match = pair - elif pair.key is not None: - if pair.key.source == str(__key): - pair_match = pair - - if pair_match: - if pair_match.value is not None: - return pair_match.value - else: - return pair_match.key - msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" - raise KeyError(msg) - - def __setitem__(self, __key: str, __value: TExpression) -> None: - new_value = __value.source if isinstance(__value, Editable) else str(__value) - for pair in self._underlying: - pair_match = None - - if isinstance(pair, Pair): - if isinstance(pair.key, String): - if pair.key.content == str(__key): - pair_match = pair - elif pair.key is not None: - if pair.key.source == str(__key): - pair_match = pair - - if pair_match: - # CASE: {a: b} - if not pair_match.shorthand: - if __key == new_value: - pair_match.edit(f"{__key}") - else: - pair.value.edit(f"{new_value}") - # CASE: {a} - else: - if __key == new_value: - pair_match.edit(f"{__key}") - else: - pair_match.edit(f"{__key}: {new_value}") - break - # CASE: {} - else: - if not self.ctx.node_classes.int_dict_key: - try: - int(__key) - __key = f"'{__key}'" - except ValueError: - pass - if __key == new_value: - self._underlying.append(f"{__key}") - else: - self._underlying.append(f"{__key}: {new_value}") - - @reader - @noapidoc - @override - def resolve_attribute(self, name: str) -> "Expression | None": - return self.get(name, None) diff --git a/src/codegen/sdk/typescript/ts_config.py b/src/codegen/sdk/typescript/ts_config.py deleted file mode 100644 index 99dde9469..000000000 --- a/src/codegen/sdk/typescript/ts_config.py +++ /dev/null @@ -1,485 +0,0 @@ -import os -from functools import cache -from pathlib import Path -from typing import TYPE_CHECKING - -import pyjson5 - -from codegen.sdk.core.directory import Directory -from codegen.sdk.core.file import File -from codegen.shared.decorators.docs import ts_apidoc -from codegen.shared.logging.get_logger import get_logger - -if TYPE_CHECKING: - from codegen.sdk.typescript.config_parser import TSConfigParser - from codegen.sdk.typescript.file import TSFile - -logger = get_logger(__name__) - - -@ts_apidoc -class TSConfig: - """TypeScript configuration file specified in tsconfig.json, used for import resolution and computing dependencies. - - Attributes: - config_file: The configuration file object representing the tsconfig.json file. - config_parser: The parser used to interpret the TypeScript configuration. - config: A dictionary containing the parsed configuration settings. - """ - - config_file: File - config_parser: "TSConfigParser" - config: dict - - # Base config values - _base_config: "TSConfig | None" = None - _base_url: str | None = None - _out_dir: str | None = None - _root_dir: str | None = None - _root_dirs: list[str] = [] - _paths: dict[str, list[str]] = {} - _references: list[tuple[str, Directory | File]] = [] - - # Self config values - _self_base_url: str | None = None - _self_out_dir: str | None = None - _self_root_dir: str | None = None - _self_root_dirs: list[str] = [] - _self_paths: dict[str, list[str]] = {} - _self_references: list[Directory | File] = [] - - # Precomputed import aliases - _computed_path_import_aliases: bool = False - _path_import_aliases: dict[str, list[str]] = {} - _reference_import_aliases: dict[str, list[str]] = {} - # Optimization hack. If all the path alises start with `@` or `~`, then we can skip any path that doesn't start with `@` or `~` - # when computing the import resolution. - _import_optimization_enabled: bool = False - - def __init__(self, config_file: File, config_parser: "TSConfigParser"): - self.config_file = config_file - self.config_parser = config_parser - # Try to parse the config file as JSON5. Fallback to empty dict if it fails. - # We use json5 because it supports comments in the config file. - try: - self.config = pyjson5.loads(config_file.content) - except pyjson5.Json5Exception: - logger.exception(f"Failed to parse tsconfig.json file: {config_file.filepath}") - self.config = {} - - # Precompute the base config, base url, paths, and references - self._precompute_config_values() - - def __repr__(self): - return f"TSConfig({self.config_file.filepath})" - - def _precompute_config_values(self): - """Precomputes the base config, base url, paths, and references.""" - # Precompute the base config - self._base_config = None - extends = self.config.get("extends", None) - if isinstance(extends, list): - # TODO: Support multiple extends - extends = extends[0] # Grab the first config in the list - base_config_path = self._parse_parent_config_path(extends) - - if base_config_path and base_config_path.exists(): - self._base_config = self.config_parser.get_config(base_config_path) - - # Precompute the base url - self._base_url = None - self._self_base_url = None - if base_url := self.config.get("compilerOptions", {}).get("baseUrl", None): - self._base_url = base_url - self._self_base_url = base_url - elif base_url := {} if self.base_config is None else self.base_config.base_url: - self._base_url = base_url - - # Precompute the outDir - self._out_dir = None - self._self_out_dir = None - if out_dir := self.config.get("compilerOptions", {}).get("outDir", None): - self._out_dir = out_dir - self._self_out_dir = out_dir - elif out_dir := {} if self.base_config is None else self.base_config.out_dir: - self._out_dir = out_dir - - # Precompute the rootDir - self._root_dir = None - self._self_root_dir = None - if root_dir := self.config.get("compilerOptions", {}).get("rootDir", None): - self._root_dir = root_dir - self._self_root_dir = root_dir - elif root_dir := {} if self.base_config is None else self.base_config.root_dir: - self._root_dir = root_dir - - # Precompute the rootDirs - self._root_dirs = [] - self._self_root_dirs = [] - if root_dirs := self.config.get("compilerOptions", {}).get("rootDirs", None): - self._root_dirs = root_dirs - self._self_root_dirs = root_dirs - elif root_dirs := [] if self.base_config is None else self.base_config.root_dirs: - self._root_dirs = root_dirs - - # Precompute the paths - base_paths = {} if self.base_config is None else self.base_config.paths - self_paths = self.config.get("compilerOptions", {}).get("paths", {}) - self._paths = {**base_paths, **self_paths} - self._self_paths = self_paths - - # Precompute the references - self_references = [] - references = self.config.get("references", None) - if references is not None: - for reference in references: - if ref_path := reference.get("path", None): - abs_ref_path = str(self.config_file.ctx.to_relative(self._relative_to_absolute_directory_path(ref_path))) - if directory := self.config_file.ctx.get_directory(self.config_file.ctx.to_absolute(abs_ref_path)): - self_references.append((ref_path, directory)) - elif ts_config := self.config_parser.get_config(abs_ref_path): - self_references.append((ref_path, ts_config.config_file)) - elif file := self.config_file.ctx.get_file(abs_ref_path): - self_references.append((ref_path, file)) - self._references = [*self_references] # MAYBE add base references here? This breaks the reference chain though. - self._self_references = self_references - - def _precompute_import_aliases(self): - """Precomputes the import aliases.""" - if self._computed_path_import_aliases: - return - - # Force compute alias of the base config - if self.base_config is not None: - self.base_config._precompute_import_aliases() - - # Precompute the formatted paths based on compilerOptions/paths - base_path_import_aliases = {} if self.base_config is None else self.base_config.path_import_aliases - self_path_import_aliases = {} - for pattern, relative_paths in self._self_paths.items(): - formatted_pattern = pattern.replace("*", "").rstrip("/").replace("//", "/") - formatted_relative_paths = [] - for relative_path in relative_paths: - cleaned_relative_path = relative_path.replace("*", "").rstrip("/").replace("//", "/") - if self._self_base_url: - cleaned_relative_path = os.path.join(self._self_base_url, cleaned_relative_path) - formatted_absolute_path = self._relative_to_absolute_directory_path(cleaned_relative_path) - formatted_relative_path = str(self.config_file.ctx.to_relative(formatted_absolute_path)) - # Fix absolute path if its base - if formatted_relative_path == ".": - formatted_relative_path = "" - formatted_relative_paths.append(formatted_relative_path) - self_path_import_aliases[formatted_pattern] = formatted_relative_paths - self._path_import_aliases = {**base_path_import_aliases, **self_path_import_aliases} - - # Precompute the formatted paths based on references - base_reference_import_aliases = {} if self.base_config is None else self.base_config.reference_import_aliases - self_reference_import_aliases = {} - # For each reference, try to grab its tsconfig. - for ref_path, reference in self._self_references: - # TODO: THIS ENTIRE PROCESS IS KINDA HACKY. - # If the reference is a file, get its directory. - if isinstance(reference, File): - reference_dir = self.config_file.ctx.get_directory(os.path.dirname(reference.filepath)) - elif isinstance(reference, Directory): - reference_dir = reference - else: - logger.warning(f"Unknown reference type during self_reference_import_aliases computation in _precompute_import_aliases: {type(reference)}") - continue - - # With the directory, try to grab the next available file and get its tsconfig. - if reference_dir and reference_dir.files(recursive=True): - next_file: TSFile = reference_dir.files(recursive=True)[0] - else: - logger.warning(f"No next file found for reference during self_reference_import_aliases computation in _precompute_import_aliases: {reference.dirpath}") - continue - target_ts_config = next_file.ts_config - if target_ts_config is None: - logger.warning(f"No tsconfig found for reference during self_reference_import_aliases computation in _precompute_import_aliases: {reference.dirpath}") - continue - - # With the tsconfig, grab its rootDirs and outDir - target_root_dirs = target_ts_config.root_dirs if target_ts_config.root_dirs else ["."] - target_out_dir = target_ts_config.out_dir - - # Calculate the formatted pattern and formatted relative paths - formatted_relative_paths = [os.path.normpath(os.path.join(reference_dir.path, root_dir)) for root_dir in target_root_dirs] - - # Loop through each possible path part of the reference - # For example, if the reference is "../../a/b/c" and the out dir is "dist" - # then the possible reference aliases are: - # - "a/b/c/dist" - # - "b/c/dist" - # - "c/dist" - # (ignoring any .. segments) - path_parts = [p for p in ref_path.split(os.path.sep) if p and not p.startswith("..")] - for i in range(len(path_parts)): - target_path = os.path.sep.join(path_parts[i:]) - if target_path: - formatted_target_path = os.path.normpath(os.path.join(target_path, target_out_dir) if target_out_dir else target_path) - self_reference_import_aliases[formatted_target_path] = formatted_relative_paths - - self._reference_import_aliases = {**base_reference_import_aliases, **self_reference_import_aliases} - - # Precompute _import_optimization_enabled - self._import_optimization_enabled = all(k.startswith("@") or k.startswith("~") for k in list(self.path_import_aliases.keys()) + list(self.reference_import_aliases.keys())) - - # Mark that we've precomputed the import aliases - self._computed_path_import_aliases = True - - def _parse_parent_config_path(self, config_filepath: str | None) -> Path | None: - """Returns a TSConfig object from a file path.""" - if config_filepath is None: - return None - - path = self._relative_to_absolute_directory_path(config_filepath) - return Path(path if path.suffix == ".json" else f"{path}.json") - - def _relative_to_absolute_directory_path(self, relative_path: str) -> Path: - """Helper to go from a relative module to an absolute one. - Ex: "../pkg-common/" would be -> "src/dir/pkg-common/" - """ - # TODO: This could also use its parent config to resolve the path - relative = self.config_file.path.parent / relative_path.strip('"') - return self.config_file.ctx.to_absolute(relative) - - def translate_import_path(self, import_path: str) -> str: - """Translates an import path to an absolute path using the tsconfig paths. - - Takes an import path and translates it to an absolute path using the configured paths in the tsconfig file. If the import - path matches a path alias, it will be resolved according to the tsconfig paths mapping. - - For example, converts `@abc/my/pkg/src` to `a/b/c/my/pkg/src` or however it's defined in the tsconfig. - - Args: - import_path (str): The import path to translate. - - Returns: - str: The translated absolute path. If no matching path alias is found, returns the original import path unchanged. - """ - # Break out early if we can - if self._import_optimization_enabled and not import_path.startswith("@") and not import_path.startswith("~"): - return import_path - - # Step 1: Try to resolve with import_resolution_overrides - if self.config_file.ctx.config.import_resolution_overrides: - if path_check := TSConfig._find_matching_path(frozenset(self.config_file.ctx.config.import_resolution_overrides.keys()), import_path): - to_base = self.config_file.ctx.config.import_resolution_overrides[path_check] - - # Get the remaining path after the matching prefix - remaining_path = import_path[len(path_check) :].lstrip("/") - - # Join the path together - import_path = os.path.join(to_base, remaining_path) - - return import_path - - # Step 2: Keep traveling down the parent config paths until we find a match a reference_import_aliases - if path_check := TSConfig._find_matching_path(frozenset(self.reference_import_aliases.keys()), import_path): - # TODO: This assumes that there is only one to_base path for the given from_base path - to_base = self.reference_import_aliases[path_check][0] - - # Get the remaining path after the matching prefix - remaining_path = import_path[len(path_check) :].lstrip("/") - - # Join the path together - import_path = os.path.join(to_base, remaining_path) - - return import_path - - # Step 3: Keep traveling down the parent config paths until we find a match a path_import_aliases - if path_check := TSConfig._find_matching_path(frozenset(self.path_import_aliases.keys()), import_path): - # TODO: This assumes that there is only one to_base path for the given from_base path - to_base = self.path_import_aliases[path_check][0] - - # Get the remaining path after the matching prefix - remaining_path = import_path[len(path_check) :].lstrip("/") - - # Join the path together - import_path = os.path.join(to_base, remaining_path) - - return import_path - - # Step 4: Try to resolve with base path for non-relative imports - return self.resolve_base_url(import_path) - - def translate_absolute_path(self, absolute_path: str) -> str: - """Translates an absolute path to an import path using the tsconfig paths. - - Takes an absolute path and translates it to an import path using the configured paths in the tsconfig file. - - For example, converts `a/b/c/my/pkg/src` to `@abc/my/pkg/src` or however it's defined in the tsconfig. - - Args: - import_path (str): The absolute path to translate. - - Returns: - str: The translated import path. - """ - path_aliases = self._path_import_aliases - for alias, paths in path_aliases.items(): - for path in paths: - if absolute_path.startswith(path): - # Pick the first alias that matches - return absolute_path.replace(path, alias, 1) - - return absolute_path - - def resolve_base_url(self, import_path: str) -> str: - """Resolves an import path with the base url. - - If a base url is not defined, try to resolve it with its base config. - """ - # Do nothing if the import path is relative - if import_path.startswith("."): - return import_path - - # If the current config has a base url, use itq - if self._self_base_url: - if not import_path.startswith(self._self_base_url): - import_path = os.path.join(self._self_base_url, import_path) - import_path = str(self._relative_to_absolute_directory_path(import_path)) - return import_path - # If there is a base config, try to resolve it with its base url - elif self.base_config: - return self.base_config.resolve_base_url(import_path) - # Otherwise, do nothing - else: - return import_path - - @staticmethod - @cache - def _find_matching_path(path_import_aliases: set[str], path_check: str): - """Recursively find the longest matching path in path_import_aliases.""" - # Base case - if not path_check or path_check == "/": - return None - - # Recursive case - if path_check in path_import_aliases: - return path_check - elif f"{path_check}/" in path_import_aliases: - return f"{path_check}/" - else: - return TSConfig._find_matching_path(path_import_aliases, os.path.dirname(path_check)) - - @property - def base_config(self) -> "TSConfig | None": - """Returns the base TSConfig that this config inherits from. - - Gets the base configuration file that this TSConfig extends. The base configuration is used for inheriting settings like paths, baseUrl,and other compiler options. - - Returns: - TSConfig | None: The parent TSConfig object if this config extends another config file, None otherwise. - """ - return self._base_config - - @property - def base_url(self) -> str | None: - """Returns the base URL defined in the TypeScript configuration. - - This property retrieves the baseUrl from the project's TypeScript configuration file. - The baseUrl is used for resolving non-relative module names. - - Returns: - str | None: The base URL if defined in the config file or inherited from a base config, - None if not specified. - """ - return self._base_url - - @property - def out_dir(self) -> str | None: - """Returns the outDir defined in the TypeScript configuration. - - The outDir specifies the output directory for all emitted files. When specified, .js (as well as .d.ts, .js.map, etc.) - files will be emitted into this directory. The directory structure of the source files is preserved. - - Returns: - str | None: The output directory path if specified in the config file or inherited from a base config, - None if not specified. - """ - return self._out_dir - - @property - def root_dir(self) -> str | None: - """Returns the rootDir defined in the TypeScript configuration. - - The rootDir specifies the root directory of input files. This is used to control the output directory structure - with outDir. When TypeScript compiles files, it maintains the directory structure of the source files relative - to rootDir when generating output. - - Returns: - str | None: The root directory path if specified in the config file or inherited from a base config, - None if not specified. - """ - return self._root_dir - - @property - def root_dirs(self) -> list[str]: - """Returns the rootDirs defined in the TypeScript configuration. - - The rootDirs allows a list of root directories to be specified that are merged and treated as one virtual directory. - This can be used when your project structure doesn't match your runtime expectations. For example, when you have - both generated and hand-written source files that need to appear to be in the same directory at runtime. - - Returns: - list[str]: A list of root directory paths specified in the config file or inherited from a base config. - Returns an empty list if not specified. - """ - if self._root_dirs is not None: - return self._root_dirs - elif self.root_dir is not None: - return [self.root_dir] - return [] - - @property - def paths(self) -> dict[str, list[str]]: - """Returns all custom module path mappings defined in the tsconfig file. - - Retrieves path mappings from both the current tsconfig file and any inherited base config file, - translating all relative paths to absolute paths. - - Returns: - dict[str, list[str]]: A dictionary mapping path patterns to lists of absolute path destinations. - Each key is a path pattern (e.g., '@/*') and each value is a list of corresponding - absolute path destinations. - """ - return self._paths - - @property - def references(self) -> list[Directory | File]: - """Returns a list of directories that this TypeScript configuration file depends on. - - The references are defined in the 'references' field of the tsconfig.json file. These directories - are used to resolve import conflicts and narrow the search space for import resolution. - - Returns: - list[Directory | File | TSConfig]: A list of Directory, File, or TSConfig objects representing the dependent directories. - """ - return self._references - - @property - def path_import_aliases(self) -> dict[str, list[str]]: - """Returns a formatted version of the paths property from a TypeScript configuration file. - - Processes the paths dictionary by formatting path patterns and their corresponding target paths. All wildcards (*), trailing slashes, and double - slashes are removed from both the path patterns and their target paths. Target paths are also converted from relative to absolute paths. - - Returns: - dict[str, list[str]]: A dictionary where keys are formatted path patterns and values are lists of formatted absolute target paths. - """ - return self._path_import_aliases - - @property - def reference_import_aliases(self) -> dict[str, list[str]]: - """Returns a formatted version of the references property from a TypeScript configuration file. - - Processes the references dictionary by formatting reference paths and their corresponding target paths. For each - reference, retrieves its tsconfig file and path mappings. Also includes any path mappings inherited from base - configs. - - Returns: - dict[str, list[str]]: A dictionary where keys are formatted reference paths (e.g. 'module/dist') and values - are lists of absolute target paths derived from the referenced tsconfig's rootDirs and outDir settings. - """ - return {k: [str(self.config_file.ctx.to_relative(v)) for v in vs] for k, vs in self._reference_import_aliases.items()} diff --git a/src/codegen/sdk/typescript/type_alias.py b/src/codegen/sdk/typescript/type_alias.py deleted file mode 100644 index d4d671909..000000000 --- a/src/codegen/sdk/typescript/type_alias.py +++ /dev/null @@ -1,73 +0,0 @@ -from codegen.sdk.core.autocommit import commiter, reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.interfaces.has_name import HasName -from codegen.sdk.core.type_alias import TypeAlias -from codegen.sdk.enums import SymbolType -from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock -from codegen.sdk.typescript.interfaces.has_block import TSHasBlock -from codegen.sdk.typescript.statements.attribute import TSAttribute -from codegen.sdk.typescript.symbol import TSSymbol -from codegen.shared.decorators.docs import noapidoc, ts_apidoc - - -@ts_apidoc -class TSTypeAlias(TypeAlias[TSCodeBlock, TSAttribute], TSSymbol, TSHasBlock): - """Representation of an Interface in TypeScript. - - Attributes: - symbol_type: The type of symbol, set to SymbolType.Type. - """ - - symbol_type = SymbolType.Type - - @noapidoc - @commiter - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - dest = dest or self.self_dest - # =====[ Type Identifiers ]===== - # Look for type references in the interface body - self.value._compute_dependencies(UsageKind.TYPE_DEFINITION, dest) - self.code_block._compute_dependencies(UsageKind.TYPE_DEFINITION, dest) - # body = self.ts_node.child_by_field_name("value") - # if body: - # # Handle type queries (typeof) - # type_queries = find_all_descendants(body, ["type_query"]) - # for type_query in type_queries: - # query_identifiers = find_all_descendants(type_query, ["identifier"]) - # self._add_symbol_usages(query_identifiers, SymbolUsageType.TYPE) - # - # type_identifiers = find_all_descendants(body, ["type_identifier"]) - # self._add_symbol_usages(type_identifiers, SymbolUsageType.TYPE) - if self.type_parameters: - self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) - - @reader - def _parse_code_block(self) -> TSCodeBlock: - """Returns the code block of the function""" - value_node = self.ts_node.child_by_field_name("value") - return super()._parse_code_block(value_node) - - @property - @reader - def attributes(self) -> list[TSAttribute]: - """Retrieves all attributes belonging to this type alias. - - Returns a list of attributes that are defined within the type alias's code block. - These attributes represent named values or properties associated with the type alias. - - Returns: - list[TSAttribute[TSTypeAlias, None]]: A list of TSAttribute objects representing the type alias's attributes. - """ - return self.code_block.attributes - - @reader - def get_attribute(self, name: str) -> TSAttribute | None: - """Retrieves a specific attribute from a TypeScript type alias by its name. - - Args: - name (str): The name of the attribute to retrieve. - - Returns: - TSAttribute[TSTypeAlias, None] | None: The attribute with the specified name if found, None otherwise. - """ - return next((x for x in self.attributes if x.name == name), None) diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py deleted file mode 100644 index 7476e6e8a..000000000 --- a/src/codegen/sdk/utils.py +++ /dev/null @@ -1,341 +0,0 @@ -import os -import re -import shutil -import statistics -from collections.abc import Iterable -from contextlib import contextmanager -from xml.dom.minidom import parseString - -import dicttoxml -import xmltodict -from tree_sitter import Node as TSNode - -from codegen.sdk.extensions.utils import find_all_descendants, find_first_descendant, get_all_identifiers -from codegen.sdk.typescript.enums import TSFunctionTypeNames -from codegen.shared.enums.programming_language import ProgrammingLanguage - -""" -Utility functions for traversing the tree sitter structure. -Do not include language specific traversals, or string manipulations here. -""" - - -class XMLUtils: - @staticmethod - def dict_to_xml(data: dict | list, format: bool = False, **kwargs) -> str: - result = dicttoxml.dicttoxml(data, return_bytes=False, **kwargs) - if not isinstance(result, str): - msg = "Failed to convert dict to XML" - raise ValueError(msg) - if format: - result = parseString(result).toprettyxml() - return result - - @staticmethod - def add_cdata_to_function_body(xml_string): - pattern = r"()(.*?)()" - replacement = r"\1\3" - updated_xml_string = re.sub(pattern, replacement, xml_string, flags=re.DOTALL) - return updated_xml_string - - @staticmethod - def add_cdata_to_tags(xml_string: str, tags: Iterable[str]) -> str: - patterns = [rf"(<{tag}>)(.*?)()" for tag in tags] - updated_xml_string = xml_string - - for pattern in patterns: - replacement = r"\1\3" - updated_xml_string = re.sub(pattern, replacement, updated_xml_string, flags=re.DOTALL) - - return updated_xml_string - - @staticmethod - def xml_to_dict(xml_string: str, **kwargs) -> dict: - return xmltodict.parse(XMLUtils.add_cdata_to_tags(xml_string, ["function_body", "reasoning"]), **kwargs) - - @staticmethod - def strip_after_tag(xml_string, tag): - pattern = re.compile(f"<{tag}.*?>.*", re.DOTALL) - match = pattern.search(xml_string) - if match: - return xml_string[: match.start()] - else: - return xml_string - - @staticmethod - def strip_tag(xml_string: str, tag: str): - pattern = re.compile(f"<{tag}>.*?", re.DOTALL) - return pattern.sub("", xml_string).strip() - - @staticmethod - def strip_all_tags(xml_string: str): - pattern = re.compile(r"<[^>]*>") - return pattern.sub("", xml_string).strip() - - @staticmethod - def extract_elements(xml_string: str, tag: str, keep_tag: bool = False) -> list[str]: - pattern = re.compile(f"<{tag}.*?", re.DOTALL) - matches = pattern.findall(xml_string) - if keep_tag: - return matches - else: - return [match.strip(f"<{tag}>").strip(f"") for match in matches] - - -def find_first_function_descendant(node: TSNode) -> TSNode: - type_names = [function_type.value for function_type in TSFunctionTypeNames] - return find_first_descendant(node=node, type_names=type_names, max_depth=2) - - -def find_import_node(node: TSNode) -> TSNode | None: - """Get the import node from a node that may contain an import. - Returns None if the node does not contain an import. - - Returns: - TSNode | None: The import_statement or call_expression node if it's an import, None otherwise - """ - # Static imports - if node.type == "import_statement": - return node - - # Dynamic imports and requires can be either: - # 1. Inside expression_statement -> call_expression - # 2. Direct call_expression - - # we only parse imports inside expressions and variable declarations - - if member_expression := find_first_descendant(node, ["member_expression"]): - # there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module) - descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block") - if descendants: - import_node = descendants[-1] - else: - # this means this is NOT a dynamic import() - return None - else: - import_node = find_first_descendant(node, ["call_expression"]) - - # thus we only consider the deepest one - if import_node: - function = import_node.child_by_field_name("function") - if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")): - return import_node - - return None - - -def find_index(target: TSNode, siblings: list[TSNode]) -> int: - """Returns the index of the target node in the list of siblings, or -1 if not found. Recursive implementation.""" - if target in siblings: - return siblings.index(target) - - for i, sibling in enumerate(siblings): - index = find_index(target, sibling.named_children if target.is_named else sibling.children) - if index != -1: - return i - return -1 - - -def find_first_ancestor(node: TSNode, type_names: list[str], max_depth: int | None = None) -> TSNode | None: - depth = 0 - while node is not None and (max_depth is None or depth <= max_depth): - if node.type in type_names: - return node - node = node.parent - depth += 1 - return None - - -def find_first_child_by_field_name(node: TSNode, field_name: str) -> TSNode | None: - child = node.child_by_field_name(field_name) - if child is not None: - return child - for child in node.children: - first_descendant = find_first_child_by_field_name(child, field_name) - if first_descendant is not None: - return first_descendant - return None - - -def has_descendant(node: TSNode, type_name: str) -> bool: - def traverse(current_node: TSNode, depth: int = 0) -> bool: - if current_node.type == type_name: - return True - return any(traverse(child, depth + 1) for child in current_node.children) - - return traverse(node) - - -def get_first_identifier(node: TSNode) -> TSNode | None: - """Get the text of the first identifier child of a tree-sitter node. Recursive implementation""" - if node.type in ("identifier", "shorthand_property_identifier_pattern"): - return node - for child in node.children: - output = get_first_identifier(child) - if output is not None: - return output - return None - - -def descendant_for_byte_range(node: TSNode, start_byte: int, end_byte: int, allow_comment_boundaries: bool = True) -> TSNode | None: - """Proper implementation of descendant_for_byte_range, which returns the lowest node that contains the byte range.""" - ts_match = node.descendant_for_byte_range(start_byte, end_byte) - - # We don't care if the match overlaps with comments - if allow_comment_boundaries: - return ts_match - - # Want to prevent it from matching with part of the match within a comment - else: - if not ts_match.children: - return ts_match - comments = find_all_descendants(ts_match, "comment") - # see if any of these comments partially overlaps with the match - if any(comment.start_byte < start_byte < comment.end_byte or comment.start_byte < end_byte < comment.end_byte for comment in comments): - return None - return ts_match - - -@contextmanager -def shadow_files(files: str | list[str]): - """Creates shadow copies of the given files. Restores the original files after the context manager is exited. - - Returns list of filenames of shadowed files. - """ - if isinstance(files, str): - files = [files] - shadowed_files = {} - # Generate shadow file names - for file_name in files: - shadow_file_name = file_name + ".gs_internal.bak" - shadowed_files[file_name] = shadow_file_name - # Shadow files - try: - # Backup the original files - for file_name, shadow_file_name in shadowed_files.items(): - shutil.copy(file_name, shadow_file_name) - yield shadowed_files.values() - finally: - # Restore the original files - for file_name, shadow_file_name in shadowed_files.items(): - # If shadow file was created, restore the original file and delete the shadow file - if os.path.exists(shadow_file_name): - # Delete the original file if it exists - if os.path.exists(file_name): - os.remove(file_name) - # Copy the shadow file to the original file path - shutil.copy(shadow_file_name, file_name) - # Delete the shadow file - os.remove(shadow_file_name) - - -def calculate_base_path(full_path, relative_path): - """Calculate the base path represented by './' in a relative path. - - :param full_path: The full path to a file or directory - :param relative_path: A relative path starting with './' - :return: The base path represented by './' in the relative path - """ - # Normalize paths to handle different path separators - full_path = os.path.normpath(full_path) - relative_path = os.path.normpath(relative_path) - - # Split paths into components - full_components = full_path.split(os.sep) - relative_components = relative_path.split(os.sep) - - # Remove './' from the start of relative path if present - if relative_components[0] == ".": - relative_components = relative_components[1:] - - # Calculate the number of components to keep from the full path - keep_components = len(full_components) - len(relative_components) - - # Join the components to form the base path - base_path = os.sep.join(full_components[:keep_components]) - - return base_path - - -__all__ = [ - "find_all_descendants", - "find_first_ancestor", - "find_first_child_by_field_name", - "find_first_descendant", - "get_all_identifiers", - "has_descendant", -] - - -def get_language_file_extensions(language: ProgrammingLanguage): - """Returns the file extensions for the given language.""" - from codegen.sdk.python import PyFile - from codegen.sdk.typescript.file import TSFile - - if language == ProgrammingLanguage.PYTHON: - return set(PyFile.get_extensions()) - elif language == ProgrammingLanguage.TYPESCRIPT: - return set(TSFile.get_extensions()) - - -def truncate_line(input: str, max_chars: int) -> str: - input = str(input) - if len(input) > max_chars: - return input[:max_chars] + f"...(truncated from {len(input)} characters)." - return input - - -def is_minified_js(content): - """Analyzes a string to determine if it contains minified JavaScript code. - - Args: - content: String containing JavaScript code to analyze - - Returns: - bool: True if the content appears to be minified JavaScript, False otherwise - """ - try: - # Skip empty content - if not content.strip(): - return False - - # Characteristics of minified JS files - lines = content.split("\n") - - # 1. Check for average line length (minified files have very long lines) - line_lengths = [len(line) for line in lines if line.strip()] - if not line_lengths: # Handle empty content case - return False - - avg_line_length = statistics.mean(line_lengths) - - # 2. Check for semicolon-to-newline ratio (minified often has ; instead of newlines) - semicolons = content.count(";") - newlines = len(lines) - 1 - semicolon_ratio = semicolons / max(newlines, 1) # Avoid division by zero - - # 3. Check whitespace ratio (minified has low whitespace) - whitespace_chars = len(re.findall(r"[\s]", content)) - total_chars = len(content) - whitespace_ratio = whitespace_chars / total_chars if total_chars else 0 - - # 4. Check for common minification patterns - has_common_patterns = bool(re.search(r"[\w\)]\{[\w:]+\}", content)) # Condensed object notation - - # 5. Check for short variable names (common in minified code) - variable_names = re.findall(r"var\s+(\w+)", content) - avg_var_length = statistics.mean([len(name) for name in variable_names]) if variable_names else 0 - - # Decision logic - tuned threshold values - is_minified = ( - (avg_line_length > 250) # Very long average line length - and (semicolon_ratio > 0.8 or has_common_patterns) # High semicolon ratio or minification patterns - and (whitespace_ratio < 0.08) # Very low whitespace ratio - and (avg_var_length < 3 or not variable_names) # Extremely short variable names or no vars - ) - - return is_minified - - except Exception as e: - print(f"Error analyzing content: {e}") - return False diff --git a/src/codegen/sdk/writer_decorators.py b/src/codegen/sdk/writer_decorators.py deleted file mode 100644 index b92da3596..000000000 --- a/src/codegen/sdk/writer_decorators.py +++ /dev/null @@ -1,10 +0,0 @@ -from codegen.shared.enums.programming_language import ProgrammingLanguage - - -def canonical(codemod): - """Decorator for canonical Codemods that will be used for AI-agent prompts.""" - codemod._canonical = True - if not hasattr(codemod, "language") or codemod.language not in (ProgrammingLanguage.PYTHON, ProgrammingLanguage.TYPESCRIPT): - msg = "Canonical codemods must have a `language` attribute (PYTHON or TYPESCRIPT)." - raise AttributeError(msg) - return codemod diff --git a/src/codegen/visualizations/enums.py b/src/codegen/visualizations/enums.py deleted file mode 100644 index fda633034..000000000 --- a/src/codegen/visualizations/enums.py +++ /dev/null @@ -1,27 +0,0 @@ -from dataclasses import dataclass -from enum import StrEnum - - -@dataclass(frozen=True) -class VizNode: - name: str | None = None - text: str | None = None - code: str | None = None - color: str | None = None - shape: str | None = None - start_point: tuple | None = None - emoji: str | None = None - end_point: tuple | None = None - file_path: str | None = None - symbol_name: str | None = None - - -@dataclass(frozen=True) -class GraphJson: - type: str - data: dict - - -class GraphType(StrEnum): - TREE = "tree" - GRAPH = "graph" diff --git a/src/codegen/visualizations/py.typed b/src/codegen/visualizations/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codegen/visualizations/visualization_manager.py b/src/codegen/visualizations/visualization_manager.py deleted file mode 100644 index 7be3cf8fb..000000000 --- a/src/codegen/visualizations/visualization_manager.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -import plotly.graph_objects as go -from networkx import Graph - -from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.sdk.core.interfaces.editable import Editable -from codegen.shared.logging.get_logger import get_logger -from codegen.visualizations.viz_utils import graph_to_json - -logger = get_logger(__name__) - - -class VisualizationManager: - op: RepoOperator - - def __init__( - self, - op: RepoOperator, - ) -> None: - self.op = op - - @property - def viz_path(self) -> str: - return os.path.join(self.op.base_dir, "codegen-graphviz") - - @property - def viz_file_path(self) -> str: - return os.path.join(self.viz_path, "graph.json") - - def clear_graphviz_data(self) -> None: - if self.op.folder_exists(self.viz_path): - self.op.emptydir(self.viz_path) - - def write_graphviz_data(self, G: Graph | go.Figure, root: Editable | str | int | None = None) -> None: - """Writes the graph data to a file. - - Args: - ---- - G (Graph | go.Figure): A NetworkX Graph object representing the graph to be visualized. - root (str | None): The root node to visualize. Defaults to None. - - Returns: - ------ - None - """ - # Convert the graph to a JSON-serializable format - if isinstance(G, Graph): - graph_json = graph_to_json(G, root) - elif isinstance(G, go.Figure): - graph_json = G.to_json() - - # Check if the visualization path exists, if so, empty it - if self.op.folder_exists(self.viz_path): - self.op.emptydir(self.viz_path) - else: - # If the path doesn't exist, create it - self.op.mkdir(self.viz_path) - - # Write the graph data to a file - with open(self.viz_file_path, "w") as f: - f.write(graph_json) - f.flush() # Ensure data is written to disk diff --git a/src/codegen/visualizations/viz_utils.py b/src/codegen/visualizations/viz_utils.py deleted file mode 100644 index f1cefeee9..000000000 --- a/src/codegen/visualizations/viz_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -import json -import os -from dataclasses import asdict -from typing import TYPE_CHECKING - -import networkx as nx -from networkx import DiGraph, Graph - -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.output.utils import DeterministicJSONEncoder -from codegen.visualizations.enums import GraphJson, GraphType - -if TYPE_CHECKING: - from codegen.git.repo_operator.repo_operator import RepoOperator - -#################################################################################################################### -# READING GRAPH VISUALIZATION DATA -#################################################################################################################### - - -def get_graph_json(op: "RepoOperator"): - if os.path.exists(op.viz_file_path): - with open(op.viz_file_path) as f: - graph_json = json.load(f) - return graph_json - else: - return None - - -#################################################################################################################### -# NETWORKX GRAPH TO JSON -#################################################################################################################### - - -def get_node_options(node: Editable | str | int): - if isinstance(node, Editable): - return asdict(node.viz) - return {} - - -def get_node_id(node: Editable | str | int): - if isinstance(node, Importable): - return node.node_id - elif isinstance(node, Editable): - return str(node.span) - elif isinstance(node, str) or isinstance(node, int): - return node - - -def graph_to_json(G1: Graph, root: Editable | str | int | None = None): - G2 = DiGraph() - for node_tuple in G1.nodes(data=True): - options = get_node_options(node_tuple[0]) - options.update(node_tuple[1]) - G2.add_node(get_node_id(node_tuple[0]), **options) - - for edge_tuple in G1.edges(data=True): - options = edge_tuple[2] - if "symbol" in options: - print(get_node_options(options["symbol"])) - options.update(get_node_options(options["symbol"])) - del options["symbol"] - G2.add_edge(get_node_id(edge_tuple[0]), get_node_id(edge_tuple[1]), **options) - - if root: - root = get_node_id(root) - return json.dumps(asdict(GraphJson(type=GraphType.TREE.value, data=nx.tree_data(G2, root))), cls=DeterministicJSONEncoder, indent=2) - else: - return json.dumps(asdict(GraphJson(type=GraphType.GRAPH.value, data=nx.node_link_data(G2))), cls=DeterministicJSONEncoder, indent=2) diff --git a/src/codemods/README.md b/src/codemods/README.md deleted file mode 100644 index ea6eaca11..000000000 --- a/src/codemods/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# Codemod Test Runner - -Put your codemod in the canonical folder if it is canonical, otherwise put it in misc. -You can also add company folders if you wish -`{codemod_folder} = src/codemods/{type}/{name}` -`{codemod_tests_folder} = tests/integration/codemod/{type}/{name}` - -## Repos - -These are the inputs to run it against. -To add a test case, create a folder called `{codemod_folder}/test_{repo_name}` - -### JSON test cases - -Add a repo to the repos folder or use an existing one. Use the current ones as reference - -### Local Test Cases - -Add a folder to the test folder containing the original state of the repository -`{codemod_tests_folder}/test_{repo_name}/original`. Then add files ( -ie: `{codemod_tests_folder}/test_{repo_name}/original/codebase.py`) to test on. - -## Expected outputs - -### Diffs - -Diffs are difficult to parse, but you can add `{codemod_tests_folder}/test_{repo_name}/expected_diff.patch` - -### Files - -You can add all the changed files to a folder called `{codemod_folder}/test_{repo_name}/expected`. The test runner will attempt to -convert the previous into this format. - -### Leaving it blank - -This will cause a warning but is helpful for performance testing - -## Profiles - -`.profiles` will have HTML profiles for each codemod/test case. These get overwritten on each run - -## Diff output - -`.diffs` will have HTML diffs for each codemod/test case. These get overwritten on each run diff --git a/src/codemods/canonical/__init__.py b/src/codemods/canonical/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/add_function_parameter_type_annotations/__init__.py b/src/codemods/canonical/add_function_parameter_type_annotations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py b/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py deleted file mode 100644 index 5cb33414b..000000000 --- a/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py +++ /dev/null @@ -1,51 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that adds type annotations for function parameters named 'db' to be of type 'SessionLocal' from 'app.db'. The codemod should -also ensure that the necessary import statement is added if it is not already present. Include examples of the code before and after the -transformation.""", - uid="d62a3590-14ef-4759-853c-39c5cf755ce5", -) -@canonical -class AddFunctionParameterTypeAnnotations(Codemod, Skill): - """Adds type annotation for function parameters that takes in a 'db' parameter, which is a `SessionLocal` from `app.db`. - It also adds the necessary import if not already present. - - Before: - ``` - def some_function(db): - pass - ``` - - After: - ``` - from app.db import SessionLocal - - def some_function(db: SessionLocal): - pass - ``` - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all functions in the codebase - for function in codebase.functions: - # Check each parameter of the function - for param in function.parameters: - # Identify parameters named 'db' - if param.name == "db": - # Change the type annotation to 'SessionLocal' - param.set_type_annotation("SessionLocal") - # Ensure the necessary import is present - file = function.file - if "SessionLocal" not in [imp.name for imp in file.imports]: - file.add_import("from app.db import SessionLocal") diff --git a/src/codemods/canonical/add_internal_to_non_exported_components/__init__.py b/src/codemods/canonical/add_internal_to_non_exported_components/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/add_internal_to_non_exported_components/add_internal_to_non_exported_components.py b/src/codemods/canonical/add_internal_to_non_exported_components/add_internal_to_non_exported_components.py deleted file mode 100644 index 11a6ecc2c..000000000 --- a/src/codemods/canonical/add_internal_to_non_exported_components/add_internal_to_non_exported_components.py +++ /dev/null @@ -1,44 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a codemod that iterates through a codebase and renames all non-exported React function components by appending 'Internal' to their names. The -codemod should check each function to determine if it is a JSX component and not exported, then rename it accordingly.""", - uid="302d8f7c-c848-4020-9dea-30e8e622d709", -) -@canonical -class AddInternalToNonExportedComponents(Codemod, Skill): - """This codemod renames all React function components that are not exported from their file to be suffixed with 'Internal'. - - Example: - Before: - ``` - const Inner = () =>
; - const Outer = () =>
; - export default Outer; - ``` - After: - ``` - const InnerInternal = () =>
; - const Outer = () =>
; - export default Outer; - ``` - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files - for file in codebase.files: - for function in file.functions: - # Check if the function is a React component and is not exported - if function.is_jsx and not function.is_exported: - # Rename the function to include 'Internal' - function.rename(f"{function.name}Internal") diff --git a/src/codemods/canonical/bang_bang_to_boolean/__init__.py b/src/codemods/canonical/bang_bang_to_boolean/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/bang_bang_to_boolean/bang_bang_to_boolean.py b/src/codemods/canonical/bang_bang_to_boolean/bang_bang_to_boolean.py deleted file mode 100644 index e7c769e9c..000000000 --- a/src/codemods/canonical/bang_bang_to_boolean/bang_bang_to_boolean.py +++ /dev/null @@ -1,37 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that transforms instances of '!!(expression)' into 'Boolean(expression)'. The codemod should search through all -TypeScript files in a codebase, using a regular expression to identify the pattern. Upon finding a match, it should replace '!!' with 'Boolean(' and -append a closing parenthesis to complete the transformation.""", - uid="d1ece8d3-7da9-4696-9288-4087737e2952", -) -@canonical -class BangBangToBoolean(Codemod, Skill): - """This codemod converts !!(expression) to Boolean(expression)""" - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Regular expression pattern as a string to find '!!' followed by an identifier or any bracketed expression - pattern = r"!!\s*(\w+|\([^\)]*\))" - - # Iterate over all files in the codebase - for file in codebase.files: - # Check if the file is a TypeScript file - if file.extension == ".ts": - # Search for the pattern in the file's source code using the string pattern - matches = file.search(pattern, include_strings=False, include_comments=False) - for match in matches: - # Replace the '!!' with 'Boolean(' - match.replace("!!", "Boolean(", count=1) - # Wrap the expression in closing parenthesis - match.insert_after(")", newline=False) diff --git a/src/codemods/canonical/built_in_type_annotation/built_in_type_annotation.py b/src/codemods/canonical/built_in_type_annotation/built_in_type_annotation.py deleted file mode 100644 index e04461cfc..000000000 --- a/src/codemods/canonical/built_in_type_annotation/built_in_type_annotation.py +++ /dev/null @@ -1,44 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that replaces type annotations from the typing module with their corresponding built-in types. The codemod should iterate -through all files in a codebase, check for imports from the typing module, remove those imports, and replace any usages of typing.List, typing.Dict, -typing.Set, and typing.Tuple with list, dict, set, and tuple respectively.""", - uid="b2cd98af-d3c5-4e45-b396-e7abf06df924", -) -@canonical -class BuiltInTypeAnnotation(Codemod, Skill): - """Replaces type annotations using typing module with builtin types. - - Examples: - typing.List -> list - typing.Dict -> dict - typing.Set -> set - typing.Tuple -> tuple - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - import_replacements = {"List": "list", "Dict": "dict", "Set": "set", "Tuple": "tuple"} - # Iterate over all files in the codebase - for file in codebase.files: - # Iterate over all imports in the file - for imported in file.imports: - # Check if the import is from the typing module and is a builtin type - if imported.module == "typing" and imported.name in import_replacements: - # Remove the type import - imported.remove() - # Iterate over all symbols that use this imported module - for usage in imported.usages: - # Replace the usage with the builtin type - if usage.match.source == imported.name: - usage.match.edit(import_replacements[imported.name]) diff --git a/src/codemods/canonical/change_component_tag_names/__init__.py b/src/codemods/canonical/change_component_tag_names/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py deleted file mode 100644 index de97853a6..000000000 --- a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py +++ /dev/null @@ -1,59 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a codemod that updates all instances of the JSX element to within React components in a TypeScript -codebase. Ensure that the new component is imported if it is not already present. The codemod should check for the existence of the - component and raise an error if it is not found.""", - uid="ab5879e3-e3ea-4231-b928-b756473f290d", -) -@canonical -class ChangeJSXElementName(Codemod, Skill): - """This codemod updates specific JSX elements inside of React components - - In particular, this: - <> - test - - - - gets updated to: - <> - test - - - - Inside of all React components in the codebase. - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase): - # Grab the NewName component - PrivateRoutesContainer = codebase.get_symbol("PrivateRoutesContainer", optional=True) - if PrivateRoutesContainer is None or not PrivateRoutesContainer.is_jsx: - msg = "PrivateRoutesContainer component not found in codebase" - raise ValueError(msg) - - # Iterate over all functions in the codebase - for file in codebase.files: - # Iterate over each function in the file - for function in file.functions: - # Check if the function is a React component - if function.is_jsx: - # Iterate over all JSXElements in the React component - for element in function.jsx_elements: - # Check if the element named improperly - if element.name == "PrivateRoute": - # Update the JSXElement's name - element.set_name("PrivateRoutesContainer") - # Add the import if it doesn't exist - if not file.has_import("PrivateRoutesContainer"): - file.add_import(PrivateRoutesContainer) diff --git a/src/codemods/canonical/classnames_to_backtick.py b/src/codemods/canonical/classnames_to_backtick.py deleted file mode 100644 index 62a28e95d..000000000 --- a/src/codemods/canonical/classnames_to_backtick.py +++ /dev/null @@ -1,50 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that converts all `className='...'` props in JSX elements to use backticks. The codemod should iterate through all files -in a codebase, identify JSX components, and for each JSX element, check its props. If a prop is named `className` and its value is not already wrapped -in curly braces, replace the quotes with backticks, updating the prop value accordingly.""", - uid="bf22f4d7-a93a-458f-be78-470c24487d4c", -) -@canonical -class ClassNamesToBackTick(Codemod, Skill): - """This Codemod converts all `classNames="..."` props in JSX elements to use backticks. - - Example: - Before: -
- - After: -
- - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files in the codebase - for file in codebase.files: - # Check if the file is likely to contain JSX elements (commonly in .tsx files) - for function in file.functions: - # Check if the function is a JSX component - if function.is_jsx: - # Iterate over all JSX elements in the function - for element in function.jsx_elements: - # Access the props of the JSXElement - for prop in element.props: - # Check if the prop is named 'className' - if prop.name == "className": - # Get the current value of the prop - if not prop.value.startswith("{"): - # Replace single or double quotes with backticks - new_value = "{`" + prop.value.strip("\"'") + "`}" - # Update the attribute value - prop.set_value(new_value) diff --git a/src/codemods/canonical/convert_array_type_to_square_bracket/__init__.py b/src/codemods/canonical/convert_array_type_to_square_bracket/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/convert_array_type_to_square_bracket/convert_array_type_to_square_bracket.py b/src/codemods/canonical/convert_array_type_to_square_bracket/convert_array_type_to_square_bracket.py deleted file mode 100644 index 8efd7debf..000000000 --- a/src/codemods/canonical/convert_array_type_to_square_bracket/convert_array_type_to_square_bracket.py +++ /dev/null @@ -1,38 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that converts types from `Array` to `T[]`. The codemod should iterate through all files in a codebase, checking each -function's return type and parameters. If a return type or parameter type is of the form `Array`, it should be transformed to `T[]`. Ensure that the -codemod handles edge cases, such as nested Array types, appropriately.""", - uid="97184a15-5992-405b-be7b-30122556fe8b", -) -@canonical -class ConvertArrayTypeToSquareBracket(Codemod, Skill): - """This codemod converts types of the form `Array` to `T[]`, while avoiding edge cases like nested Array types""" - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files in the codebase - for file in codebase.files: - # Iterate over all functions in the file - for func in file.functions: - # Check if the return type is of the form Array - if (return_type := func.return_type) and isinstance(return_type, GenericType) and return_type.name == "Array": - # Array<..> syntax only allows one type argument - func.set_return_type(f"({return_type.parameters[0].source})[]") - - # Process each parameter in the function - for param in func.parameters: - if (param_type := param.type) and isinstance(param_type, GenericType) and param_type.name == "Array": - # Array<..> syntax only allows one type argument - param_type.edit(f"({param_type.parameters[0].source})[]") diff --git a/src/codemods/canonical/convert_attribute_to_decorator/__init__.py b/src/codemods/canonical/convert_attribute_to_decorator/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py b/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py deleted file mode 100644 index a0fecf515..000000000 --- a/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py +++ /dev/null @@ -1,59 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that transforms class attributes initializing specific Session objects into decorators. The codemod should iterate through -all classes in a codebase, check for attributes with values 'NullSession' or 'SecureCookieSession', import the corresponding decorators, add them to -the class, and remove the original attributes. Ensure the decorators are imported from 'src.flask.sessions'.""", - uid="b200fb43-dad4-4241-a0b2-75a6fbf5aca6", -) -@canonical -class ConvertAttributeToDecorator(Codemod, Skill): - """This converts any class attributes that initializes a set of Session objects to a decorator. - - For example, before: - - class MySession(SessionInterface): - session_class = NullSession - ... - - After: - @null_session - class MySession(SessionInterface): - ... - - That is, it deletes the attribute and adds the appropriate decorator via the `cls.add_decorator` method. - Note that `cls.file.add_import(import_str)` is the method used to add import for the decorator. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - attr_value_to_decorator = { - "NullSession": "null_session", - "SecureCookieSession": "secure_cookie_session", - } - # Iterate over all classes in the codebase - for cls in codebase.classes: - # Check if the class contains any targeted attributes - for attribute in cls.attributes: - if attribute.right is None: - continue - - if attribute.right.source in attr_value_to_decorator: - decorator_name = attr_value_to_decorator[attribute.right.source] - # Import the necessary decorators - required_import = f"from src.flask.sessions import {decorator_name}" - cls.file.add_import(required_import) - - # Add the appropriate decorator - cls.add_decorator(f"@{decorator_name}") - # Remove the attribute - attribute.remove() diff --git a/src/codemods/canonical/convert_comments_to_JSDoc_style/__init__.py b/src/codemods/canonical/convert_comments_to_JSDoc_style/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/convert_comments_to_JSDoc_style/convert_comments_to_JSDoc_style.py b/src/codemods/canonical/convert_comments_to_JSDoc_style/convert_comments_to_JSDoc_style.py deleted file mode 100644 index 1443aa8b5..000000000 --- a/src/codemods/canonical/convert_comments_to_JSDoc_style/convert_comments_to_JSDoc_style.py +++ /dev/null @@ -1,45 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a codemod that converts comments on exported functions and classes in a TypeScript codebase to JSDoc style. The codemod should iterate -through all functions and classes, check if they are exported, and if they lack docstrings. If comments are present and do not contain 'eslint', -escape any occurrences of '*/' in the comments to prevent breaking the JSDoc block, then convert the comments to JSDoc format. Finally, remove the -original comments after conversion.""", - uid="846a3894-b534-4de2-9810-94bc691a5687", -) -@canonical -class ConvertCommentsToJSDocStyle(Codemod, Skill): - """This codemod converts the comments on any exported function or class to JSDoc style if they aren't already in JSDoc style. - - A JSDoc style comment is one that uses /** */ instead of // - - It also accounts for some common edgecases like avoiding eslint comments or comments which include a */ in them that needs to be escaped - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all functions and classes in the codebase - for symbol in codebase.functions + codebase.classes: - # Check if the symbol is exported - if symbol.is_exported: - # Check if the symbol is missing docstrings - if not symbol.docstring: - # Check if the symbol has comments - if symbol.comment: - # If eslint comments are present, skip conversion - if "eslint" not in symbol.comment.text: - # Escape any `*/` found in the comment to prevent breaking the JSDoc block - escaped_comment = symbol.comment.text.replace("*/", r"*\/") - # Convert comment to JSdoc docstrings - # symbol.set_docstring(escaped_comment, force_multiline=True) - symbol.set_docstring(escaped_comment, force_multiline=True) - symbol.comment.remove() diff --git a/src/codemods/canonical/convert_docstring_to_google_style/__init__.py b/src/codemods/canonical/convert_docstring_to_google_style/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/convert_docstring_to_google_style/convert_docstring_to_google_style.py b/src/codemods/canonical/convert_docstring_to_google_style/convert_docstring_to_google_style.py deleted file mode 100644 index ce7ba980f..000000000 --- a/src/codemods/canonical/convert_docstring_to_google_style/convert_docstring_to_google_style.py +++ /dev/null @@ -1,30 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod class named `ConvertDocstringToGoogleStyle` that inherits from `Codemod` and `Skill`. The class should have a docstring -explaining its purpose: converting docstrings of functions and classes to Google style if they aren't already. The `execute` method should iterate -over the functions in a given `codebase`, check if each function has a docstring, and if so, convert it to Google style using a method -`to_google_docstring`.""", - uid="99da3cd9-6ba8-4a4e-8ceb-8c1b2a60562d", -) -@canonical -class ConvertDocstringToGoogleStyle(Codemod, Skill): - """This codemod converts docstrings on any function or class to Google docstring style if they aren't already. - - A Google docstring style is one that specifies the args, return value, and raised exceptions in a structured format. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - for function in codebase.functions: - if (docstring := function.docstring) is not None: - function.set_docstring(docstring.to_google_docstring(function)) diff --git a/src/codemods/canonical/delete_unused_functions/delete_unused_functions.py b/src/codemods/canonical/delete_unused_functions/delete_unused_functions.py deleted file mode 100644 index e396aa4fa..000000000 --- a/src/codemods/canonical/delete_unused_functions/delete_unused_functions.py +++ /dev/null @@ -1,34 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that deletes all unused functions from a codebase. The codemod should iterate through each file in the codebase, check for -top-level functions, and remove any function that has no usages or call-sites. Ensure that the implementation follows best practices for identifying -unused functions.""", - uid="4024ceb5-54de-49de-b8f5-122ca2d3a6ee", -) -@canonical -class DeleteUnusedFunctionsCodemod(Codemod, Skill): - """This Codemod deletes all functions that are not used in the codebase (no usages). - In general, when deleting unused things, it's good practice to check both usages and call-sites, even though - call-sites should be basically a subset of usages (every call-site should correspond to a usage). - This is not always the case, however, so it's good to check both. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - for file in codebase.files: - # Iterate over top-level functions in the file - for function in file.functions: - # Check conditions: function has no usages/call-sites - if not function.usages: - # Remove the function from the codebase when it has no call sites - function.remove() diff --git a/src/codemods/canonical/emojify_py_files_codemod/emojify_py_files_codemod.py b/src/codemods/canonical/emojify_py_files_codemod/emojify_py_files_codemod.py deleted file mode 100644 index 08336c629..000000000 --- a/src/codemods/canonical/emojify_py_files_codemod/emojify_py_files_codemod.py +++ /dev/null @@ -1,28 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates over all Python files in a codebase and adds a rainbow emoji comment at the beginning of each file. The -codemod should be implemented in the `execute` function of the `EmojifyPyFilesCodemod` class, which inherits from `Codemod` and `Skill`. Ensure that -the new content for each file starts with the comment '#🌈' followed by the original content of the file.""", - uid="5d8f1994-7f74-42e8-aaa8-0c41ced228ef", -) -@canonical -class EmojifyPyFilesCodemod(Codemod, Skill): - """Trivial codemod to add a rainbow emoji in a comment at the beginning of all Python files.""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # iterate over files - for file in codebase.files: - # add the rainbow emoji to the top of the file - new_content = "#🌈" + "\n" + file.content - file.edit(new_content) diff --git a/src/codemods/canonical/enum_mover/enum_mover.py b/src/codemods/canonical/enum_mover/enum_mover.py deleted file mode 100644 index dae5ff394..000000000 --- a/src/codemods/canonical/enum_mover/enum_mover.py +++ /dev/null @@ -1,49 +0,0 @@ -from codegen.sdk.core.codebase import CodebaseType -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates through all classes in a codebase, identifies subclasses of Enum, and moves them to a designated enums.py -file. Ensure that the codemod checks if the class is already in the correct file, flags it for movement if necessary, and creates the enums.py file if -it does not exist.""", - uid="55bc76e5-15d2-4da6-bac1-59b408a59be7", -) -@canonical -class EnumMover(Codemod, Skill): - """This codemod moves all enums (Enum subclasses) to a designated enums.py file within the same directory of the - file they're defined in. It ensures that the enums are moved to the correct file and creates the enums.py file if - it does not exist. Furthermore, it flags the class for movement which is necessary for splitting up the - modifications into separate pull requests. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: CodebaseType): - # Iterate over all classes in the codebase - for cls in codebase.classes: - # Check if the class is a subclass of Enum - if cls.is_subclass_of("Enum"): - # Determine the target file path for enums.py - target_filepath = "/".join(cls.file.filepath.split("/")[:-1]) + "/enums.py" - - # Check if the current class is already in the correct enums.py file - if cls.file.filepath.endswith("enums.py"): - continue - - # Flag the class for potential movement - flag = codebase.flag_instance(symbol=cls) - if codebase.should_fix(flag): - # Check if the enums.py file exists, if not, create it - if not codebase.has_file(target_filepath): - enums_file = codebase.create_file(target_filepath, "") - else: - enums_file = codebase.get_file(target_filepath) - - # Move the enum class to the enums.py file - cls.move_to_file(enums_file) diff --git a/src/codemods/canonical/insert_arguments_to_decorator/__init__.py b/src/codemods/canonical/insert_arguments_to_decorator/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/insert_arguments_to_decorator/insert_arguments_to_decorator.py b/src/codemods/canonical/insert_arguments_to_decorator/insert_arguments_to_decorator.py deleted file mode 100644 index 24b6a55ce..000000000 --- a/src/codemods/canonical/insert_arguments_to_decorator/insert_arguments_to_decorator.py +++ /dev/null @@ -1,45 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates through a codebase, identifying all instances of the `@app.function` decorator. For each decorator, check if -the `cloud` and `region` arguments are present. If they are missing, append `cloud='aws'` and `region='us-east-1'` to the decorator's arguments. -Ensure that the modifications are made only when the arguments are not already included.""", - uid="de868e09-796c-421b-9efd-151f94f08aef", -) -@canonical -class InsertArgumentsToDecorator(Codemod, Skill): - """This codemod inserts the cloud and region arguments to every app.function decorator. - it decides whether to insert the arguments based on whether they are already present in the decorator. - if they are not present, it inserts them. - for example: - - -@app.function(image=runner_image, secrets=[modal.Secret.from_name("aws-secret")]) - +@app.function(image=runner_image, secrets=[modal.Secret.from_name("aws-secret")], cloud="aws", region="us-east-1") - - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files in the codebase - for file in codebase.files: - # Iterate over all functions in each file - for function in file.functions: - # Check each decorator for the function - for decorator in function.decorators: - # Identify decorators that are app.function and modify them - if decorator.source.startswith("@app.function("): - # Parse the existing decorator to add or update the cloud and region parameters - # Check if 'cloud' and 'region' are already in the decorator - if "cloud=" not in decorator.source: - decorator.call.args.append('cloud="aws"') - if "region=" not in decorator.source: - decorator.call.args.append('region="us-east-1"') diff --git a/src/codemods/canonical/invite_factory_create_params/__init__.py b/src/codemods/canonical/invite_factory_create_params/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/invite_factory_create_params/invite_factory_create_params.py b/src/codemods/canonical/invite_factory_create_params/invite_factory_create_params.py deleted file mode 100644 index 20d39396a..000000000 --- a/src/codemods/canonical/invite_factory_create_params/invite_factory_create_params.py +++ /dev/null @@ -1,69 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.detached_symbols.function_call import FunctionCall -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that updates calls to `InviteFactory.create`, `InviteFactory.build`, and `InviteFactory(...)` to use the `invitee` parameter -instead of `invitee_id`, `invitee['email']`, or `invitee.id`. The codemod should iterate through all files in a codebase, find the relevant function -calls, and modify the arguments accordingly. Specifically, it should replace `invitee_id` with `invitee`, and adjust the value to remove `.id` or -`['email']` as needed.""", - uid="1c43f274-e4bc-49c7-abca-8b273e9cad9a", -) -@canonical -class InviteFactoryCreateParams(Codemod, Skill): - """This codemod updates calls to InviteFactory.create, InviteFactory.build and InviteFactory(...) to use the `invitee` parameter instead of `invitee_id`, `invitee["email"]`, or `invitee.id`. - - For example: - - InviteFactory.create(invitee_id=user_deleted_recently.id) - - Becomes: - - InviteFactory.create(invitee=user_deleted_recently) - - Note that this involves grabbing the function calls by using `file.find` and `file.search` to find the function calls, and then using `FunctionCall.from_usage` to create a `FunctionCall` object from the usage. This is because **the current version of GraphSitter does not support finding method usages** - """ # noqa: E501 - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files - for file in codebase.files: - # Find invocations of InviteFactory.create and InviteFactory.build in the file - usages = file.find("InviteFactory.create", exact=True) # returns an Editable - usages += file.find("InviteFactory.build", exact=True) - usages += file.search(r"\bInviteFactory\(") - - # Iterate over all these function calls - for usage in usages: - # Create a function call from this `usage` - function_call = FunctionCall.from_usage(usage) - if function_call is None: - continue - - # Grab the invitee_id argument - invitee_arg = function_call.get_arg_by_parameter_name("invitee_id") - # If it exists... - if invitee_arg: - # Grab the current value - arg_value = invitee_arg.value - - # Replace the arg value with the correct value - if arg_value.endswith(".id"): - # replace `xyz.id` with `xyz` - invitee_arg.set_value(arg_value.replace(".id", "")) - elif arg_value.endswith('["email"]'): - # replace `xyz["email"]` with `xyz` - invitee_arg.set_value(arg_value.replace('["email"]', "")) - else: - continue - - # Update the arg keyword from `invitee_id` => 'invitee' - invitee_arg.rename("invitee") diff --git a/src/codemods/canonical/js_to_esm_codemod/js_to_esm_codemod.py b/src/codemods/canonical/js_to_esm_codemod/js_to_esm_codemod.py deleted file mode 100644 index 9e4745ff0..000000000 --- a/src/codemods/canonical/js_to_esm_codemod/js_to_esm_codemod.py +++ /dev/null @@ -1,33 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python function named `execute` within a class `JsToEsmCodemod` that iterates through all files in a given `codebase`. For each file, check -if its name contains '.router'. If it does, convert the file to ESM format and update its filename to have a '.ts' extension, preserving the original -directory structure.""", - uid="f93122d3-f469-4740-a8bf-f53016de41b2", -) -@canonical -class JsToEsmCodemod(Codemod, Skill): - """This codemod will convert all JS files that have .router in their name to be proper ESM modules""" - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # iterate all files in the codebase - for file in codebase.files: - # Check if the file is not a router file - if ".router" in file.name: - # Convert the file to ESM - file.convert_js_to_esm() - # Update filename - new_file_dir = "/".join(file.filepath.split("/")[:-1]) - new_file_name = ".".join(file.name.split(".")[:3]) - file.update_filepath(f"{new_file_dir}/{new_file_name}.ts") diff --git a/src/codemods/canonical/mark_as_internal_codemod/mark_as_internal_codemod.py b/src/codemods/canonical/mark_as_internal_codemod/mark_as_internal_codemod.py deleted file mode 100644 index c8fbd8dbc..000000000 --- a/src/codemods/canonical/mark_as_internal_codemod/mark_as_internal_codemod.py +++ /dev/null @@ -1,49 +0,0 @@ -from pathlib import Path - -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that marks functions as internal by adding the @internal tag to their docstrings. The codemod should check if a function -is only used within the same directory or subdirectory, ensuring it is not exported, re-exported, or overloaded. If the function's docstring does not -already contain the @internal tag, append it appropriately.""", - uid="fe61add3-ab41-49ec-9c26-c2d13e2647d1", -) -@canonical -class MarkAsInternalCodemod(Codemod, Skill): - """Mark all functions that are only used in the same directory or subdirectory as an internal function. - To mark function as internal by adding the @internal tag to the docstring. - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Check if the caller and callee are in the same directory - def check_caller_directory(caller_file: str, callee_file: str) -> bool: - caller_path = Path(caller_file).resolve() - callee_path = Path(callee_file).resolve() - return str(caller_path).startswith(str(callee_path.parent)) - - # Iterate over all the functions in the codebase - for function in codebase.functions: - # Ignore functions that are exported - if function.is_exported: - # Check if all usages of the function are in the same file - if all([check_caller_directory(caller.file.filepath, function.file.filepath) for caller in function.symbol_usages]): - # Check if function is not re-exported - if not function.is_reexported and not function.is_overload: - # Check if function is not already marked as internal - docstring = function.docstring.text if function.docstring else "" - if "@internal" not in docstring: - # Add @internal to the docstring - if function.docstring: - function.set_docstring(f"{function.docstring.text}\n\n@internal") - else: - function.set_docstring("@internal") diff --git a/src/codemods/canonical/mark_internal_to_module/mark_internal_to_module.py b/src/codemods/canonical/mark_internal_to_module/mark_internal_to_module.py deleted file mode 100644 index 220cf645c..000000000 --- a/src/codemods/canonical/mark_internal_to_module/mark_internal_to_module.py +++ /dev/null @@ -1,32 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates through all functions in the `app` directory of a codebase. For each function that is not private and is not -being imported anywhere, rename it to be internal by prefixing its name with an underscore. Ensure that the function checks the file path to confirm -it belongs to the `app` directory and uses a method to find import usages.""", - uid="cb5c6f1d-0a00-46e3-ac0d-c540ab665041", -) -@canonical -class MarkInternalToModule(Codemod, Skill): - """This codemod looks at all functions in the `app` directory and marks them as internal if they are not being imported anywhere""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - for function in codebase.functions: - if "app" in function.file.filepath: - # Check if the function is not internal - if not function.is_private and function.name is not None: - # Check if the function is not being imported anywhere - if not any(usage.kind in (UsageKind.IMPORTED, UsageKind.IMPORTED_WILDCARD) for usage in function.usages): - # Rename the function to be internal - function.rename("_" + function.name) diff --git a/src/codemods/canonical/mark_is_boolean/__init__.py b/src/codemods/canonical/mark_is_boolean/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/mark_is_boolean/mark_is_boolean.py b/src/codemods/canonical/mark_is_boolean/mark_is_boolean.py deleted file mode 100644 index e236dce67..000000000 --- a/src/codemods/canonical/mark_is_boolean/mark_is_boolean.py +++ /dev/null @@ -1,45 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that renames function parameters of boolean type that do not start with 'is'. The codemod should iterate through all -files in a codebase, check each function's parameters, and if a parameter is boolean and does not start with 'is', it should be renamed to start with -'is' followed by the capitalized parameter name. Additionally, all function calls using the old parameter name should be updated to use the new name.""", - uid="e848b784-c703-4f4f-bfa4-e3876b2468d1", -) -@canonical -class MarkIsBoolean(Codemod, Skill): - """This (TypeScript) Codemod illustrates how to rename function parameters that are boolean types but do not start with 'is'. - - In a real application, you would probably also check for other valid prefixes, like `should` etc. - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files in the codebase - for file in codebase.files: - # Iterate over all functions in the file - for function in file.functions: - # Iterate over all parameters in each function - for param in function.parameters: - # Check if the parameter is a boolean type - if param.type == "boolean" or param.default in ["true", "false"]: - # Check if the parameter name does not start with 'is' - if not param.name.startswith("is"): - # Generate the new parameter name - new_name = "is" + param.name.capitalize() - # Rename the parameter and update all usages - param.rename(new_name) - # Update all function calls with the new parameter name - for call in function.call_sites: - arg = call.get_arg_by_parameter_name(param.name) - if arg: - arg.rename(new_name) diff --git a/src/codemods/canonical/migrate_class_attributes/migrate_class_attributes.py b/src/codemods/canonical/migrate_class_attributes/migrate_class_attributes.py deleted file mode 100644 index 51bd22fe8..000000000 --- a/src/codemods/canonical/migrate_class_attributes/migrate_class_attributes.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -import textwrap - -from codegen.sdk.core.codebase import PyCodebaseType -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - -logger = logging.getLogger(__name__) - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that migrates class attributes from a source class named 'RequestResetPassword' to a destination class named -'UserGroupsSettingsControlPanel'. The migrated attributes should be made private in the source class by renaming them with a leading underscore. -Additionally, create a hybrid property for each migrated attribute in the source class, including getter and setter methods that manage the private -attribute and maintain a copy in the source class.""", - uid="739061ae-4f4f-48eb-a825-7424417ce540", -) -@canonical -class MigrateClassAttributes(Codemod, Skill): - """Migrates class attributes from a source class to another class. - Any migrated attributes are made private in the source class. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: PyCodebaseType) -> None: - # Get the source and destination classes - source_class = codebase.get_class("RequestResetPassword") - dest_class = codebase.get_class("UserGroupsSettingsControlPanel") - dest_attr_names = [x.name for x in dest_class.attributes] - - # Iterate over all attributes in the source class - for attribute in source_class.attributes(private=False): - # Skip attributes that are already added - if attribute.name in dest_attr_names: - continue - - # Add the attribute to the destination class (and bring its dependencies with it) - dest_class.add_attribute(attribute, include_dependencies=True) - - # Make this attribute private (_name) in the source class - attribute.rename(f"_{attribute.name}") - - # Add a "shadow copy write" to the source class - return_type = attribute.assignment.type.source if attribute.assignment.type else "None" - source_class.add_attribute_from_source(f"""{attribute.name} = hybrid_property(fget=get_{attribute.name}, fset=set_{attribute.name})""") - source_class.methods.append( - textwrap.dedent(f""" - def get_{attribute.name}(self) -> {return_type}: - return self._{attribute.name} - - def set_{attribute.name}(self, value: str) -> None: - self._{attribute.name} = value - self.copy.{attribute.name} = value - """) - ) diff --git a/src/codemods/canonical/move_enums_codemod/move_enums_codemod.py b/src/codemods/canonical/move_enums_codemod/move_enums_codemod.py deleted file mode 100644 index 07e406916..000000000 --- a/src/codemods/canonical/move_enums_codemod/move_enums_codemod.py +++ /dev/null @@ -1,42 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that moves all enum classes from various files in a codebase to a single file named 'enums.py'. The codemod should check if -'enums.py' already exists in the current directory; if not, it should create it. For each enum class found, the codemod should move the class along -with its dependencies to 'enums.py' and add a back edge import to the original file.""", - uid="47e9399c-b8d5-4f39-a5cf-fd40c51620b0", -) -@canonical -class MoveEnumsCodemod(Codemod, Skill): - """Moves all enums to a file called enums.py in current directory if it doesn't already exist""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - for file in codebase.files: - if not file.name.endswith("enums.py"): - for cls in file.classes: - # check if the class inherits from the Enum class - if cls.is_subclass_of("Enum"): - # generate the new filename for the enums.py file - new_filename = "/".join(file.filepath.split("/")[:-1]) + "/enums.py" - - # check if the enums.py file exists - if not codebase.has_file(new_filename): - # if it doesn't exist, create a new file - dst_file = codebase.create_file(new_filename, "from enum import Enum\n\n") - else: - # if it exists, get a reference to the existing file - dst_file = codebase.get_file(new_filename) - - # move the enum class and its dependencies to the enums.py file - # add a "back edge" import to the original file - cls.move_to_file(dst_file, include_dependencies=True, strategy="add_back_edge") diff --git a/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py b/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py deleted file mode 100644 index c91d46962..000000000 --- a/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import TYPE_CHECKING - -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - -if TYPE_CHECKING: - from codegen.sdk.core.file import SourceFile - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that moves all functions starting with 'pylsp_' from existing files in a codebase to a new file named 'pylsp_shared.py'. -Ensure that all imports across the codebase are updated to reflect the new location of these functions. The codemod should iterate through each file -in the codebase, create the new file, and move the matching functions while including their dependencies.""", - uid="b29f6b8b-0837-4548-b770-b597bbcd3e02", -) -@canonical -class MoveFunctionsToNewFile(Codemod, Skill): - """This codemod moves functions that starts with "pylsp_" in their names to a new file called pylsp_shared.py - - When it moves them to this file, all imports across the codebase will get updated to reflect the new location. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase): - # Create a new file for storing the functions that contain pylsp util functions - new_file: SourceFile = codebase.create_file("pylsp/pylsp_shared.py", "") - for file in codebase.files: - # Move function's name contains 'pylsp_' as a prefix - for function in file.functions: - if function.name.startswith("pylsp_"): - # Move each function that matches the criteria to the new file - function.move_to_file(new_file, include_dependencies=True, strategy="update_all_imports") diff --git a/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py b/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py deleted file mode 100644 index cd3eb8809..000000000 --- a/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py +++ /dev/null @@ -1,75 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.detached_symbols.decorator import Decorator -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that adds a `@xys_ns.response(200)` decorator to Flask Resource methods that lack return status codes. The codemod should -check for Flask Resource classes and their HTTP methods (GET, POST, PUT, PATCH, DELETE). If a method does not have any `@response` decorators and has -a valid return statement, the codemod should extract the namespace from the class's `@xys.route` decorator and add the `@xys_ns.response(200)` -decorator to the method.""", - uid="c1596668-8169-44b4-9e0e-b244eb7671d9", -) -@canonical -class OpenAPIAddResponseNone(Codemod, Skill): - """This one adds a `@xys_ns.response(200)` decorator to Flask Resource methods that do not contain any return status codes - - Before: - - @xyz_ns.route("/ping", methods=["GET"]) - class XYZResource(Resource): - - @decorator - def get(self): - return "pong" - - After: - - @xyz_ns.route("/ping", methods=["GET"]) - class XYZResource(Resource): - - @decorator - @xyz_ns.response(200) - def get(self): - return "pong" - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase): - def get_response_decorators(method: Symbol) -> list[Decorator]: - """Returns a list of decorators that contain the string '.response' in the source code""" - return [d for d in method.decorators if ".response" in d.source] - - def get_namespace_decorator(symbol: Symbol) -> Decorator | None: - """Returns the first decorator that contains the string '.route' in the source code""" - matches = [d for d in symbol.decorators if ".route" in d.source] - if len(matches) == 0: - return None - return matches[0] - - for cls in codebase.classes: - # Get Flask Resource classes - if cls.superclasses and any("Resource" in sc.source for sc in cls.superclasses): - for method in cls.methods: - # Filter to HTTP methods - if method.name in ("get", "post", "put", "patch", "delete"): - # Check if it has no `@response` decorators - response_decorators = get_response_decorators(method) - if len(response_decorators) == 0: - # Make sure it has `@xys.route` on the class - ns_decorator = get_namespace_decorator(cls) - if ns_decorator is not None: - # Check if returns a status code - if method.return_statements and not any(ret.value and ret.value.ts_node_type == "expression_list" for ret in method.return_statements): - # Extract the namespace name - ns_name = ns_decorator.source.split("@")[1].split(".")[0] - # Add the decorator - method.add_decorator(f"@{ns_name}.response(200)") diff --git a/src/codemods/canonical/openapi_no_reference_request/openapi_no_reference_request.py b/src/codemods/canonical/openapi_no_reference_request/openapi_no_reference_request.py deleted file mode 100644 index 0849c840f..000000000 --- a/src/codemods/canonical/openapi_no_reference_request/openapi_no_reference_request.py +++ /dev/null @@ -1,49 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.detached_symbols.decorator import Decorator -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that adds `@my_namespace.expect(None)` to all Flask route methods (GET, POST, PUT, PATCH, DELETE) in classes ending with -'Resource' that do not access the request object. Ensure that these methods do not already have an `expect` decorator or similar decorators like -`load_with`, `use_args`, or `use_kwargs`. The codemod should also check for the presence of a namespace decorator in the class to determine the -correct namespace to use.""", - uid="5341d15f-92c7-4a3e-b409-416603dfa7f6", -) -@canonical -class OpenAPINoReferenceRequest(Codemod, Skill): - """As part of the OpenAPI typing initiative for Flask endpoints, this codemod will add `@my_namespace.expect(None)` to all Flask routes that do not interact with the request object.""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - request_accesses = ["request_get_json", "request.json", "request.args", "request.form", "request.files", "request", "self.request"] - - def get_namespace_decorator(symbol: Symbol) -> Decorator | None: - matches = [d for d in symbol.decorators if "_ns.route" in d.source] - if len(matches) == 0: - return None - return matches[0] - - for cls in codebase.classes: - if cls.name.endswith("Resource"): - for method in cls.methods: - if method.name in ("get", "post", "put", "patch", "delete"): - # Check if it has any request accesses - if not any([access in method.source for access in request_accesses]): - # Check if it has an existing `expect` - decorators = method.decorators - if not any([x in decorator.source for decorator in decorators for x in ["load_with", "expect", "use_args", "use_kwargs"]]): - # Make sure it has `@xys_ns.route` on the class - ns_decorator = get_namespace_decorator(cls) - if ns_decorator is not None: - ns_name = ns_decorator.source.split("@")[1].split(".")[0] - # Add the decorator - method.add_decorator(f"@{ns_name}.expect(None)") diff --git a/src/codemods/canonical/pascal_case_symbols/__init__.py b/src/codemods/canonical/pascal_case_symbols/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/pascal_case_symbols/pascal_case_symbols.py b/src/codemods/canonical/pascal_case_symbols/pascal_case_symbols.py deleted file mode 100644 index 5ff9dbea8..000000000 --- a/src/codemods/canonical/pascal_case_symbols/pascal_case_symbols.py +++ /dev/null @@ -1,41 +0,0 @@ -from codegen.sdk.core.class_definition import Class -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.interface import Interface -from codegen.sdk.core.type_alias import TypeAlias -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that converts all Classes, Interfaces, and Types in a codebase to PascalCase. The codemod should iterate through all -symbols in the codebase, check if each symbol is a Class, Interface, or Type using the `isinstance` function, and if the symbol's name is not -capitalized, it should convert the name to PascalCase by capitalizing the first letter of each word and removing underscores. Finally, the codemod -should rename the symbol and update all references accordingly.""", - uid="bbb9e26a-7911-4b94-a4eb-207b9d32d18f", -) -@canonical -class PascalCaseSymbols(Codemod, Skill): - """This (Typescript) codemod converts all Classes, Interfaces and Types to be in PascalCase using simple logic. - - Note the use of the `isinstance(symbol, (Class | Interface | Type))` syntax to check if the symbol is a Class, Interface, or Type. - You should always use the abstract base class to check for the type of a symbol. - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all symbols in the codebase - for symbol in codebase.symbols: - # Check if the symbol is a Class, Interface, or Type with `isinstance` syntax - if isinstance(symbol, (Class | Interface | TypeAlias)): - # Check if the name isn't capitalized - if not symbol.name[0].isupper(): - # Generate the PascalCase name - new_name = "".join(word.capitalize() for word in symbol.name.replace("_", " ").split()) - # Rename the symbol and update all references - symbol.rename(new_name) diff --git a/src/codemods/canonical/pivot_return_types/__init__.py b/src/codemods/canonical/pivot_return_types/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/pivot_return_types/pivot_return_types.py b/src/codemods/canonical/pivot_return_types/pivot_return_types.py deleted file mode 100644 index aeb1cdee8..000000000 --- a/src/codemods/canonical/pivot_return_types/pivot_return_types.py +++ /dev/null @@ -1,49 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that transforms all functions returning a string type to return a custom FastStr type instead. The codemod should iterate -through the codebase, check for functions with a return type of 'str', update the return type to 'FastStr', add the necessary import statement for -FastStr, and modify all return statements to wrap the returned value in the FastStr constructor.""", - uid="a357f5c4-2ff0-4fb2-a5c6-be051428604a", -) -@canonical -class PivotReturnTypes(Codemod, Skill): - """This codemod allows us to take all functions that return str and safely convert it to a custom FastStr type. - It does so by wrapping the return statement value in the CustomStr constructor and update the return type annotation. - - def f() -> str: - ... - return content - - Becomes - - def f() -> FastStr: - ... - return FastStr(str=content) - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all functions in the codebase - for function in codebase.functions: - # Check if the function's return type annotation is 'str' - if (return_type := function.return_type) and return_type.source == "str": - # Update the return type to 'FastStr' - function.set_return_type("FastStr") - - # Add import for 'FastStr' if it doesn't exist - function.file.add_import("from app.models.fast_str import FastStr") - - # Modify all return statements within the function - for return_stmt in function.code_block.return_statements: - # Wrap return statements with FastStr constructor - return_stmt.set_value(f"FastStr(str={return_stmt.value})") diff --git a/src/codemods/canonical/refactor_react_components_into_separate_files/refactor_react_components_into_separate_files.py b/src/codemods/canonical/refactor_react_components_into_separate_files/refactor_react_components_into_separate_files.py deleted file mode 100644 index bb03666a9..000000000 --- a/src/codemods/canonical/refactor_react_components_into_separate_files/refactor_react_components_into_separate_files.py +++ /dev/null @@ -1,48 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python function that refactors React components in a codebase. The function should iterate through all files, identify React function -components, and separate non-default exported components into new files. Ensure that the new files are named after the components and that all imports -are updated accordingly. Include necessary error handling and commit changes to the codebase after each move.""", - uid="b64406f4-a670-4d65-8356-c6db25c4f4b7", -) -@canonical -class RefactorReactComponentsIntoSeparateFiles(Codemod, Skill): - """This codemod breaks up JSX/TSX files by moving components that aren't exported by default - into separate files. - """ - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files in the codebase - for file in codebase.files: - # Find all React function components in the file - react_components = [func for func in file.functions if func.is_jsx and func.name is not None] - - # Identify the default exported component - default_component = next((comp for comp in react_components if comp.is_exported and comp.export.is_default_export()), None) - if default_component is None: - continue - - # Move non-default components to new files - for component in react_components: - if component != default_component and component in file.symbols: - # Create a new file for the component - new_file_path = "/".join(file.filepath.split("/")[:-1]) + "/" + component.name + ".tsx" - if not codebase.has_file(new_file_path): - new_file = codebase.create_file(new_file_path) - - # Move the component to the new file and update all imports - component.move_to_file(new_file, strategy="update_all_imports") - - # Commit is NECESSARY since subsequent steps depend on current symbol locations - codebase.commit() diff --git a/src/codemods/canonical/remove_indirect_imports/remove_indirect_imports.py b/src/codemods/canonical/remove_indirect_imports/remove_indirect_imports.py deleted file mode 100644 index ffd866851..000000000 --- a/src/codemods/canonical/remove_indirect_imports/remove_indirect_imports.py +++ /dev/null @@ -1,53 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.core.import_resolution import Import -from codegen.sdk.core.symbol import Symbol -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python function named `execute` within a class `RemoveIndirectImports` that processes a codebase to remove all indirect imports. The -function should iterate through all files in the codebase, check each import to determine if it points to another import, and replace it with a direct -import. Handle cases where the resolved import is either an external module or a symbol, ensuring that the import is updated accordingly.""", - uid="0648c80e-a569-4aa5-b241-38a2dd320e9a", -) -@canonical -class RemoveIndirectImports(Codemod, Skill): - """This codemod removes all indirect imports from a codebase (i.e. an import that points to another import), - replacing them instead with direct imports - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # iterate over all files -> imports - for file in codebase.files: - for original_import in file.imports: - # Grab the symbol being imported - imported_symbol = original_import.imported_symbol - - # Check if the symbol being imported is itself import - if isinstance(imported_symbol, Import): - # We've found an import that points to another import which means it's an indirect import! - # Get the symbol that the import eventually resolves to - imported_symbol = original_import.resolved_symbol - - # Case: we can't find the final destination symbol - if imported_symbol is None: - continue - - # Case: the resolved import is an external module. - elif isinstance(imported_symbol, ExternalModule): - original_import.edit(imported_symbol.source) - - # Case: the resolved import is Symbol. - elif isinstance(imported_symbol, Symbol): - # Replace the module in the import with the final destination symbol's module - # e.g. `from abc import ABC` -> `from xyz import ABC` or equivalent in your language. - original_import.set_import_module(imported_symbol.file.import_module_name) diff --git a/src/codemods/canonical/rename_function_parameters/rename_function_parameters.py b/src/codemods/canonical/rename_function_parameters/rename_function_parameters.py deleted file mode 100644 index 2779453ac..000000000 --- a/src/codemods/canonical/rename_function_parameters/rename_function_parameters.py +++ /dev/null @@ -1,33 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates through all files in a codebase, identifies function parameters containing the substring 'obj', and renames -them to 'new_obj'. The codemod should be structured as a class that inherits from Codemod and Skill, with an execute method that performs the -renaming operation.""", - uid="1576b2fd-8a00-44e4-9659-eb0f585e015a", -) -@canonical -class RenameFunctionParameters(Codemod, Skill): - """This takes all functions that renames any parameter that contains 'obj' and replaces with 'new_obj'""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files - for file in codebase.files: - for function in file.functions: - # Search for parameter names that contain 'obj' - params_to_rename = [p for p in function.parameters if "obj" in p.name] - if params_to_rename: - # Rename the parameters - for param in params_to_rename: - new_param_name = param.name.replace("obj", "new_obj") - param.rename(new_param_name) diff --git a/src/codemods/canonical/rename_local_variables/rename_local_variables.py b/src/codemods/canonical/rename_local_variables/rename_local_variables.py deleted file mode 100644 index 8ce143ac5..000000000 --- a/src/codemods/canonical/rename_local_variables/rename_local_variables.py +++ /dev/null @@ -1,48 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates through a codebase, identifying functions with local variables containing the name 'position'. For each -identified function, rename all occurrences of the local variable 'position' to 'pos', ensuring that the renaming is applied to all relevant usages -within the function.""", - uid="79c10c00-bbce-4bdb-8c39-d91586307a2b", -) -@canonical -class RenameLocalVariables(Codemod, Skill): - """This codemod renames all local variables in functions that contain 'position' to 'pos' - - Example: - Before: - ``` - def some_function(x, y, position): - position_x = x + position - position_y = y + position - return position_x, position_y - ``` - After: - ``` - def some_function(x, y, position): - pos_x = x + position - pos_y = y + position - return pos_x, pos_y - ``` - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # iterate over files - for file in codebase.files: - for function in file.functions: - # Check if any local variable names contain "position" - position_usages = function.code_block.get_variable_usages("position", fuzzy_match=True) - if len(position_usages) > 0: - # Rename - function.rename_local_variable("position", "pos", fuzzy_match=True) diff --git a/src/codemods/canonical/replace_prop_values/replace_prop_values.py b/src/codemods/canonical/replace_prop_values/replace_prop_values.py deleted file mode 100644 index 0609d9e13..000000000 --- a/src/codemods/canonical/replace_prop_values/replace_prop_values.py +++ /dev/null @@ -1,36 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that iterates through a codebase, identifies JSX functions, and replaces any occurrences of the prop value 'text-center' -with 'text-left' in all JSX elements.""", - uid="c1914552-556b-4ae0-99f0-33cb7bfb702e", -) -@canonical -class ReplacePropValues(Codemod, Skill): - """Replaces any JSX props with text-center to text-left""" - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files in the codebase - for file in codebase.files: - # Iterate over all functions in the file - for function in file.functions: - # Filter for JSX functions - if function.is_jsx: - # Iterate over all JSX elements in the function - for jsx_element in function.jsx_elements: - # Iterate over all the props of the component - for prop in jsx_element.props: - # Check if prop has a value - if prop.value: - # Replace text-center with text-left - prop.value.replace("text-center", "text-left") diff --git a/src/codemods/canonical/return_none_type_annotation/return_none_type_annotation.py b/src/codemods/canonical/return_none_type_annotation/return_none_type_annotation.py deleted file mode 100644 index d5adb15cb..000000000 --- a/src/codemods/canonical/return_none_type_annotation/return_none_type_annotation.py +++ /dev/null @@ -1,36 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that iterates through all functions and methods in a codebase. For each function or method that lacks return statements and -a return type annotation, set the return type to 'None'. Ensure the implementation handles both standalone functions and methods within classes.""", - uid="fcac16ed-a915-472a-9dfe-1562452d9ab3", -) -@canonical -class ReturnNoneTypeAnnotation(Codemod, Skill): - """This codemod sets the return type of functions that do not have any return statements""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all functions in the codebase - for function in codebase.functions: - # Look at ones that do not have return statements and no return type annotation - if len(function.return_statements) == 0 and not function.return_type: - # Set the return type to None - function.set_return_type("None") - - # Do the same for methods (have to call it `cls`, not `class`, since `class` is a reserved keyword) - for cls in codebase.classes: - for method in cls.methods: - # Look at ones that do not have return statements and no return type annotation - if len(method.return_statements) == 0 and not method.return_type: - # Set the return type to None - method.set_return_type("None") diff --git a/src/codemods/canonical/split_decorators/__init__.py b/src/codemods/canonical/split_decorators/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/split_decorators/split_decorators.py b/src/codemods/canonical/split_decorators/split_decorators.py deleted file mode 100644 index 8c8b7fc59..000000000 --- a/src/codemods/canonical/split_decorators/split_decorators.py +++ /dev/null @@ -1,52 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that transforms a single decorator call into multiple calls. The codemod should iterate through all classes in a codebase, -identify decorators matching the pattern '@generic_repr', and replace them with separate decorators for each argument passed to the original -decorator. Ensure that the original decorator's ordering is preserved by editing in-place.""", - uid="3f6325b8-02c3-4d90-a726-830f8bccce3a", -) -@canonical -class SplitDecorators(Codemod, Skill): - """This codemod splits a single decorator call into multiple - - For example: - @generic_repr("id", "name", "email") - def f(): - ... - - Becomes: - @generic_repr("id") - @generic_repr("name") - @generic_repr("email") - def f(): - ... - - Note that we edit the original decorator in-place (`decorator.edit(...)`), so as to keep the original decorator's ordering! - - If we instead did `add_decorator` etc., we would have to figure out where to insert the new decorators. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all classes in the codebase - for cls in codebase.classes: - # Find all decorators of the function that match the pattern for `@allow_update` - this is a list of Decorator instances with '{' in the source - target_decorators = [decorator for decorator in cls.decorators if "@generic_repr" in decorator.source] - for decorator in target_decorators: - new_decorators = [] - for arg in decorator.call.args: - new_decorator_source = f"@generic_repr({arg})" - new_decorators.append(new_decorator_source) - - # Remove the original decorator as it will be replaced - decorator.edit("\n".join(new_decorators), fix_indentation=True) diff --git a/src/codemods/canonical/split_file/split_file.py b/src/codemods/canonical/split_file/split_file.py deleted file mode 100644 index 180f62131..000000000 --- a/src/codemods/canonical/split_file/split_file.py +++ /dev/null @@ -1,38 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that splits a large file by moving all subclasses of 'Enum' from 'sqlglot/optimizer/scope.py' to a new file named -'sqlglot/optimizer/enums.py'. The codemod should check if the large file exists, raise a FileNotFoundError if it does not, and then create the new -file before iterating through the classes in the large file to move the relevant subclasses.""", - uid="a7c7388d-f473-4a37-b316-e881079fe093", -) -@canonical -class SplitFile(Codemod, Skill): - """This codemod moves symbols from one large to a new file with the goal of breaking up a large file.""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase): - # Grab large file to split - file = codebase.get_file("sqlglot/optimizer/scope.py", optional=True) - if file is None: - msg = "The file `sqlglot/optimizer/scope.py` was not found." - raise FileNotFoundError(msg) - - # Create a new file for storing all our 'Enum' classes - new_file = codebase.create_file("sqlglot/optimizer/enums.py") - - # iterate over all classes - for cls in file.classes: - # Check inheritance - if cls.is_subclass_of("Enum"): - # Move symbol - cls.move_to_file(new_file) diff --git a/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py b/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py deleted file mode 100644 index 80d7f8263..000000000 --- a/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py +++ /dev/null @@ -1,59 +0,0 @@ -from codegen.sdk.core.codebase import CodebaseType -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that splits a file by moving classes containing 'Configuration' to a new file named 'configuration.py'. After moving, commit -the changes to ensure the new classes are recognized. Then, rename all 'Configuration' classes in the new file to 'Config'. Finally, update the -original file's path from 'types.py' to 'schemas.py'.""", - uid="816415d9-27e8-4228-b284-1b18b3072f0d", -) -@canonical -class SplitFileAndRenameSymbols(Codemod, Skill): - """Split file and rename moved symbols - - This codemod first moves several symbols to new files and then renames them. - - This requires a codebase.commit() call between the move and the rename step. - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: CodebaseType): - # Get file to split up - source_file = codebase.get_file("redash/models/types.py", optional=True) - if source_file is None: - msg = "[1] The file `redash/models/types.py` was not found." - raise FileNotFoundError(msg) - - # Get file symbols will be moved to - configuration_file = codebase.create_file("redash/models/configuration.py") - - # Move all the classes that contain with `Configuration` to the new configuration file - for cls in source_file.classes: - # Move the `_filter` functions - if "Configuration" in cls.name: - # Move the function to the filters file and rename it - # move_to_file should also take care of updating the imports of the functions, and bringing over any imports or local references the function needs - cls.move_to_file(configuration_file, include_dependencies=True, strategy="update_all_imports") - - # Commit is NECESSARY for the codebase graph to be aware of the new classes moved into configuration file - codebase.commit() - - # re-acquire the configuration file with the latest changes - configuration_file = codebase.get_file("redash/models/configuration.py") - - # rename all the `Configuration` classes to `Config` - for cls in configuration_file.classes: - if cls.name == "Configuration": - cls.rename("Config") - - # re-acquire the source file with the latest changes - source_file = codebase.get_file("redash/models/types.py") - source_file.update_filepath("redash/models/schemas.py") diff --git a/src/codemods/canonical/split_large_files/split_large_files.py b/src/codemods/canonical/split_large_files/split_large_files.py deleted file mode 100644 index 33f846421..000000000 --- a/src/codemods/canonical/split_large_files/split_large_files.py +++ /dev/null @@ -1,49 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a TypeScript codemod that processes a codebase to split large files. The codemod should define constants for maximum file length (500 lines) -and maximum symbol length (50 lines). It should iterate through all files in the codebase, checking if a file exceeds the maximum length. If a file -has more than 3 symbols that exceed the maximum symbol length, create a new directory for the file (removing the .ts extension) and move each long -symbol into its own new file within that directory. Ensure to add a back edge to the original file for each moved symbol.""", - uid="b5bbec91-5bfe-4b4b-b62e-0a1ec94089b5", -) -@canonical -class SplitLargeFiles(Codemod, Skill): - """This codemod splits all large files.""" - - language = ProgrammingLanguage.TYPESCRIPT - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase): - # Define constants for maximum lengths - MAX_FILE_LENGTH = 500 - MAX_SYMBOL_LENGTH = 50 - - # Iterate over all files in the codebase - for file in codebase.files: - # Check if the file has more than the maximum file length - if len(file.content.splitlines()) > MAX_FILE_LENGTH: - # Count the number of symbols with more than the maximum symbol length - long_symbols_count = sum(1 for symbol in file.symbols if len(symbol.source.splitlines()) > MAX_SYMBOL_LENGTH) - # Proceed if there are more than 3 long symbols - if long_symbols_count > 3: - # Create a new directory for the file - dir_name = file.filepath.replace(".ts", "") - codebase.create_directory(dir_name, exist_ok=True) - # Iterate over symbols in the file - for symbol in file.symbols: - # Skip any symbol named 'Space' - if len(symbol.source.splitlines()) > MAX_SYMBOL_LENGTH: - # Create a new file for the symbol - new_file = codebase.create_file(f"{dir_name}/{symbol.name}.ts", sync=False) - # Move the symbol to the new file - symbol.move_to_file(new_file) - # Add a back edge to the original file - file.add_import(symbol) diff --git a/src/codemods/canonical/swap_call_site_imports/__init__.py b/src/codemods/canonical/swap_call_site_imports/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py b/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py deleted file mode 100644 index 14c4e96f4..000000000 --- a/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py +++ /dev/null @@ -1,63 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that replaces all imports of a legacy function with its new replacement. The codemod should find all call sites of the -legacy function, update the import module to the new module, and handle the edge case where the legacy function is called within the same file it is -defined. In this case, the codemod should remove the legacy function and add an import for its replacement. The legacy function is located in -'redash/settings/helpers.py' and is named 'array_from_string'. The new import module is 'redash.settings.collections'. Include comments to explain -each step.""", - uid="8fa00be7-adad-473d-8436-fc5f70e6ac6d", -) -@canonical -class SwapCallSiteImports(Codemod, Skill): - """This codemod replaces all imports of a legacy function with it's new replacement. - - This involves: - - Finding all the call sites of the old function - - Updating the import module of the old function import to the new module - - Edge case: legacy function is called within the same file it's defined in - - There won't be an import to the legacy function in this file (b/c it's where it's defined) - - For this case we have to both remove the legacy function and add an import to it's replacement. - - Example: - Before: - from mod import func - - func() - - After: - from new_mode import func - - func() - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - legacy_func_file = codebase.get_file("redash/settings/helpers.py") - legacy_function = legacy_func_file.get_function("array_from_string") - - # Find all call sites of the legacy function - for call_site in legacy_function.call_sites: - # Get the import of the legacy function in the call site file - legacy_import = next((x for x in call_site.file.imports if x.resolved_symbol == legacy_function), None) - - # Update the import module of the old function import to the new module - if legacy_import: - legacy_import.set_import_module("redash.settings.collections") - - # Edge case: legacy function is called within the same file it's defined in - if call_site.file == legacy_function.file: - # Remove the legacy function - legacy_function.remove() - - # Add import of the new function - call_site.file.add_import(f"from settings.collections import {legacy_function.name}") diff --git a/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py b/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py deleted file mode 100644 index 60c520986..000000000 --- a/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py +++ /dev/null @@ -1,63 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that transfers attributes from one class to another. The codemod should rename parameters of functions that use the first -class (GraphRagConfig) to use the second class (CacheConfig) instead. It should also handle variable renaming to avoid conflicts, update function -definitions, add necessary imports, and modify function call sites accordingly.""", - uid="4a3569c2-cf58-4bdc-822b-7a5747f476ab", -) -@canonical -class SwapClassAttributeUsages(Codemod, Skill): - """This codemod takes two classes (class A and class B) and transfers one class's attributes to the other. - It does this by: - - Renaming any parameters that are passing the class A and replaces it to take in class B instead - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - class_a_symb = codebase.get_symbol("GraphRagConfig") - class_b_symb = codebase.get_symbol("CacheConfig") - - for function in codebase.functions: - parameters = function.parameters - if any(p.type == class_a_symb for p in parameters): - # Rename existing instances of `cache_config`=> `cache_config_` (prevents mypy issue) - name_conflict_vars = function.code_block.get_local_var_assignments("cache_config") - for name_conflict_var in name_conflict_vars: - name_conflict_var.rename("cache_config_") - - # Get the parameter to update - class_a_param = function.get_parameter_by_type(class_a_symb) - - # Update original function definition - class_a_param.edit("cache_config: CacheConfig") - - # Add import of `CacheConfig` to function definition file - function.file.add_import(class_b_symb) - - # Check if the function body is using `cache_config` - if len(function.code_block.get_variable_usages(class_a_param.name)) > 0: - # Add "wrapper" inside the function - # This creates the `cache_config` variable internally - proxy_var_declaration = f"""{class_a_param.name} = cache_config.settings # added by Codegen""" - function.prepend_statements(proxy_var_declaration) - - # Update all callsites of original function to take in `cache_config` instead of `graph_rag_config` - fcalls = function.call_sites - for fcall in fcalls: - arg = fcall.get_arg_by_parameter_name(class_a_param.name) - if not arg: - continue - if arg.is_named: - arg.edit(f"cache_config={arg.value}.cache_config") - else: - arg.edit(f"{arg.value}.cache_config") diff --git a/src/codemods/canonical/update_optional_type_annotations/__init__.py b/src/codemods/canonical/update_optional_type_annotations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py b/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py deleted file mode 100644 index d1ceea089..000000000 --- a/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py +++ /dev/null @@ -1,55 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.expressions import Type -from codegen.sdk.core.expressions.generic_type import GenericType -from codegen.sdk.core.expressions.union_type import UnionType -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that updates type annotations in a codebase. The codemod should replace instances of 'Optional[X]' with 'X | None' and -handle other generic types and unions appropriately. Ensure that the codemod iterates through all files, processes functions and methods, checks for -typed parameters, and modifies their annotations as needed. Additionally, include an import statement for future annotations if any changes are made.""", - uid="0e2d60db-bff0-4020-bda7-f264ff6c7f46", -) -@canonical -class UpdateOptionalTypeAnnotations(Codemod, Skill): - """Replaces type annotations with builtin ones, e.g.: - def f(x: Optional[int]): - becomes - def f(x: int | None): - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - def update_type_annotation(type: Type) -> str: - if "Optional" in type.source: - if isinstance(type, GenericType): - if type.name == "Optional": - return update_type_annotation(type.parameters[0]) + " | None" - else: - return f"{type.name}[{', '.join(update_type_annotation(param) for param in type.parameters)}]" - if isinstance(type, UnionType): - return " | ".join(update_type_annotation(param) for param in type) - return type.source - - # Iterate over all files in the codebase - for file in codebase.files: - # Process standalone functions and methods within classes - for function in file.functions + [method for cls in file.classes for method in cls.methods]: - # Iterate over all parameters in the function - if function.parameters: - for parameter in function.parameters: - if parameter.is_typed: - # Check if the parameter has a type annotation - new_type = update_type_annotation(parameter.type) - if parameter.type != new_type: - # Add the future annotations import - file.add_import("from __future__ import annotations\n") - parameter.type.edit(new_type) diff --git a/src/codemods/canonical/update_union_types/__init__.py b/src/codemods/canonical/update_union_types/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/codemods/canonical/update_union_types/update_union_types.py b/src/codemods/canonical/update_union_types/update_union_types.py deleted file mode 100644 index 98f7085c9..000000000 --- a/src/codemods/canonical/update_union_types/update_union_types.py +++ /dev/null @@ -1,41 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that updates type annotations from the old Union[x, y] syntax to the new x | y syntax for migration from Python 3.9 to -Python 3.10. The codemod should iterate through all files in a codebase, check for imports of Union from typing, and replace occurrences of Union in -both generic type and subscript forms. Ensure that the new syntax is correctly formatted, handling cases with multiple types and removing any empty -strings from trailing commas.""", - uid="7637d11a-b907-4716-a09f-07776f81a359", -) -@canonical -class UpdateUnionTypes(Codemod, Skill): - """This updates the Union [ x , y ] syntax for x | y for migrations for python 3.9 to python 3.10""" - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - for file in codebase.files: - # Check if the file imports Union from typing - if "Union" in [imp.name for imp in file.imports]: - # Search for Union type annotations in the file - for editable in file.find("Union["): - if editable.ts_node_type == "generic_type": - new_type = editable.source.replace("Union[", "").replace("]", "", 1).replace(", ", " | ") - editable.replace(editable.source, new_type) - elif editable.ts_node_type == "subscript": - # Handle subscript case (like TypeAlias = Union[...]) - types = editable.source[6:-1].split(",") - # Remove any empty strings that might result from trailing commas - types = [t.strip() for t in types if t.strip()] - new_type = " | ".join(types) - if len(types) > 1: - new_type = f"({new_type})" - editable.replace(editable.source, new_type) diff --git a/src/codemods/canonical/use_named_kwargs/use_named_kwargs.py b/src/codemods/canonical/use_named_kwargs/use_named_kwargs.py deleted file mode 100644 index ed886aff3..000000000 --- a/src/codemods/canonical/use_named_kwargs/use_named_kwargs.py +++ /dev/null @@ -1,57 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.external_module import ExternalModule -from codegen.sdk.python.class_definition import PyClass -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a Python codemod that converts all function calls in a codebase to use named keyword arguments if they have two or more positional arguments. -The codemod should iterate through all files and functions, checking each function call to determine if it meets the criteria for conversion. Ensure -that the conversion is skipped if all arguments are already named, if there are fewer than two arguments, if the function definition cannot be found, -if the function is a class without a constructor, or if the function is part of an external module.""", - uid="1a4b9e66-1df5-4ad1-adbb-034976add8e0", -) -@canonical -class UseNamedKwargs(Codemod, Skill): - """Converts all functions to use named kwargs if there are more than >= 2 args being used. - - In general you can use FunctionCall.convert_args_to_kwargs() once you have filtered properly - """ - - language = ProgrammingLanguage.PYTHON - - @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase) -> None: - # Iterate over all files - for file in codebase.files: - # TODO: doesn't handle global function calls - # Iterate over all functions - for function in file.functions: - # look at the function calls - for call in function.function_calls: - # Skip if all args are already named - if all(arg.is_named for arg in call.args): - continue - - # Skip if call sites has < 2 args - if len(call.args) < 2: - continue - - # Skip if we can't find the def of the function - function_def = call.function_definition - if not function_def: - continue - - # Skip if function_def is a class and the class has no constructor - if isinstance(function_def, PyClass) and not function_def.constructor: - continue - - if isinstance(function_def, ExternalModule): - continue - - call.convert_args_to_kwargs() diff --git a/src/codemods/canonical/wrap_with_component/wrap_with_component.py b/src/codemods/canonical/wrap_with_component/wrap_with_component.py deleted file mode 100644 index b1bed4cc1..000000000 --- a/src/codemods/canonical/wrap_with_component/wrap_with_component.py +++ /dev/null @@ -1,51 +0,0 @@ -from codegen.sdk.core.codebase import Codebase -from codegen.sdk.writer_decorators import canonical -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codemods.codemod import Codemod -from tests.shared.skills.decorators import skill, skill_impl -from tests.shared.skills.skill import Skill - - -@skill( - canonical=True, - prompt="""Generate a codemod in TypeScript that wraps all instances of the JSX element -

Current step: {this.props.by}

- - ); - } - - onClick() { - this.setState({ counter: this.state.counter + this.props.by }); - } -} - """ - os.chdir(tmpdir) # TODO: CG-10643 - - with get_codebase_session(tmpdir=tmpdir, files={"dir/file1.tsx": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: - file: TSFile = codebase.get_file("dir/file1.tsx") - component = file.get_class("C") - component.source = component.class_component_to_function_component() - - # language=typescript - assert ( - file.content - == """ -import React from "react"; - -type Props = { - by: number; -}; - -type State = { - counter: number; -}; - -export const C: React.FC = props => { - const { - by = 1 - } = props; - - const [counter, setCounter] = React.useState(0); - - function onClick() { - setCounter(counter + by); - } - - return <> - -

Current step: {by}

- ; -}; - """ - ) diff --git a/tests/integration/codegen/test_imports.py b/tests/integration/codegen/test_imports.py deleted file mode 100644 index b7b3fb039..000000000 --- a/tests/integration/codegen/test_imports.py +++ /dev/null @@ -1,19 +0,0 @@ -import os - -import codegen -from codegen.sdk.code_generation.current_code_codebase import get_graphsitter_repo_path -from codegen.sdk.core.codebase import Codebase - - -def test_codegen_imports(): - # Test decorated function - @codegen.function(name="sample_codemod") - def run(codebase): - pass - - # Test class - cls = codegen.Function - assert cls is not None - os.chdir(get_graphsitter_repo_path()) # TODO: CG-10643 - codebase = Codebase("./") - assert codebase is not None diff --git a/tests/integration/codegen/test_placeholder.py b/tests/integration/codegen/test_placeholder.py new file mode 100644 index 000000000..868b5237c --- /dev/null +++ b/tests/integration/codegen/test_placeholder.py @@ -0,0 +1,10 @@ +"""Placeholder integration test to make the workflow pass.""" + + +def test_placeholder(): + """Placeholder test that always passes. + + This test exists to ensure the integration test workflow completes successfully + when there are no actual integration tests to run. + """ + assert True, "Placeholder test should always pass" diff --git a/tests/integration/codemod/.gitignore b/tests/integration/codemod/.gitignore deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/integration/codemod/__init__.py b/tests/integration/codemod/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/integration/codemod/canonical/bang_bang_to_boolean/test_vite/expected_diff.patch.skip b/tests/integration/codemod/canonical/bang_bang_to_boolean/test_vite/expected_diff.patch.skip deleted file mode 100644 index 1da08456c..000000000 --- a/tests/integration/codemod/canonical/bang_bang_to_boolean/test_vite/expected_diff.patch.skip +++ /dev/null @@ -1,503 +0,0 @@ -diff --git a/packages/plugin-legacy/src/index.ts b/packages/plugin-legacy/src/index.ts -index c4a648753128ac5da7d9a74e6b1cef32a7e9ebea..e2cb2a354bd461f96107ac0add2e480ee119f2f8 100644 ---- a/packages/plugin-legacy/src/index.ts -+++ b/packages/plugin-legacy/src/index.ts -@@ -65,7 +65,7 @@ function toOutputFilePathInHtml( - hostId, - hostType, - type, -- ssr: !!config.build.ssr, -+ ssr: Boolean(config.build.ssr), - }) - if (typeof result === 'object') { - if (result.runtime) { -@@ -538,12 +538,12 @@ function viteLegacyPlugin(options: Options = {}): Plugin[] { - options.polyfills !== false && !Array.isArray(options.polyfills) - - // transform the legacy chunk with @babel/preset-env -- const sourceMaps = !!config.build.sourcemap -+ const sourceMaps = Boolean(config.build.sourcemap) - const babel = await loadBabel() - const result = babel.transform(raw, { - babelrc: false, - configFile: false, -- compact: !!config.build.minify, -+ compact: Boolean(config.build.minify), - sourceMaps, - inputSourceMap: undefined, // sourceMaps ? chunk.map : undefined, `.map` TODO: moved to OutputChunk? - presets: [ -@@ -877,7 +877,7 @@ function isLegacyBundle( - (output) => output.type === 'chunk' && output.isEntry, - ) - -- return !!entryChunk && entryChunk.fileName.includes('-legacy') -+ return Boolean(entryChunk) && entryChunk.fileName.includes('-legacy') - } - - return false -diff --git a/packages/vite/src/node/build.ts b/packages/vite/src/node/build.ts -index d86393d36656fcf9fd328b718cbf684c04d00ae8..b4333ab42b72be6bcda50c67b748299ef5d33414 100644 ---- a/packages/vite/src/node/build.ts -+++ b/packages/vite/src/node/build.ts -@@ -418,7 +418,7 @@ export function resolveBuildOptions( - } - - if (resolved.cssMinify == null) { -- resolved.cssMinify = !!resolved.minify -+ resolved.cssMinify = Boolean(resolved.minify) - } - - return resolved -@@ -475,7 +475,7 @@ export async function build( - ) - const options = config.build - const { logger } = config -- const ssr = !!options.ssr -+ const ssr = Boolean(options.ssr) - const libOptions = options.lib - - logger.info( -@@ -1210,7 +1210,7 @@ export function toOutputFilePathInJS( - hostId, - hostType, - type, -- ssr: !!config.build.ssr, -+ ssr: Boolean(config.build.ssr), - }) - if (typeof result === 'object') { - if (result.runtime) { -@@ -1257,7 +1257,7 @@ export function toOutputFilePathWithoutRuntime( - hostId, - hostType, - type, -- ssr: !!config.build.ssr, -+ ssr: Boolean(config.build.ssr), - }) - if (typeof result === 'object') { - if (result.runtime) { -diff --git a/packages/vite/src/node/config.ts b/packages/vite/src/node/config.ts -index e38e5b5959809ba9a1f08ba10ee14a8722fdd308..3ed2c71a92f6efddded098fffda0b6390d0349cd 100644 ---- a/packages/vite/src/node/config.ts -+++ b/packages/vite/src/node/config.ts -@@ -457,7 +457,7 @@ export async function resolveConfig( - let config = inlineConfig - let configFileDependencies: string[] = [] - let mode = inlineConfig.mode || defaultMode -- const isNodeEnvSet = !!process.env.NODE_ENV -+ const isNodeEnvSet = Boolean(process.env.NODE_ENV) - const packageCache: PackageCache = new Map() - - // some dependencies e.g. @vue/compiler-* relies on NODE_ENV for getting -@@ -469,7 +469,7 @@ export async function resolveConfig( - const configEnv: ConfigEnv = { - mode, - command, -- isSsrBuild: command === 'build' && !!config.build?.ssr, -+ isSsrBuild: command === 'build' && Boolean(config.build?.ssr), - isPreview, - } - -@@ -1152,11 +1152,11 @@ async function bundleConfigFile( - if (!isImport) { - let canResolveWithImport = false - try { -- canResolveWithImport = !!resolveByViteResolver( -+ canResolveWithImport = Boolean(resolveByViteResolver( - id, - importer, - false, -- ) -+ )) - } catch {} - if (canResolveWithImport) { - throw new Error( -diff --git a/packages/vite/src/node/fsUtils.ts b/packages/vite/src/node/fsUtils.ts -index a295d4fc41adb66a139f6ec3c28f68ee57886d25..c7eba3b027f0a346429d6c924904c273a80ebc83 100644 ---- a/packages/vite/src/node/fsUtils.ts -+++ b/packages/vite/src/node/fsUtils.ts -@@ -252,7 +252,7 @@ export function createCachedFsUtils(config: ResolvedConfig): FsUtils { - // fallback to built-in fs for out-of-root and symlinked files - return fs.existsSync(file) - } -- return !!direntCache -+ return Boolean(direntCache) - }, - tryResolveRealFile( - file: string, -diff --git a/packages/vite/src/node/optimizer/optimizer.ts b/packages/vite/src/node/optimizer/optimizer.ts -index 3f76e480a45e75338010944b163da39847e4d633..33db424d2181f6f313721f049ac863f2919368f7 100644 ---- a/packages/vite/src/node/optimizer/optimizer.ts -+++ b/packages/vite/src/node/optimizer/optimizer.ts -@@ -159,7 +159,7 @@ async function createDepsOptimizer( - let enqueuedRerun: (() => void) | undefined - let currentlyProcessing = false - -- let firstRunCalled = !!cachedMetadata -+ let firstRunCalled = Boolean(cachedMetadata) - let warnAboutMissedDependencies = false - - // If this is a cold run, we wait for static imports discovered -diff --git a/packages/vite/src/node/plugins/clientInjections.ts b/packages/vite/src/node/plugins/clientInjections.ts -index c66f3877eca822c14989eac10c480f2639d1080a..6f305f2b71a3eded0527163027b76eb7612479f4 100644 ---- a/packages/vite/src/node/plugins/clientInjections.ts -+++ b/packages/vite/src/node/plugins/clientInjections.ts -@@ -32,7 +32,7 @@ export function clientInjectionsPlugin(config: ResolvedConfig): Plugin { - const protocol = hmrConfig?.protocol || null - const timeout = hmrConfig?.timeout || 30000 - const overlay = hmrConfig?.overlay !== false -- const isHmrServerSpecified = !!hmrConfig?.server -+ const isHmrServerSpecified = Boolean(hmrConfig?.server) - const hmrConfigName = path.basename(config.configFile || 'vite.config.js') - - // hmr.clientPort -> hmr.port -diff --git a/packages/vite/src/node/plugins/css.ts b/packages/vite/src/node/plugins/css.ts -index 26ba17c192f84ea79fed0d46c371925c6fa117d1..de40ca7d2486c1405008fba8ab137d3250953c12 100644 ---- a/packages/vite/src/node/plugins/css.ts -+++ b/packages/vite/src/node/plugins/css.ts -@@ -2184,11 +2184,11 @@ const makeScssWorker = ( - shouldUseFake(_sassPath, _data, options) { - // functions and importer is a function and is not serializable - // in that case, fallback to running in main thread -- return !!( -+ return Boolean(( - (options.functions && Object.keys(options.functions).length > 0) || - (options.importer && - (!Array.isArray(options.importer) || options.importer.length > 0)) -- ) -+ )) - }, - max: maxWorkers, - }, -@@ -2286,11 +2286,11 @@ const makeModernScssWorker = ( - shouldUseFake(_sassPath, _data, options) { - // functions and importer is a function and is not serializable - // in that case, fallback to running in main thread -- return !!( -+ return Boolean(( - (options.functions && Object.keys(options.functions).length > 0) || - (options.importers && - (!Array.isArray(options.importers) || options.importers.length > 0)) -- ) -+ )) - }, - max: maxWorkers, - }, -@@ -2749,10 +2749,10 @@ const makeStylWorker = (maxWorkers: number | undefined) => { - shouldUseFake(_stylusPath, _content, _root, options) { - // define can include functions and those are not serializable - // in that case, fallback to running in main thread -- return !!( -+ return Boolean(( - options.define && - Object.values(options.define).some((d) => typeof d === 'function') -- ) -+ )) - }, - max: maxWorkers, - }, -@@ -2949,7 +2949,7 @@ async function compileLightningCSS( - filename, - code: Buffer.from(src), - targets: config.css?.lightningcss?.targets, -- minify: config.isProduction && !!config.build.cssMinify, -+ minify: config.isProduction && Boolean(config.build.cssMinify), - analyzeDependencies: true, - }) - : await ( -@@ -2986,10 +2986,10 @@ async function compileLightningCSS( - return id - }, - }, -- minify: config.isProduction && !!config.build.cssMinify, -+ minify: config.isProduction && Boolean(config.build.cssMinify), - sourceMap: - config.command === 'build' -- ? !!config.build.sourcemap -+ ? Boolean(config.build.sourcemap) - : config.css?.devSourcemap, - analyzeDependencies: true, - cssModules: cssModuleRE.test(id) -diff --git a/packages/vite/src/node/plugins/define.ts b/packages/vite/src/node/plugins/define.ts -index 585bc0154fa263e270f20eab587042fb66fd8f3d..5d53513ed4806d7ff61563541ab078d8d1625274 100644 ---- a/packages/vite/src/node/plugins/define.ts -+++ b/packages/vite/src/node/plugins/define.ts -@@ -166,7 +166,7 @@ export async function replaceDefine( - platform: 'neutral', - define, - sourcefile: id, -- sourcemap: config.command === 'build' ? !!config.build.sourcemap : true, -+ sourcemap: config.command === 'build' ? Boolean(config.build.sourcemap) : true, - }) - - // remove esbuild's source entries -diff --git a/packages/vite/src/node/plugins/html.ts b/packages/vite/src/node/plugins/html.ts -index b7109debc3863a9ddab4d278912a83f0ac1ac115..5e11b9e009a88b65ba387701d9a73d8cbffa4a68 100644 ---- a/packages/vite/src/node/plugins/html.ts -+++ b/packages/vite/src/node/plugins/html.ts -@@ -435,7 +435,7 @@ export function buildHtmlPlugin(config: ResolvedConfig): Plugin { - getScriptInfo(node) - - const url = src && src.value -- const isPublicFile = !!(url && checkPublicFile(url, config)) -+ const isPublicFile = Boolean((url && checkPublicFile(url, config))) - if (isPublicFile) { - // referencing public dir url, prefix with base - overwriteAttrValue( -diff --git a/packages/vite/src/node/plugins/importAnalysisBuild.ts b/packages/vite/src/node/plugins/importAnalysisBuild.ts -index 7dcad179654fe4afcc406b46fed262a58af5b81f..ae11759929ec1c8856ae12d91fe15db6d16666c3 100644 ---- a/packages/vite/src/node/plugins/importAnalysisBuild.ts -+++ b/packages/vite/src/node/plugins/importAnalysisBuild.ts -@@ -101,7 +101,7 @@ function preload( - seen[dep] = true - const isCss = dep.endsWith('.css') - const cssSelector = isCss ? '[rel="stylesheet"]' : '' -- const isBaseRelative = !!importerUrl -+ const isBaseRelative = Boolean(importerUrl) - - // check if the file is already preloaded by SSR markup - if (isBaseRelative) { -@@ -162,16 +162,16 @@ function preload( - * Build only. During serve this is performed as part of ./importAnalysis. - */ - export function buildImportAnalysisPlugin(config: ResolvedConfig): Plugin { -- const ssr = !!config.build.ssr -+ const ssr = Boolean(config.build.ssr) - const isWorker = config.isWorker -- const insertPreload = !(ssr || !!config.build.lib || isWorker) -+ const insertPreload = !(ssr || Boolean(config.build.lib) || isWorker) - - const resolveModulePreloadDependencies = - config.build.modulePreload && config.build.modulePreload.resolveDependencies - const renderBuiltUrl = config.experimental.renderBuiltUrl -- const customModulePreloadPaths = !!( -+ const customModulePreloadPaths = Boolean(( - resolveModulePreloadDependencies || renderBuiltUrl -- ) -+ )) - const isRelativeBase = config.base === './' || config.base === '' - const optimizeModulePreloadRelativePaths = - isRelativeBase && !customModulePreloadPaths -diff --git a/packages/vite/src/node/plugins/importMetaGlob.ts b/packages/vite/src/node/plugins/importMetaGlob.ts -index d596d39d1a62e9875558a058db9736e1ef4fa95d..ce56279df2bcf62382ae583d2066e195656b9ff0 100644 ---- a/packages/vite/src/node/plugins/importMetaGlob.ts -+++ b/packages/vite/src/node/plugins/importMetaGlob.ts -@@ -396,7 +396,7 @@ export async function transformGlobImport( - await fg(globsResolved, { - cwd, - absolute: true, -- dot: !!options.exhaustive, -+ dot: Boolean(options.exhaustive), - ignore: options.exhaustive - ? [] - : [join(cwd, '**/node_modules/**')], -diff --git a/packages/vite/src/node/plugins/terser.ts b/packages/vite/src/node/plugins/terser.ts -index 90c29b26c7501ec2c2bad08fd985d85cf8ca580a..d045ca27c386e3e24a06791d420b66770421df72 100644 ---- a/packages/vite/src/node/plugins/terser.ts -+++ b/packages/vite/src/node/plugins/terser.ts -@@ -84,7 +84,7 @@ export function terserPlugin(config: ResolvedConfig): Plugin { - const res = await worker.run(terserPath, code, { - safari10: true, - ...terserOptions, -- sourceMap: !!outputOptions.sourcemap, -+ sourceMap: Boolean(outputOptions.sourcemap), - module: outputOptions.format.startsWith('es'), - toplevel: outputOptions.format === 'cjs', - }) -diff --git a/packages/vite/src/node/server/index.ts b/packages/vite/src/node/server/index.ts -index b5e1d9c57e5f196cc2cca347cec1a14496420263..5623153d4a3946be42096cfd2727cbf49f139d1d 100644 ---- a/packages/vite/src/node/server/index.ts -+++ b/packages/vite/src/node/server/index.ts -@@ -684,7 +684,7 @@ export async function _createServer( - }, - async restart(forceOptimize?: boolean) { - if (!server._restartPromise) { -- server._forceOptimizeOnRestart = !!forceOptimize -+ server._forceOptimizeOnRestart = Boolean(forceOptimize) - server._restartPromise = restartServer(server).finally(() => { - server._restartPromise = null - server._forceOptimizeOnRestart = false -@@ -865,7 +865,7 @@ export async function _createServer( - - // base - if (config.base !== '/') { -- middlewares.use(baseMiddleware(config.rawBase, !!middlewareMode)) -+ middlewares.use(baseMiddleware(config.rawBase, Boolean(middlewareMode))) - } - - // open in editor support -@@ -920,7 +920,7 @@ export async function _createServer( - } - - // error handler -- middlewares.use(errorMiddleware(server, !!middlewareMode)) -+ middlewares.use(errorMiddleware(server, Boolean(middlewareMode))) - - // httpServer.listen can be called multiple times - // when port when using next port number -diff --git a/packages/vite/src/node/server/moduleGraph.ts b/packages/vite/src/node/server/moduleGraph.ts -index 442ece308dbaff945085f9e797892530aca64c95..9bed1e7feb64ce71974eda8e8adb1109e209e656 100644 ---- a/packages/vite/src/node/server/moduleGraph.ts -+++ b/packages/vite/src/node/server/moduleGraph.ts -@@ -501,7 +501,7 @@ export class ModuleGraph { - ssr?: boolean, - alreadyResolved?: PartialResolvedId, - ): Promise { -- const resolved = alreadyResolved ?? (await this.resolveId(url, !!ssr)) -+ const resolved = alreadyResolved ?? (await this.resolveId(url, Boolean(ssr))) - const resolvedId = resolved?.id || url - if ( - url !== resolvedId && -diff --git a/packages/vite/src/node/server/pluginContainer.ts b/packages/vite/src/node/server/pluginContainer.ts -index 3251790d1698644c05ea4da3e2091b15b814aae2..4aec442b9d452d98f46a74f901cfebafb6b0e16c 100644 ---- a/packages/vite/src/node/server/pluginContainer.ts -+++ b/packages/vite/src/node/server/pluginContainer.ts -@@ -311,8 +311,8 @@ class PluginContainer { - ): Promise { - const skip = options?.skip - const ssr = options?.ssr -- const scan = !!options?.scan -- const ctx = new ResolveIdContext(this, !!ssr, skip, scan) -+ const scan = Boolean(options?.scan) -+ const ctx = new ResolveIdContext(this, Boolean(ssr), skip, scan) - - const resolveStart = debugResolve ? performance.now() : 0 - let id: string | null = null -@@ -331,7 +331,7 @@ class PluginContainer { - handler.call(ctx as any, rawId, importer, { - attributes: options?.attributes ?? {}, - custom: options?.custom, -- isEntry: !!options?.isEntry, -+ isEntry: Boolean(options?.isEntry), - ssr, - scan, - }), -@@ -383,7 +383,7 @@ class PluginContainer { - }, - ): Promise { - const ssr = options?.ssr -- const ctx = new LoadPluginContext(this, !!ssr) -+ const ctx = new LoadPluginContext(this, Boolean(ssr)) - - for (const plugin of this.getSortedPlugins('load')) { - if (this._closed && !ssr) throwClosedServerError() -@@ -421,7 +421,7 @@ class PluginContainer { - id, - code, - inMap as SourceMap, -- !!ssr, -+ Boolean(ssr), - ) - ctx._addedImports = this._getAddedImports(id) - -@@ -539,7 +539,7 @@ class PluginContext implements Omit { - let out = await this._container.resolveId(id, importer, { - attributes: options?.attributes, - custom: options?.custom, -- isEntry: !!options?.isEntry, -+ isEntry: Boolean(options?.isEntry), - skip, - ssr: this.ssr, - scan: this._scan, -diff --git a/packages/vite/src/node/server/searchRoot.ts b/packages/vite/src/node/server/searchRoot.ts -index edb7a76946266e1cc2acbbea8df196aa4e50ebf3..eefe04680579da2450c6543faabcd2235c966d12 100644 ---- a/packages/vite/src/node/server/searchRoot.ts -+++ b/packages/vite/src/node/server/searchRoot.ts -@@ -29,7 +29,7 @@ function hasWorkspacePackageJSON(root: string): boolean { - } - try { - const content = JSON.parse(fs.readFileSync(path, 'utf-8')) || {} -- return !!content.workspaces -+ return Boolean(content.workspaces) - } catch { - return false - } -diff --git a/packages/vite/src/node/server/transformRequest.ts b/packages/vite/src/node/server/transformRequest.ts -index dc98c1795daf26a8344158bee040b58d6955e55d..fdd06e9ed24159273802afead00a04cf053027a0 100644 ---- a/packages/vite/src/node/server/transformRequest.ts -+++ b/packages/vite/src/node/server/transformRequest.ts -@@ -132,7 +132,7 @@ async function doTransform( - url = removeTimestampQuery(url) - - const { config, pluginContainer } = server -- const ssr = !!options.ssr -+ const ssr = Boolean(options.ssr) - - if (ssr && isDepsOptimizerEnabled(config, true)) { - await initDevSsrDepsOptimizer(config, server) -@@ -237,7 +237,7 @@ async function loadAndTransform( - const { logger } = config - const prettyUrl = - debugLoad || debugTransform ? prettifyUrl(url, config.root) : '' -- const ssr = !!options.ssr -+ const ssr = Boolean(options.ssr) - - const file = cleanUrl(id) - -diff --git a/packages/vite/src/node/ssr/ssrExternal.ts b/packages/vite/src/node/ssr/ssrExternal.ts -index 5681e000502a5f2d1a44183711c8c1c128b6d677..53eff3822d9a998ff0eda5cddffd8c1b3085099f 100644 ---- a/packages/vite/src/node/ssr/ssrExternal.ts -+++ b/packages/vite/src/node/ssr/ssrExternal.ts -@@ -59,7 +59,7 @@ export function createIsConfiguredAsSsrExternal( - return false - } - try { -- return !!tryNodeResolve( -+ return Boolean(tryNodeResolve( - id, - // Skip passing importer in build to avoid externalizing non-hoisted dependencies - // unresolvable from root (which would be unresolvable from output bundles also) -@@ -73,8 +73,8 @@ export function createIsConfiguredAsSsrExternal( - true, - // Allow linked packages to be externalized if they are explicitly - // configured as external -- !!configuredAsExternal, -- )?.external -+ Boolean(configuredAsExternal), -+ )?.external) - } catch (e) { - debug?.( - `Failed to node resolve "${id}". Skipping externalizing it by default.`, -diff --git a/packages/vite/src/node/utils.ts b/packages/vite/src/node/utils.ts -index 393bc391799aad12b47cfd2f2495e95529ec4348..248e1c2852428990b7f24b8fa605977325299ed7 100644 ---- a/packages/vite/src/node/utils.ts -+++ b/packages/vite/src/node/utils.ts -@@ -168,7 +168,7 @@ export function createDebugger( - let enabled = log.enabled - if (enabled && onlyWhenFocused) { - const ns = typeof onlyWhenFocused === 'string' ? onlyWhenFocused : namespace -- enabled = !!DEBUG?.includes(ns) -+ enabled = Boolean(DEBUG?.includes(ns)) - } - - if (enabled) { -@@ -745,7 +745,7 @@ function splitSrcSetDescriptor(srcs: string): ImageCandidate[] { - descriptor: src.slice(url.length).trim(), - } - }) -- .filter(({ url }) => !!url) -+ .filter(({ url }) => Boolean(url)) - } - - export function processSrcSet( -diff --git a/playground/vitestSetup.ts b/playground/vitestSetup.ts -index eb28b5f544d453bc234ee6f3da780f2d4fa44cd4..9c25b1cac3de9e99909374f4f1fa6fd11fa898d1 100644 ---- a/playground/vitestSetup.ts -+++ b/playground/vitestSetup.ts -@@ -27,7 +27,7 @@ import { beforeAll, inject } from 'vitest' - - export const workspaceRoot = path.resolve(__dirname, '../') - --export const isBuild = !!process.env.VITE_TEST_BUILD -+export const isBuild = Boolean(process.env.VITE_TEST_BUILD) - export const isServe = !isBuild - export const isWindows = process.platform === 'win32' - export const viteBinPath = path.posix.join( -@@ -267,7 +267,7 @@ export async function startDefaultServe(): Promise { - }, - ) - const rollupOutput = await build(buildConfig) -- const isWatch = !!resolvedConfig!.build.watch -+ const isWatch = Boolean(resolvedConfig)!.build.watch - // in build watch,call startStaticServer after the build is complete - if (isWatch) { - watcher = rollupOutput as RollupWatcher diff --git a/tests/integration/codemod/canonical/move_functions_to_new_file/test_pylsp/expected_diff.patch.skip b/tests/integration/codemod/canonical/move_functions_to_new_file/test_pylsp/expected_diff.patch.skip deleted file mode 100644 index 207b7ce7b..000000000 --- a/tests/integration/codemod/canonical/move_functions_to_new_file/test_pylsp/expected_diff.patch.skip +++ /dev/null @@ -1,8378 +0,0 @@ -diff --git a/pylsp/hookspecs.py b/pylsp/hookspecs.py -index 41508be..e7f3f1f 100644 ---- a/pylsp/hookspecs.py -+++ b/pylsp/hookspecs.py -@@ -2,134 +2,3 @@ - # Copyright 2021- Python Language Server Contributors. - - from pylsp import hookspec -- -- --@hookspec --def pylsp_code_actions(config, workspace, document, range, context): -- pass -- -- --@hookspec --def pylsp_code_lens(config, workspace, document) -> None: -- pass -- -- --@hookspec --def pylsp_commands(config, workspace) -> None: -- """The list of command strings supported by the server. -- -- Returns: -- List[str]: The supported commands. -- """ -- -- --@hookspec --def pylsp_completions(config, workspace, document, position, ignored_names) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_completion_item_resolve(config, workspace, document, completion_item) -> None: -- pass -- -- --@hookspec --def pylsp_definitions(config, workspace, document, position) -> None: -- pass -- -- --@hookspec --def pylsp_dispatchers(config, workspace) -> None: -- pass -- -- --@hookspec --def pylsp_document_did_open(config, workspace, document) -> None: -- pass -- -- --@hookspec --def pylsp_document_did_save(config, workspace, document) -> None: -- pass -- -- --@hookspec --def pylsp_document_highlight(config, workspace, document, position) -> None: -- pass -- -- --@hookspec --def pylsp_document_symbols(config, workspace, document) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_execute_command(config, workspace, command, arguments) -> None: -- pass -- -- --@hookspec --def pylsp_experimental_capabilities(config, workspace) -> None: -- pass -- -- --@hookspec --def pylsp_folding_range(config, workspace, document) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_format_document(config, workspace, document, options) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_format_range(config, workspace, document, range, options) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_hover(config, workspace, document, position) -> None: -- pass -- -- --@hookspec --def pylsp_initialize(config, workspace) -> None: -- pass -- -- --@hookspec --def pylsp_initialized() -> None: -- pass -- -- --@hookspec --def pylsp_lint(config, workspace, document, is_saved) -> None: -- pass -- -- --@hookspec --def pylsp_references( -- config, workspace, document, position, exclude_declaration --) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_rename(config, workspace, document, position, new_name) -> None: -- pass -- -- --@hookspec --def pylsp_settings(config) -> None: -- pass -- -- --@hookspec(firstresult=True) --def pylsp_signature_help(config, workspace, document, position) -> None: -- pass -- -- --@hookspec --def pylsp_workspace_configuration_changed(config, workspace) -> None: -- pass -diff --git a/pylsp/plugins/autopep8_format.py b/pylsp/plugins/autopep8_format.py -index 2b3491d..69aba5c 100644 ---- a/pylsp/plugins/autopep8_format.py -+++ b/pylsp/plugins/autopep8_format.py -@@ -1,6 +1,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import _format - import logging - - import pycodestyle -@@ -9,87 +11,3 @@ from autopep8 import fix_code - - from pylsp import hookimpl - from pylsp._utils import get_eol_chars -- --log = logging.getLogger(__name__) -- -- --@hookimpl(tryfirst=True) # Prefer autopep8 over YAPF --def pylsp_format_document(config, workspace, document, options): -- with workspace.report_progress("format: autopep8"): -- log.info("Formatting document %s with autopep8", document) -- return _format(config, document) -- -- --@hookimpl(tryfirst=True) # Prefer autopep8 over YAPF --def pylsp_format_range(config, workspace, document, range, options): -- log.info("Formatting document %s in range %s with autopep8", document, range) -- -- # First we 'round' the range up/down to full lines only -- range["start"]["character"] = 0 -- range["end"]["line"] += 1 -- range["end"]["character"] = 0 -- -- # Add 1 for 1-indexing vs LSP's 0-indexing -- line_range = (range["start"]["line"] + 1, range["end"]["line"]) -- return _format(config, document, line_range=line_range) -- -- --def _format(config, document, line_range=None): -- options = _autopep8_config(config, document) -- if line_range: -- options["line_range"] = list(line_range) -- -- # Temporarily re-monkey-patch the continued_indentation checker - #771 -- del pycodestyle._checks["logical_line"][pycodestyle.continued_indentation] -- pycodestyle.register_check(autopep8_c_i) -- -- # Autopep8 doesn't work with CR line endings, so we replace them by '\n' -- # and restore them below. -- replace_cr = False -- source = document.source -- eol_chars = get_eol_chars(source) -- if eol_chars == "\r": -- replace_cr = True -- source = source.replace("\r", "\n") -- -- new_source = fix_code(source, options=options) -- -- # Switch it back -- del pycodestyle._checks["logical_line"][autopep8_c_i] -- pycodestyle.register_check(pycodestyle.continued_indentation) -- -- if new_source == source: -- return [] -- -- if replace_cr: -- new_source = new_source.replace("\n", "\r") -- -- # I'm too lazy at the moment to parse diffs into TextEdit items -- # So let's just return the entire file... -- return [ -- { -- "range": { -- "start": {"line": 0, "character": 0}, -- # End char 0 of the line after our document -- "end": {"line": len(document.lines), "character": 0}, -- }, -- "newText": new_source, -- } -- ] -- -- --def _autopep8_config(config, document=None): -- # We user pycodestyle settings to avoid redefining things -- path = document.path if document is not None else None -- settings = config.plugin_settings("pycodestyle", document_path=path) -- options = { -- "exclude": settings.get("exclude"), -- "hang_closing": settings.get("hangClosing"), -- "ignore": settings.get("ignore"), -- "max_line_length": settings.get("maxLineLength"), -- "select": settings.get("select"), -- "aggressive": settings.get("aggressive"), -- } -- -- # Filter out null options -- return {k: v for k, v in options.items() if v} -diff --git a/pylsp/plugins/definition.py b/pylsp/plugins/definition.py -index 67abfb7..266dbe0 100644 ---- a/pylsp/plugins/definition.py -+++ b/pylsp/plugins/definition.py -@@ -2,6 +2,8 @@ - # Copyright 2021- Python Language Server Contributors. - from __future__ import annotations - -+from pylsp.pylsp_shared import _resolve_definition -+from pylsp.pylsp_shared import _not_internal_definition - import logging - from typing import TYPE_CHECKING, Any, Dict, List - -@@ -17,68 +19,3 @@ if TYPE_CHECKING: - from pylsp.workspace import Document - - log = logging.getLogger(__name__) -- -- --MAX_JEDI_GOTO_HOPS = 100 -- -- --def _resolve_definition( -- maybe_defn: Name, script: Script, settings: Dict[str, Any] --) -> Name: -- for _ in range(MAX_JEDI_GOTO_HOPS): -- if maybe_defn.is_definition() or maybe_defn.module_path != script.path: -- break -- defns = script.goto( -- follow_imports=settings.get("follow_imports", True), -- follow_builtin_imports=settings.get("follow_builtin_imports", True), -- line=maybe_defn.line, -- column=maybe_defn.column, -- ) -- if len(defns) == 1: -- maybe_defn = defns[0] -- else: -- break -- return maybe_defn -- -- --@hookimpl --def pylsp_definitions( -- config: Config, document: Document, position: Dict[str, int] --) -> List[Dict[str, Any]]: -- settings = config.plugin_settings("jedi_definition") -- code_position = _utils.position_to_jedi_linecolumn(document, position) -- script = document.jedi_script(use_document_path=True) -- auto_import_modules = jedi.settings.auto_import_modules -- -- try: -- jedi.settings.auto_import_modules = [] -- definitions = script.goto( -- follow_imports=settings.get("follow_imports", True), -- follow_builtin_imports=settings.get("follow_builtin_imports", True), -- **code_position, -- ) -- definitions = [_resolve_definition(d, script, settings) for d in definitions] -- finally: -- jedi.settings.auto_import_modules = auto_import_modules -- -- follow_builtin_defns = settings.get("follow_builtin_definitions", True) -- return [ -- { -- "uri": uris.uri_with(document.uri, path=str(d.module_path)), -- "range": { -- "start": {"line": d.line - 1, "character": d.column}, -- "end": {"line": d.line - 1, "character": d.column + len(d.name)}, -- }, -- } -- for d in definitions -- if d.is_definition() and (follow_builtin_defns or _not_internal_definition(d)) -- ] -- -- --def _not_internal_definition(definition: Name) -> bool: -- return ( -- definition.line is not None -- and definition.column is not None -- and definition.module_path is not None -- and not definition.in_builtin_module() -- ) -diff --git a/pylsp/plugins/flake8_lint.py b/pylsp/plugins/flake8_lint.py -index 74e2664..c118d10 100644 ---- a/pylsp/plugins/flake8_lint.py -+++ b/pylsp/plugins/flake8_lint.py -@@ -3,6 +3,11 @@ - - """Linter pluging for flake8""" - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import run_flake8 -+from pylsp.pylsp_shared import build_args -+from pylsp.pylsp_shared import parse_stdout -+from pylsp.pylsp_shared import PYFLAKES_ERROR_MESSAGES - import logging - import os.path - import re -@@ -13,231 +18,3 @@ from subprocess import PIPE, Popen - from flake8.plugins.pyflakes import FLAKE8_PYFLAKES_CODES - - from pylsp import hookimpl, lsp --from pylsp.plugins.pyflakes_lint import PYFLAKES_ERROR_MESSAGES -- --log = logging.getLogger(__name__) -- --FIX_IGNORES_RE = re.compile(r"([^a-zA-Z0-9_,]*;.*(\W+||$))") --UNNECESSITY_CODES = { -- "F401", # `module` imported but unused -- "F504", # % format unused named arguments -- "F522", # .format(...) unused named arguments -- "F523", # .format(...) unused positional arguments -- "F841", # local variable `name` is assigned to but never used --} --# NOTE: If the user sets the flake8 executable with workspace configuration, the --# error codes in this set may be inaccurate. --ERROR_CODES = ( -- # Errors from the pyflakes plugin of flake8 -- {FLAKE8_PYFLAKES_CODES.get(m.__name__, "E999") for m in PYFLAKES_ERROR_MESSAGES} -- # Syntax error from flake8 itself -- | {"E999"} --) -- -- --@hookimpl --def pylsp_settings(): -- # Default flake8 to disabled -- return {"plugins": {"flake8": {"enabled": False}}} -- -- --@hookimpl --def pylsp_lint(workspace, document): -- with workspace.report_progress("lint: flake8"): -- config = workspace._config -- settings = config.plugin_settings("flake8", document_path=document.path) -- log.debug("Got flake8 settings: %s", settings) -- -- ignores = settings.get("ignore", []) -- per_file_ignores = settings.get("perFileIgnores") -- -- if per_file_ignores: -- prev_file_pat = None -- for path in per_file_ignores: -- try: -- file_pat, errors = path.split(":") -- prev_file_pat = file_pat -- except ValueError: -- # It's legal to just specify another error type for the same -- # file pattern: -- if prev_file_pat is None: -- log.warning("skipping a Per-file-ignore with no file pattern") -- continue -- file_pat = prev_file_pat -- errors = path -- if PurePath(document.path).match(file_pat): -- ignores.extend(errors.split(",")) -- -- opts = { -- "config": settings.get("config"), -- "exclude": settings.get("exclude"), -- "extend-ignore": settings.get("extendIgnore"), -- "extend-select": settings.get("extendSelect"), -- "filename": settings.get("filename"), -- "hang-closing": settings.get("hangClosing"), -- "ignore": ignores or None, -- "max-complexity": settings.get("maxComplexity"), -- "max-line-length": settings.get("maxLineLength"), -- "indent-size": settings.get("indentSize"), -- "select": settings.get("select"), -- } -- -- # flake takes only absolute path to the config. So we should check and -- # convert if necessary -- if opts.get("config") and not os.path.isabs(opts.get("config")): -- opts["config"] = os.path.abspath( -- os.path.expanduser(os.path.expandvars(opts.get("config"))) -- ) -- log.debug("using flake8 with config: %s", opts["config"]) -- -- # Call the flake8 utility then parse diagnostics from stdout -- flake8_executable = settings.get("executable", "flake8") -- -- args = build_args(opts) -- -- # ensure the same source is used for flake8 execution and result parsing; -- # single source access improves performance as it is only one disk access -- source = document.source -- output = run_flake8(flake8_executable, args, document, source) -- return parse_stdout(source, output) -- -- --def run_flake8(flake8_executable, args, document, source): -- """Run flake8 with the provided arguments, logs errors -- from stderr if any. -- """ -- # a quick temporary fix to deal with Atom -- args = [ -- (i if not i.startswith("--ignore=") else FIX_IGNORES_RE.sub("", i)) -- for i in args -- if i is not None -- ] -- -- if document.path and document.path.startswith(document._workspace.root_path): -- args.extend( -- [ -- "--stdin-display-name", -- os.path.relpath(document.path, document._workspace.root_path), -- ] -- ) -- -- # if executable looks like a path resolve it -- if not os.path.isfile(flake8_executable) and os.sep in flake8_executable: -- flake8_executable = os.path.abspath( -- os.path.expanduser(os.path.expandvars(flake8_executable)) -- ) -- -- log.debug("Calling %s with args: '%s'", flake8_executable, args) -- popen_kwargs = {} -- if cwd := document._workspace.root_path: -- popen_kwargs["cwd"] = cwd -- try: -- cmd = [flake8_executable] -- cmd.extend(args) -- p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, **popen_kwargs) -- except IOError: -- log.debug( -- "Can't execute %s. Trying with '%s -m flake8'", -- flake8_executable, -- sys.executable, -- ) -- cmd = [sys.executable, "-m", "flake8"] -- cmd.extend(args) -- p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, **popen_kwargs) -- (stdout, stderr) = p.communicate(source.encode()) -- if stderr: -- log.error("Error while running flake8 '%s'", stderr.decode()) -- return stdout.decode() -- -- --def build_args(options): -- """Build arguments for calling flake8. -- -- Args: -- options: dictionary of argument names and their values. -- """ -- args = ["-"] # use stdin -- for arg_name, arg_val in options.items(): -- if arg_val is None: -- continue -- arg = None -- if isinstance(arg_val, list): -- arg = "--{}={}".format(arg_name, ",".join(arg_val)) -- elif isinstance(arg_val, bool): -- if arg_val: -- arg = "--{}".format(arg_name) -- else: -- arg = "--{}={}".format(arg_name, arg_val) -- args.append(arg) -- return args -- -- --def parse_stdout(source, stdout): -- """ -- Build a diagnostics from flake8's output, it should extract every result and format -- it into a dict that looks like this: -- { -- 'source': 'flake8', -- 'code': code, # 'E501' -- 'range': { -- 'start': { -- 'line': start_line, -- 'character': start_column, -- }, -- 'end': { -- 'line': end_line, -- 'character': end_column, -- }, -- }, -- 'message': msg, -- 'severity': lsp.DiagnosticSeverity.*, -- } -- -- Args: -- document: The document to be linted. -- stdout: output from flake8 -- Returns: -- A list of dictionaries. -- """ -- -- document_lines = source.splitlines(True) -- diagnostics = [] -- lines = stdout.splitlines() -- for raw_line in lines: -- parsed_line = re.match(r"(.*):(\d*):(\d*): (\w*) (.*)", raw_line) -- if not parsed_line: -- log.debug("Flake8 output parser can't parse line '%s'", raw_line) -- continue -- -- parsed_line = parsed_line.groups() -- if len(parsed_line) != 5: -- log.debug("Flake8 output parser can't parse line '%s'", raw_line) -- continue -- -- _, line, character, code, msg = parsed_line -- line = int(line) - 1 -- character = int(character) - 1 -- # show also the code in message -- msg = code + " " + msg -- severity = lsp.DiagnosticSeverity.Warning -- if code in ERROR_CODES: -- severity = lsp.DiagnosticSeverity.Error -- diagnostic = { -- "source": "flake8", -- "code": code, -- "range": { -- "start": {"line": line, "character": character}, -- "end": { -- "line": line, -- # no way to determine the column -- "character": len(document_lines[line]), -- }, -- }, -- "message": msg, -- "severity": severity, -- } -- if code in UNNECESSITY_CODES: -- diagnostic["tags"] = [lsp.DiagnosticTag.Unnecessary] -- diagnostics.append(diagnostic) -- -- return diagnostics -diff --git a/pylsp/plugins/folding.py b/pylsp/plugins/folding.py -index 123ba4a..71ddcef 100644 ---- a/pylsp/plugins/folding.py -+++ b/pylsp/plugins/folding.py -@@ -1,210 +1,10 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import __compute_folding_ranges - import re - - import parso - import parso.python.tree as tree_nodes - - from pylsp import hookimpl -- --SKIP_NODES = (tree_nodes.Module, tree_nodes.IfStmt, tree_nodes.TryStmt) --IDENTATION_REGEX = re.compile(r"(\s+).+") -- -- --@hookimpl --def pylsp_folding_range(document): -- program = document.source + "\n" -- lines = program.splitlines() -- tree = parso.parse(program) -- ranges = __compute_folding_ranges(tree, lines) -- -- results = [] -- for start_line, end_line in ranges: -- start_line -= 1 -- end_line -= 1 -- # If start/end character is not defined, then it defaults to the -- # corresponding line last character -- results.append( -- { -- "startLine": start_line, -- "endLine": end_line, -- } -- ) -- return results -- -- --def __merge_folding_ranges(left, right): -- for start in list(left.keys()): -- right_start = right.pop(start, None) -- if right_start is not None: -- left[start] = max(right_start, start) -- left.update(right) -- return left -- -- --def __empty_identation_stack( -- identation_stack, level_limits, current_line, folding_ranges --): -- while identation_stack != []: -- upper_level = identation_stack.pop(0) -- level_start = level_limits.pop(upper_level) -- folding_ranges.append((level_start, current_line)) -- return folding_ranges -- -- --def __match_identation_stack( -- identation_stack, level, level_limits, folding_ranges, current_line --): -- upper_level = identation_stack.pop(0) -- while upper_level >= level: -- level_start = level_limits.pop(upper_level) -- folding_ranges.append((level_start, current_line)) -- upper_level = identation_stack.pop(0) -- identation_stack.insert(0, upper_level) -- return identation_stack, folding_ranges -- -- --def __compute_folding_ranges_identation(text): -- lines = text.splitlines() -- folding_ranges = [] -- identation_stack = [] -- level_limits = {} -- current_level = 0 -- current_line = 0 -- while lines[current_line] == "": -- current_line += 1 -- for i, line in enumerate(lines): -- if i < current_line: -- continue -- i += 1 -- identation_match = IDENTATION_REGEX.match(line) -- if identation_match is not None: -- whitespace = identation_match.group(1) -- level = len(whitespace) -- if level > current_level: -- level_limits[current_level] = current_line -- identation_stack.insert(0, current_level) -- current_level = level -- elif level < current_level: -- identation_stack, folding_ranges = __match_identation_stack( -- identation_stack, level, level_limits, folding_ranges, current_line -- ) -- current_level = level -- else: -- folding_ranges = __empty_identation_stack( -- identation_stack, level_limits, current_line, folding_ranges -- ) -- current_level = 0 -- if line.strip() != "": -- current_line = i -- folding_ranges = __empty_identation_stack( -- identation_stack, level_limits, current_line, folding_ranges -- ) -- return dict(folding_ranges) -- -- --def __check_if_node_is_valid(node): -- valid = True -- if isinstance(node, tree_nodes.PythonNode): -- kind = node.type -- valid = kind not in { -- "decorated", -- "parameters", -- "dictorsetmaker", -- "testlist_comp", -- } -- if kind == "suite": -- if isinstance(node.parent, tree_nodes.Function): -- valid = False -- return valid -- -- --def __handle_skip(stack, skip): -- body = stack[skip] -- children = [body] -- if hasattr(body, "children"): -- children = body.children -- stack = stack[:skip] + children + stack[skip + 1 :] -- node = body -- end_line, _ = body.end_pos -- return node, end_line -- -- --def __handle_flow_nodes(node, end_line, stack): -- from_keyword = False -- if isinstance(node, tree_nodes.Keyword): -- from_keyword = True -- if node.value in {"if", "elif", "with", "while"}: -- node, end_line = __handle_skip(stack, 2) -- elif node.value in {"except"}: -- first_node = stack[0] -- if isinstance(first_node, tree_nodes.Operator): -- node, end_line = __handle_skip(stack, 1) -- else: -- node, end_line = __handle_skip(stack, 2) -- elif node.value in {"for"}: -- node, end_line = __handle_skip(stack, 4) -- elif node.value in {"else"}: -- node, end_line = __handle_skip(stack, 1) -- return end_line, from_keyword, node, stack -- -- --def __compute_start_end_lines(node, stack): -- start_line, _ = node.start_pos -- end_line, _ = node.end_pos -- modified = False -- end_line, from_keyword, node, stack = __handle_flow_nodes(node, end_line, stack) -- -- last_leaf = node.get_last_leaf() -- last_newline = isinstance(last_leaf, tree_nodes.Newline) -- last_operator = isinstance(last_leaf, tree_nodes.Operator) -- node_is_operator = isinstance(node, tree_nodes.Operator) -- last_operator = last_operator or not node_is_operator -- -- end_line -= 1 -- -- if isinstance(node.parent, tree_nodes.PythonNode) and not from_keyword: -- kind = node.type -- if kind in {"suite", "atom", "atom_expr", "arglist"}: -- if len(stack) > 0: -- next_node = stack[0] -- next_line, _ = next_node.start_pos -- if next_line > end_line: -- end_line += 1 -- modified = True -- if not last_newline and not modified and not last_operator: -- end_line += 1 -- return start_line, end_line, stack -- -- --def __compute_folding_ranges(tree, lines): -- folding_ranges = {} -- stack = [tree] -- -- while len(stack) > 0: -- node = stack.pop(0) -- if isinstance(node, tree_nodes.Newline): -- # Skip newline nodes -- continue -- if isinstance(node, tree_nodes.PythonErrorNode): -- # Fallback to indentation-based (best-effort) folding -- start_line, _ = node.start_pos -- start_line -= 1 -- padding = [""] * start_line -- text = "\n".join(padding + lines[start_line:]) + "\n" -- identation_ranges = __compute_folding_ranges_identation(text) -- folding_ranges = __merge_folding_ranges(folding_ranges, identation_ranges) -- break -- if not isinstance(node, SKIP_NODES): -- valid = __check_if_node_is_valid(node) -- if valid: -- start_line, end_line, stack = __compute_start_end_lines(node, stack) -- if end_line > start_line: -- current_end = folding_ranges.get(start_line, -1) -- folding_ranges[start_line] = max(current_end, end_line) -- if hasattr(node, "children"): -- stack = node.children + stack -- -- folding_ranges = sorted(folding_ranges.items()) -- return folding_ranges -diff --git a/pylsp/plugins/highlight.py b/pylsp/plugins/highlight.py -index c4c1240..4bdbe18 100644 ---- a/pylsp/plugins/highlight.py -+++ b/pylsp/plugins/highlight.py -@@ -6,31 +6,3 @@ import logging - from pylsp import _utils, hookimpl, lsp - - log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_document_highlight(document, position): -- code_position = _utils.position_to_jedi_linecolumn(document, position) -- usages = document.jedi_script().get_references(**code_position) -- -- def is_valid(definition): -- return definition.line is not None and definition.column is not None -- -- def local_to_document(definition): -- return ( -- not definition.module_path or str(definition.module_path) == document.path -- ) -- -- return [ -- { -- "range": { -- "start": {"line": d.line - 1, "character": d.column}, -- "end": {"line": d.line - 1, "character": d.column + len(d.name)}, -- }, -- "kind": lsp.DocumentHighlightKind.Write -- if d.is_definition() -- else lsp.DocumentHighlightKind.Read, -- } -- for d in usages -- if is_valid(d) and local_to_document(d) -- ] -diff --git a/pylsp/plugins/hover.py b/pylsp/plugins/hover.py -index ca69d1b..fe8d9c5 100644 ---- a/pylsp/plugins/hover.py -+++ b/pylsp/plugins/hover.py -@@ -6,45 +6,3 @@ import logging - from pylsp import _utils, hookimpl - - log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_hover(config, document, position): -- code_position = _utils.position_to_jedi_linecolumn(document, position) -- definitions = document.jedi_script(use_document_path=True).infer(**code_position) -- word = document.word_at_position(position) -- -- # Find first exact matching definition -- definition = next((x for x in definitions if x.name == word), None) -- -- # Ensure a definition is used if only one is available -- # even if the word doesn't match. An example of this case is 'np' -- # where 'numpy' doesn't match with 'np'. Same for NumPy ufuncs -- if len(definitions) == 1: -- definition = definitions[0] -- -- if not definition: -- return {"contents": ""} -- -- hover_capabilities = config.capabilities.get("textDocument", {}).get("hover", {}) -- supported_markup_kinds = hover_capabilities.get("contentFormat", ["markdown"]) -- preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -- -- # Find first exact matching signature -- signature = next( -- ( -- x.to_string() -- for x in definition.get_signatures() -- if (x.name == word and x.type not in ["module"]) -- ), -- "", -- ) -- -- return { -- "contents": _utils.format_docstring( -- # raw docstring returns only doc, without signature -- definition.docstring(raw=True), -- preferred_markup_kind, -- signatures=[signature] if signature else None, -- ) -- } -diff --git a/pylsp/plugins/jedi_completion.py b/pylsp/plugins/jedi_completion.py -index 2796a09..e2fc847 100644 ---- a/pylsp/plugins/jedi_completion.py -+++ b/pylsp/plugins/jedi_completion.py -@@ -1,6 +1,9 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import use_snippets -+from pylsp.pylsp_shared import _resolve_completion -+from pylsp.pylsp_shared import _format_completion - import logging - import os - -@@ -10,290 +13,3 @@ from pylsp import _utils, hookimpl, lsp - from pylsp.plugins._resolvers import LABEL_RESOLVER, SNIPPET_RESOLVER - - log = logging.getLogger(__name__) -- --# Map to the LSP type --# > Valid values for type are ``module``, `` class ``, ``instance``, ``function``, --# > ``param``, ``path``, ``keyword``, ``property`` and ``statement``. --# see: https://jedi.readthedocs.io/en/latest/docs/api-classes.html#jedi.api.classes.BaseName.type --_TYPE_MAP = { -- "module": lsp.CompletionItemKind.Module, -- "namespace": lsp.CompletionItemKind.Module, # to be added in Jedi 0.18+ -- "class": lsp.CompletionItemKind.Class, -- "instance": lsp.CompletionItemKind.Reference, -- "function": lsp.CompletionItemKind.Function, -- "param": lsp.CompletionItemKind.Variable, -- "path": lsp.CompletionItemKind.File, -- "keyword": lsp.CompletionItemKind.Keyword, -- "property": lsp.CompletionItemKind.Property, # added in Jedi 0.18 -- "statement": lsp.CompletionItemKind.Variable, --} -- --# Types of parso nodes for which snippet is not included in the completion --_IMPORTS = ("import_name", "import_from") -- --# Types of parso node for errors --_ERRORS = ("error_node",) -- -- --@hookimpl --def pylsp_completions(config, document, position): -- """Get formatted completions for current code position""" -- settings = config.plugin_settings("jedi_completion", document_path=document.path) -- resolve_eagerly = settings.get("eager", False) -- code_position = _utils.position_to_jedi_linecolumn(document, position) -- -- code_position["fuzzy"] = settings.get("fuzzy", False) -- completions = document.jedi_script(use_document_path=True).complete(**code_position) -- -- if not completions: -- return None -- -- completion_capabilities = config.capabilities.get("textDocument", {}).get( -- "completion", {} -- ) -- item_capabilities = completion_capabilities.get("completionItem", {}) -- snippet_support = item_capabilities.get("snippetSupport") -- supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -- preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -- -- should_include_params = settings.get("include_params") -- should_include_class_objects = settings.get("include_class_objects", False) -- should_include_function_objects = settings.get("include_function_objects", False) -- -- max_to_resolve = settings.get("resolve_at_most", 25) -- modules_to_cache_for = settings.get("cache_for", None) -- if modules_to_cache_for is not None: -- LABEL_RESOLVER.cached_modules = modules_to_cache_for -- SNIPPET_RESOLVER.cached_modules = modules_to_cache_for -- -- include_params = ( -- snippet_support and should_include_params and use_snippets(document, position) -- ) -- include_class_objects = ( -- snippet_support -- and should_include_class_objects -- and use_snippets(document, position) -- ) -- include_function_objects = ( -- snippet_support -- and should_include_function_objects -- and use_snippets(document, position) -- ) -- -- ready_completions = [ -- _format_completion( -- c, -- markup_kind=preferred_markup_kind, -- include_params=include_params if c.type in ["class", "function"] else False, -- resolve=resolve_eagerly, -- resolve_label_or_snippet=(i < max_to_resolve), -- snippet_support=snippet_support, -- ) -- for i, c in enumerate(completions) -- ] -- -- # TODO split up once other improvements are merged -- if include_class_objects: -- for i, c in enumerate(completions): -- if c.type == "class": -- completion_dict = _format_completion( -- c, -- markup_kind=preferred_markup_kind, -- include_params=False, -- resolve=resolve_eagerly, -- resolve_label_or_snippet=(i < max_to_resolve), -- snippet_support=snippet_support, -- ) -- completion_dict["kind"] = lsp.CompletionItemKind.TypeParameter -- completion_dict["label"] += " object" -- ready_completions.append(completion_dict) -- -- if include_function_objects: -- for i, c in enumerate(completions): -- if c.type == "function": -- completion_dict = _format_completion( -- c, -- markup_kind=preferred_markup_kind, -- include_params=False, -- resolve=resolve_eagerly, -- resolve_label_or_snippet=(i < max_to_resolve), -- snippet_support=snippet_support, -- ) -- completion_dict["kind"] = lsp.CompletionItemKind.TypeParameter -- completion_dict["label"] += " object" -- ready_completions.append(completion_dict) -- -- for completion_dict in ready_completions: -- completion_dict["data"] = {"doc_uri": document.uri} -- -- # most recently retrieved completion items, used for resolution -- document.shared_data["LAST_JEDI_COMPLETIONS"] = { -- # label is the only required property; here it is assumed to be unique -- completion["label"]: (completion, data) -- for completion, data in zip(ready_completions, completions) -- } -- -- return ready_completions or None -- -- --@hookimpl --def pylsp_completion_item_resolve(config, completion_item, document): -- """Resolve formatted completion for given non-resolved completion""" -- shared_data = document.shared_data["LAST_JEDI_COMPLETIONS"].get( -- completion_item["label"] -- ) -- -- completion_capabilities = config.capabilities.get("textDocument", {}).get( -- "completion", {} -- ) -- item_capabilities = completion_capabilities.get("completionItem", {}) -- supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -- preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -- -- if shared_data: -- completion, data = shared_data -- return _resolve_completion(completion, data, markup_kind=preferred_markup_kind) -- return completion_item -- -- --def is_exception_class(name): -- """ -- Determine if a class name is an instance of an Exception. -- -- This returns `False` if the name given corresponds with a instance of -- the 'Exception' class, `True` otherwise -- """ -- try: -- return name in [cls.__name__ for cls in Exception.__subclasses__()] -- except AttributeError: -- # Needed in case a class don't uses new-style -- # class definition in Python 2 -- return False -- -- --def use_snippets(document, position): -- """ -- Determine if it's necessary to return snippets in code completions. -- -- This returns `False` if a completion is being requested on an import -- statement, `True` otherwise. -- """ -- line = position["line"] -- lines = document.source.split("\n", line) -- act_lines = [lines[line][: position["character"]]] -- line -= 1 -- last_character = "" -- while line > -1: -- act_line = lines[line] -- if ( -- act_line.rstrip().endswith("\\") -- or act_line.rstrip().endswith("(") -- or act_line.rstrip().endswith(",") -- ): -- act_lines.insert(0, act_line) -- line -= 1 -- if act_line.rstrip().endswith("("): -- # Needs to be added to the end of the code before parsing -- # to make it valid, otherwise the node type could end -- # being an 'error_node' for multi-line imports that use '(' -- last_character = ")" -- else: -- break -- if "(" in act_lines[-1].strip(): -- last_character = ")" -- code = "\n".join(act_lines).rsplit(";", maxsplit=1)[-1].strip() + last_character -- tokens = parso.parse(code) -- expr_type = tokens.children[0].type -- return expr_type not in _IMPORTS and not (expr_type in _ERRORS and "import" in code) -- -- --def _resolve_completion(completion, d, markup_kind: str): -- completion["detail"] = _detail(d) -- try: -- docs = _utils.format_docstring( -- d.docstring(raw=True), -- signatures=[signature.to_string() for signature in d.get_signatures()], -- markup_kind=markup_kind, -- ) -- except Exception: -- docs = "" -- completion["documentation"] = docs -- return completion -- -- --def _format_completion( -- d, -- markup_kind: str, -- include_params=True, -- resolve=False, -- resolve_label_or_snippet=False, -- snippet_support=False, --): -- completion = { -- "label": _label(d, resolve_label_or_snippet), -- "kind": _TYPE_MAP.get(d.type), -- "sortText": _sort_text(d), -- "insertText": d.name, -- } -- -- if resolve: -- completion = _resolve_completion(completion, d, markup_kind) -- -- # Adjustments for file completions -- if d.type == "path": -- path = os.path.normpath(d.name) -- -- # If the completion ends with os.sep, it means it's a directory. So we add os.sep at the end -- # to ease additional file completions. -- if d.name.endswith(os.sep): -- if os.name == "nt": -- path = path + "\\" -- else: -- path = path + "/" -- -- # Escape to prevent conflicts with the code snippets grammer -- # See also https://github.com/python-lsp/python-lsp-server/issues/373 -- if snippet_support: -- path = path.replace("\\", "\\\\") -- path = path.replace("/", "\\/") -- -- completion["insertText"] = path -- -- if include_params and not is_exception_class(d.name): -- snippet = _snippet(d, resolve_label_or_snippet) -- completion.update(snippet) -- -- return completion -- -- --def _label(definition, resolve=False): -- if not resolve: -- return definition.name -- sig = LABEL_RESOLVER.get_or_create(definition) -- if sig: -- return sig -- return definition.name -- -- --def _snippet(definition, resolve=False): -- if not resolve: -- return {} -- snippet = SNIPPET_RESOLVER.get_or_create(definition) -- return snippet -- -- --def _detail(definition): -- try: -- return definition.parent().full_name or "" -- except AttributeError: -- return definition.full_name or "" -- -- --def _sort_text(definition): -- """Ensure builtins appear at the bottom. -- Description is of format : . -- """ -- -- # If its 'hidden', put it next last -- prefix = "z{}" if definition.name.startswith("_") else "a{}" -- return prefix.format(definition.name) -diff --git a/pylsp/plugins/jedi_rename.py b/pylsp/plugins/jedi_rename.py -index b35e321..7f34f19 100644 ---- a/pylsp/plugins/jedi_rename.py -+++ b/pylsp/plugins/jedi_rename.py -@@ -1,56 +1,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import _num_lines - import logging - - from pylsp import _utils, hookimpl, uris -- --log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_rename(config, workspace, document, position, new_name): -- log.debug( -- "Executing rename of %s to %s", document.word_at_position(position), new_name -- ) -- kwargs = _utils.position_to_jedi_linecolumn(document, position) -- kwargs["new_name"] = new_name -- try: -- refactoring = document.jedi_script().rename(**kwargs) -- except NotImplementedError as exc: -- raise Exception( -- "No support for renaming in Python 2/3.5 with Jedi. " -- "Consider using the pylsp-rope plugin instead" -- ) from exc -- log.debug("Finished rename: %s", refactoring.get_diff()) -- changes = [] -- -- changed_files = refactoring.get_changed_files() -- for file_path, changed_file in changed_files.items(): -- uri = uris.from_fs_path(str(file_path)) -- doc = workspace.get_maybe_document(uri) -- changes.append( -- { -- "textDocument": {"uri": uri, "version": doc.version if doc else None}, -- "edits": [ -- { -- "range": { -- "start": {"line": 0, "character": 0}, -- "end": { -- "line": _num_lines(changed_file.get_new_code()), -- "character": 0, -- }, -- }, -- "newText": changed_file.get_new_code(), -- } -- ], -- } -- ) -- return {"documentChanges": changes} -- -- --def _num_lines(file_contents): -- "Count the number of lines in the given string." -- if _utils.get_eol_chars(file_contents): -- return len(file_contents.splitlines()) -- return 0 -diff --git a/pylsp/plugins/mccabe_lint.py b/pylsp/plugins/mccabe_lint.py -index 0e2cba2..d0df911 100644 ---- a/pylsp/plugins/mccabe_lint.py -+++ b/pylsp/plugins/mccabe_lint.py -@@ -1,56 +1,12 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import THRESHOLD -+from pylsp.pylsp_shared import DEFAULT_THRESHOLD - import ast - import logging - - import mccabe - - from pylsp import hookimpl, lsp -- --log = logging.getLogger(__name__) -- --THRESHOLD = "threshold" --DEFAULT_THRESHOLD = 15 -- -- --@hookimpl --def pylsp_lint(config, workspace, document): -- with workspace.report_progress("lint: mccabe"): -- threshold = config.plugin_settings("mccabe", document_path=document.path).get( -- THRESHOLD, DEFAULT_THRESHOLD -- ) -- log.debug("Running mccabe lint with threshold: %s", threshold) -- -- try: -- tree = compile(document.source, document.path, "exec", ast.PyCF_ONLY_AST) -- except SyntaxError: -- # We'll let the other linters point this one out -- return None -- -- visitor = mccabe.PathGraphingAstVisitor() -- visitor.preorder(tree, visitor) -- -- diags = [] -- for graph in visitor.graphs.values(): -- if graph.complexity() >= threshold: -- diags.append( -- { -- "source": "mccabe", -- "range": { -- "start": { -- "line": graph.lineno - 1, -- "character": graph.column, -- }, -- "end": { -- "line": graph.lineno - 1, -- "character": len(document.lines[graph.lineno]), -- }, -- }, -- "message": "Cyclomatic complexity too high: %s (threshold %s)" -- % (graph.complexity(), threshold), -- "severity": lsp.DiagnosticSeverity.Warning, -- } -- ) -- -- return diags -diff --git a/pylsp/plugins/preload_imports.py b/pylsp/plugins/preload_imports.py -index ebcd9ad..6601e71 100644 ---- a/pylsp/plugins/preload_imports.py -+++ b/pylsp/plugins/preload_imports.py -@@ -1,79 +1,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import MODULES -+from pylsp.pylsp_shared import log - import logging - - from pylsp import hookimpl -- --log = logging.getLogger(__name__) -- --MODULES = [ -- "OpenGL", -- "PIL", -- "array", -- "audioop", -- "binascii", -- "cPickle", -- "cStringIO", -- "cmath", -- "collections", -- "datetime", -- "errno", -- "exceptions", -- "gc", -- "imageop", -- "imp", -- "itertools", -- "marshal", -- "math", -- "matplotlib", -- "mmap", -- "mpmath", -- "msvcrt", -- "networkx", -- "nose", -- "nt", -- "numpy", -- "operator", -- "os", -- "os.path", -- "pandas", -- "parser", -- "rgbimg", -- "scipy", -- "signal", -- "skimage", -- "sklearn", -- "statsmodels", -- "strop", -- "sympy", -- "sys", -- "thread", -- "time", -- "wx", -- "xxsubtype", -- "zipimport", -- "zlib", --] -- -- --@hookimpl --def pylsp_settings(): -- # Setup default modules to preload, and rope extension modules -- return { -- "plugins": {"preload": {"modules": MODULES}}, -- "rope": {"extensionModules": MODULES}, -- } -- -- --@hookimpl --def pylsp_initialize(config) -> None: -- for mod_name in config.plugin_settings("preload").get("modules", []): -- try: -- __import__(mod_name) -- log.debug("Preloaded module %s", mod_name) -- except Exception: -- # Catch any exception since not only ImportError can be raised here -- # For example, old versions of NumPy can cause a ValueError. -- # See spyder-ide/spyder#13985 -- pass -diff --git a/pylsp/plugins/pycodestyle_lint.py b/pylsp/plugins/pycodestyle_lint.py -index 7a514ad..9d06708 100644 ---- a/pylsp/plugins/pycodestyle_lint.py -+++ b/pylsp/plugins/pycodestyle_lint.py -@@ -1,6 +1,9 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import _get_severity -+from pylsp.pylsp_shared import PyCodeStyleDiagnosticReport - import logging - - import pycodestyle -@@ -19,95 +22,3 @@ else: - if autopep8_c_i in pycodestyle._checks["logical_line"]: - del pycodestyle._checks["logical_line"][autopep8_c_i] - pycodestyle.register_check(pycodestyle.continued_indentation) -- --log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_lint(workspace, document): -- with workspace.report_progress("lint: pycodestyle"): -- config = workspace._config -- settings = config.plugin_settings("pycodestyle", document_path=document.path) -- log.debug("Got pycodestyle settings: %s", settings) -- -- opts = { -- "exclude": settings.get("exclude"), -- "filename": settings.get("filename"), -- "hang_closing": settings.get("hangClosing"), -- "ignore": settings.get("ignore"), -- "max_line_length": settings.get("maxLineLength"), -- "indent_size": settings.get("indentSize"), -- "select": settings.get("select"), -- } -- kwargs = {k: v for k, v in opts.items() if v} -- styleguide = pycodestyle.StyleGuide(kwargs) -- -- # Use LF to lint file because other line endings can give false positives. -- # See spyder-ide/spyder#19565 for context. -- source = document.source -- eol_chars = get_eol_chars(source) -- if eol_chars in ["\r", "\r\n"]: -- source = source.replace(eol_chars, "\n") -- lines = source.splitlines(keepends=True) -- else: -- lines = document.lines -- -- c = pycodestyle.Checker( -- filename=document.path, -- lines=lines, -- options=styleguide.options, -- report=PyCodeStyleDiagnosticReport(styleguide.options), -- ) -- c.check_all() -- diagnostics = c.report.diagnostics -- -- return diagnostics -- -- --class PyCodeStyleDiagnosticReport(pycodestyle.BaseReport): -- def __init__(self, options) -> None: -- self.diagnostics = [] -- super().__init__(options=options) -- -- def error(self, line_number, offset, text, check): -- code = text[:4] -- if self._ignore_code(code): -- return -- -- # Don't care about expected errors or warnings -- if code in self.expected: -- return -- -- # PyCodeStyle will sometimes give you an error the line after the end of the file -- # e.g. no newline at end of file -- # In that case, the end offset should just be some number ~100 -- # (because why not? There's nothing to underline anyways) -- err_range = { -- "start": {"line": line_number - 1, "character": offset}, -- "end": { -- # FIXME: It's a little naiive to mark until the end of the line, can we not easily do better? -- "line": line_number - 1, -- "character": 100 -- if line_number > len(self.lines) -- else len(self.lines[line_number - 1]), -- }, -- } -- diagnostic = { -- "source": "pycodestyle", -- "range": err_range, -- "message": text, -- "code": code, -- # Are style errors really ever errors? -- "severity": _get_severity(code), -- } -- if code.startswith("W6"): -- diagnostic["tags"] = [lsp.DiagnosticTag.Deprecated] -- self.diagnostics.append(diagnostic) -- -- --def _get_severity(code): -- # Are style errors ever really errors? -- if code[0] == "E" or code[0] == "W": -- return lsp.DiagnosticSeverity.Warning -- # If no severity is specified, why wouldn't this be informational only? -- return lsp.DiagnosticSeverity.Information -diff --git a/pylsp/plugins/pydocstyle_lint.py b/pylsp/plugins/pydocstyle_lint.py -index a310ac8..e000d4a 100644 ---- a/pylsp/plugins/pydocstyle_lint.py -+++ b/pylsp/plugins/pydocstyle_lint.py -@@ -1,6 +1,11 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import DEFAULT_MATCH_RE -+from pylsp.pylsp_shared import DEFAULT_MATCH_DIR_RE -+from pylsp.pylsp_shared import _parse_diagnostic -+from pylsp.pylsp_shared import _patch_sys_argv - import contextlib - import logging - import os -@@ -11,117 +16,6 @@ import pydocstyle - - from pylsp import hookimpl, lsp - --log = logging.getLogger(__name__) -- - # PyDocstyle is a little verbose in debug message - pydocstyle_logger = logging.getLogger(pydocstyle.utils.__name__) - pydocstyle_logger.setLevel(logging.INFO) -- --DEFAULT_MATCH_RE = pydocstyle.config.ConfigurationParser.DEFAULT_MATCH_RE --DEFAULT_MATCH_DIR_RE = pydocstyle.config.ConfigurationParser.DEFAULT_MATCH_DIR_RE -- -- --@hookimpl --def pylsp_settings(): -- # Default pydocstyle to disabled -- return {"plugins": {"pydocstyle": {"enabled": False}}} -- -- --@hookimpl --def pylsp_lint(config, workspace, document): -- with workspace.report_progress("lint: pydocstyle"): -- settings = config.plugin_settings("pydocstyle", document_path=document.path) -- log.debug("Got pydocstyle settings: %s", settings) -- -- # Explicitly passing a path to pydocstyle means it doesn't respect the --match flag, so do it ourselves -- filename_match_re = re.compile(settings.get("match", DEFAULT_MATCH_RE) + "$") -- if not filename_match_re.match(os.path.basename(document.path)): -- return [] -- -- # Likewise with --match-dir -- dir_match_re = re.compile(settings.get("matchDir", DEFAULT_MATCH_DIR_RE) + "$") -- if not dir_match_re.match(os.path.basename(os.path.dirname(document.path))): -- return [] -- -- args = [document.path] -- -- if settings.get("convention"): -- args.append("--convention=" + settings["convention"]) -- -- if settings.get("addSelect"): -- args.append("--add-select=" + ",".join(settings["addSelect"])) -- if settings.get("addIgnore"): -- args.append("--add-ignore=" + ",".join(settings["addIgnore"])) -- -- elif settings.get("select"): -- args.append("--select=" + ",".join(settings["select"])) -- elif settings.get("ignore"): -- args.append("--ignore=" + ",".join(settings["ignore"])) -- -- log.info("Using pydocstyle args: %s", args) -- -- conf = pydocstyle.config.ConfigurationParser() -- with _patch_sys_argv(args): -- # TODO(gatesn): We can add more pydocstyle args here from our pylsp config -- conf.parse() -- -- # Will only yield a single filename, the document path -- diags = [] -- for ( -- filename, -- checked_codes, -- ignore_decorators, -- property_decorators, -- ignore_self_only_init, -- ) in conf.get_files_to_check(): -- errors = pydocstyle.checker.ConventionChecker().check_source( -- document.source, -- filename, -- ignore_decorators=ignore_decorators, -- property_decorators=property_decorators, -- ignore_self_only_init=ignore_self_only_init, -- ) -- -- try: -- for error in errors: -- if error.code not in checked_codes: -- continue -- diags.append(_parse_diagnostic(document, error)) -- except pydocstyle.parser.ParseError: -- # In the case we cannot parse the Python file, just continue -- pass -- -- log.debug("Got pydocstyle errors: %s", diags) -- return diags -- -- --def _parse_diagnostic(document, error): -- lineno = error.definition.start - 1 -- line = document.lines[0] if document.lines else "" -- -- start_character = len(line) - len(line.lstrip()) -- end_character = len(line) -- -- return { -- "source": "pydocstyle", -- "code": error.code, -- "message": error.message, -- "severity": lsp.DiagnosticSeverity.Warning, -- "range": { -- "start": {"line": lineno, "character": start_character}, -- "end": {"line": lineno, "character": end_character}, -- }, -- } -- -- --@contextlib.contextmanager --def _patch_sys_argv(arguments) -> None: -- old_args = sys.argv -- -- # Preserve argv[0] since it's the executable -- sys.argv = old_args[0:1] + arguments -- -- try: -- yield -- finally: -- sys.argv = old_args -diff --git a/pylsp/plugins/pyflakes_lint.py b/pylsp/plugins/pyflakes_lint.py -index 8a04276..f0332ba 100644 ---- a/pylsp/plugins/pyflakes_lint.py -+++ b/pylsp/plugins/pyflakes_lint.py -@@ -1,97 +1,9 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import PYFLAKES_ERROR_MESSAGES -+from pylsp.pylsp_shared import PyflakesDiagnosticReport - from pyflakes import api as pyflakes_api - from pyflakes import messages - - from pylsp import hookimpl, lsp -- --# Pyflakes messages that should be reported as Errors instead of Warns --PYFLAKES_ERROR_MESSAGES = ( -- messages.UndefinedName, -- messages.UndefinedExport, -- messages.UndefinedLocal, -- messages.DuplicateArgument, -- messages.FutureFeatureNotDefined, -- messages.ReturnOutsideFunction, -- messages.YieldOutsideFunction, -- messages.ContinueOutsideLoop, -- messages.BreakOutsideLoop, -- messages.TwoStarredExpressions, --) -- -- --@hookimpl --def pylsp_lint(workspace, document): -- with workspace.report_progress("lint: pyflakes"): -- reporter = PyflakesDiagnosticReport(document.lines) -- pyflakes_api.check( -- document.source.encode("utf-8"), document.path, reporter=reporter -- ) -- return reporter.diagnostics -- -- --class PyflakesDiagnosticReport: -- def __init__(self, lines) -> None: -- self.lines = lines -- self.diagnostics = [] -- -- def unexpectedError(self, _filename, msg) -> None: # pragma: no cover -- err_range = { -- "start": {"line": 0, "character": 0}, -- "end": {"line": 0, "character": 0}, -- } -- self.diagnostics.append( -- { -- "source": "pyflakes", -- "range": err_range, -- "message": msg, -- "severity": lsp.DiagnosticSeverity.Error, -- } -- ) -- -- def syntaxError(self, _filename, msg, lineno, offset, text) -> None: -- # We've seen that lineno and offset can sometimes be None -- lineno = lineno or 1 -- offset = offset or 0 -- # could be None if the error is due to an invalid encoding -- # see e.g. https://github.com/python-lsp/python-lsp-server/issues/429 -- text = text or "" -- -- err_range = { -- "start": {"line": lineno - 1, "character": offset}, -- "end": {"line": lineno - 1, "character": offset + len(text)}, -- } -- self.diagnostics.append( -- { -- "source": "pyflakes", -- "range": err_range, -- "message": msg, -- "severity": lsp.DiagnosticSeverity.Error, -- } -- ) -- -- def flake(self, message) -> None: -- """Get message like :: """ -- err_range = { -- "start": {"line": message.lineno - 1, "character": message.col}, -- "end": { -- "line": message.lineno - 1, -- "character": len(self.lines[message.lineno - 1]), -- }, -- } -- -- severity = lsp.DiagnosticSeverity.Warning -- for message_type in PYFLAKES_ERROR_MESSAGES: -- if isinstance(message, message_type): -- severity = lsp.DiagnosticSeverity.Error -- break -- -- self.diagnostics.append( -- { -- "source": "pyflakes", -- "range": err_range, -- "message": message.message % message.message_args, -- "severity": severity, -- } -- ) -diff --git a/pylsp/plugins/pylint_lint.py b/pylsp/plugins/pylint_lint.py -index beffe6f..5ff622a 100644 ---- a/pylsp/plugins/pylint_lint.py -+++ b/pylsp/plugins/pylint_lint.py -@@ -4,6 +4,13 @@ - - """Linter plugin for pylint.""" - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import DEPRECATION_CODES -+from pylsp.pylsp_shared import UNNECESSITY_CODES -+from pylsp.pylsp_shared import PylintLinter -+from pylsp.pylsp_shared import _build_pylint_flags -+from pylsp.pylsp_shared import build_args_stdio -+from pylsp.pylsp_shared import pylint_lint_stdin - import collections - import logging - import os -@@ -19,8 +26,6 @@ try: - except Exception: - import json - --log = logging.getLogger(__name__) -- - # Pylint fails to suppress STDOUT when importing whitelisted C - # extensions, mangling their output into the expected JSON which breaks the - # parser. The most prominent example (and maybe the only one out there) is -@@ -29,326 +34,3 @@ log = logging.getLogger(__name__) - # fix for a very specific upstream issue. - # Related: https://github.com/PyCQA/pylint/issues/3518 - os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide" --DEPRECATION_CODES = { -- "W0402", # Uses of a deprecated module %r -- "W1505", # Using deprecated method %s() -- "W1511", # Using deprecated argument %s of method %s() -- "W1512", # Using deprecated class %s of module %s -- "W1513", # Using deprecated decorator %s() --} --UNNECESSITY_CODES = { -- "W0611", # Unused import %s -- "W0612", # Unused variable %r -- "W0613", # Unused argument %r -- "W0614", # Unused import %s from wildcard import -- "W1304", # Unused-format-string-argument --} -- -- --class PylintLinter: -- last_diags = collections.defaultdict(list) -- -- @classmethod -- def lint(cls, document, is_saved, flags=""): -- """Plugin interface to pylsp linter. -- -- Args: -- document: The document to be linted. -- is_saved: Whether or not the file has been saved to disk. -- flags: Additional flags to pass to pylint. Not exposed to -- pylsp_lint, but used for testing. -- -- Returns: -- A list of dicts with the following format: -- -- { -- 'source': 'pylint', -- 'range': { -- 'start': { -- 'line': start_line, -- 'character': start_column, -- }, -- 'end': { -- 'line': end_line, -- 'character': end_column, -- }, -- } -- 'message': msg, -- 'severity': lsp.DiagnosticSeverity.*, -- } -- """ -- if not is_saved: -- # Pylint can only be run on files that have been saved to disk. -- # Rather than return nothing, return the previous list of -- # diagnostics. If we return an empty list, any diagnostics we'd -- # previously shown will be cleared until the next save. Instead, -- # continue showing (possibly stale) diagnostics until the next -- # save. -- return cls.last_diags[document.path] -- -- cmd = [ -- sys.executable, -- "-c", -- "import sys; from pylint.lint import Run; Run(sys.argv[1:])", -- "-f", -- "json", -- document.path, -- ] + (shlex.split(str(flags)) if flags else []) -- log.debug("Calling pylint with '%s'", " ".join(cmd)) -- -- cwd = document._workspace.root_path -- if not cwd: -- cwd = os.path.dirname(__file__) -- -- with Popen( -- cmd, stdout=PIPE, stderr=PIPE, cwd=cwd, universal_newlines=True -- ) as process: -- json_out, err = process.communicate() -- -- if err != "": -- log.error("Error calling pylint: '%s'", err) -- -- # pylint prints nothing rather than [] when there are no diagnostics. -- # json.loads will not parse an empty string, so just return. -- if not json_out.strip(): -- cls.last_diags[document.path] = [] -- return [] -- -- # Pylint's JSON output is a list of objects with the following format. -- # -- # { -- # "obj": "main", -- # "path": "foo.py", -- # "message": "Missing function docstring", -- # "message-id": "C0111", -- # "symbol": "missing-docstring", -- # "column": 0, -- # "type": "convention", -- # "line": 5, -- # "module": "foo" -- # } -- # -- # The type can be any of: -- # -- # * convention -- # * information -- # * error -- # * fatal -- # * refactor -- # * warning -- diagnostics = [] -- for diag in json.loads(json_out): -- # pylint lines index from 1, pylsp lines index from 0 -- line = diag["line"] - 1 -- -- err_range = { -- "start": { -- "line": line, -- # Index columns start from 0 -- "character": diag["column"], -- }, -- "end": { -- "line": line, -- # It's possible that we're linting an empty file. Even an empty -- # file might fail linting if it isn't named properly. -- "character": len(document.lines[line]) if document.lines else 0, -- }, -- } -- -- if diag["type"] == "convention": -- severity = lsp.DiagnosticSeverity.Information -- elif diag["type"] == "information": -- severity = lsp.DiagnosticSeverity.Information -- elif diag["type"] == "error": -- severity = lsp.DiagnosticSeverity.Error -- elif diag["type"] == "fatal": -- severity = lsp.DiagnosticSeverity.Error -- elif diag["type"] == "refactor": -- severity = lsp.DiagnosticSeverity.Hint -- elif diag["type"] == "warning": -- severity = lsp.DiagnosticSeverity.Warning -- -- code = diag["message-id"] -- -- diagnostic = { -- "source": "pylint", -- "range": err_range, -- "message": "[{}] {}".format(diag["symbol"], diag["message"]), -- "severity": severity, -- "code": code, -- } -- -- if code in UNNECESSITY_CODES: -- diagnostic["tags"] = [lsp.DiagnosticTag.Unnecessary] -- if code in DEPRECATION_CODES: -- diagnostic["tags"] = [lsp.DiagnosticTag.Deprecated] -- -- diagnostics.append(diagnostic) -- cls.last_diags[document.path] = diagnostics -- return diagnostics -- -- --def _build_pylint_flags(settings): -- """Build arguments for calling pylint.""" -- pylint_args = settings.get("args") -- if pylint_args is None: -- return "" -- return " ".join(pylint_args) -- -- --@hookimpl --def pylsp_settings(): -- # Default pylint to disabled because it requires a config -- # file to be useful. -- return { -- "plugins": { -- "pylint": { -- "enabled": False, -- "args": [], -- # disabled by default as it can slow down the workflow -- "executable": None, -- } -- } -- } -- -- --@hookimpl --def pylsp_lint(config, workspace, document, is_saved): -- """Run pylint linter.""" -- with workspace.report_progress("lint: pylint"): -- settings = config.plugin_settings("pylint") -- log.debug("Got pylint settings: %s", settings) -- # pylint >= 2.5.0 is required for working through stdin and only -- # available with python3 -- if settings.get("executable") and sys.version_info[0] >= 3: -- flags = build_args_stdio(settings) -- pylint_executable = settings.get("executable", "pylint") -- return pylint_lint_stdin(pylint_executable, document, flags) -- flags = _build_pylint_flags(settings) -- return PylintLinter.lint(document, is_saved, flags=flags) -- -- --def build_args_stdio(settings): -- """Build arguments for calling pylint. -- -- :param settings: client settings -- :type settings: dict -- -- :return: arguments to path to pylint -- :rtype: list -- """ -- pylint_args = settings.get("args") -- if pylint_args is None: -- return [] -- return pylint_args -- -- --def pylint_lint_stdin(pylint_executable, document, flags): -- """Run pylint linter from stdin. -- -- This runs pylint in a subprocess with popen. -- This allows passing the file from stdin and as a result -- run pylint on unsaved files. Can slowdown the workflow. -- -- :param pylint_executable: path to pylint executable -- :type pylint_executable: string -- :param document: document to run pylint on -- :type document: pylsp.workspace.Document -- :param flags: arguments to path to pylint -- :type flags: list -- -- :return: linting diagnostics -- :rtype: list -- """ -- pylint_result = _run_pylint_stdio(pylint_executable, document, flags) -- return _parse_pylint_stdio_result(document, pylint_result) -- -- --def _run_pylint_stdio(pylint_executable, document, flags): -- """Run pylint in popen. -- -- :param pylint_executable: path to pylint executable -- :type pylint_executable: string -- :param document: document to run pylint on -- :type document: pylsp.workspace.Document -- :param flags: arguments to path to pylint -- :type flags: list -- -- :return: result of calling pylint -- :rtype: string -- """ -- log.debug("Calling %s with args: '%s'", pylint_executable, flags) -- try: -- cmd = [pylint_executable] -- cmd.extend(flags) -- cmd.extend(["--from-stdin", document.path]) -- p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) -- except IOError: -- log.debug("Can't execute %s. Trying with 'python -m pylint'", pylint_executable) -- cmd = [sys.executable, "-m", "pylint"] -- cmd.extend(flags) -- cmd.extend(["--from-stdin", document.path]) -- p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) -- (stdout, stderr) = p.communicate(document.source.encode()) -- if stderr: -- log.error("Error while running pylint '%s'", stderr.decode()) -- return stdout.decode() -- -- --def _parse_pylint_stdio_result(document, stdout): -- """Parse pylint results. -- -- :param document: document to run pylint on -- :type document: pylsp.workspace.Document -- :param stdout: pylint results to parse -- :type stdout: string -- -- :return: linting diagnostics -- :rtype: list -- """ -- diagnostics = [] -- lines = stdout.splitlines() -- for raw_line in lines: -- parsed_line = re.match(r"(.*):(\d*):(\d*): (\w*): (.*)", raw_line) -- if not parsed_line: -- log.debug("Pylint output parser can't parse line '%s'", raw_line) -- continue -- -- parsed_line = parsed_line.groups() -- if len(parsed_line) != 5: -- log.debug("Pylint output parser can't parse line '%s'", raw_line) -- continue -- -- _, line, character, code, msg = parsed_line -- line = int(line) - 1 -- character = int(character) -- severity_map = { -- "C": lsp.DiagnosticSeverity.Information, -- "E": lsp.DiagnosticSeverity.Error, -- "F": lsp.DiagnosticSeverity.Error, -- "I": lsp.DiagnosticSeverity.Information, -- "R": lsp.DiagnosticSeverity.Hint, -- "W": lsp.DiagnosticSeverity.Warning, -- } -- severity = severity_map[code[0]] -- diagnostic = { -- "source": "pylint", -- "code": code, -- "range": { -- "start": {"line": line, "character": character}, -- "end": { -- "line": line, -- # no way to determine the column -- "character": len(document.lines[line]) - 1, -- }, -- }, -- "message": msg, -- "severity": severity, -- } -- if code in UNNECESSITY_CODES: -- diagnostic["tags"] = [lsp.DiagnosticTag.Unnecessary] -- if code in DEPRECATION_CODES: -- diagnostic["tags"] = [lsp.DiagnosticTag.Deprecated] -- diagnostics.append(diagnostic) -- -- return diagnostics -diff --git a/pylsp/plugins/references.py b/pylsp/plugins/references.py -index a4c61b5..514f60e 100644 ---- a/pylsp/plugins/references.py -+++ b/pylsp/plugins/references.py -@@ -6,28 +6,3 @@ import logging - from pylsp import _utils, hookimpl, uris - - log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_references(document, position, exclude_declaration): -- code_position = _utils.position_to_jedi_linecolumn(document, position) -- usages = document.jedi_script().get_references(**code_position) -- -- if exclude_declaration: -- # Filter out if the usage is the actual declaration of the thing -- usages = [d for d in usages if not d.is_definition()] -- -- # Filter out builtin modules -- return [ -- { -- "uri": uris.uri_with(document.uri, path=str(d.module_path)) -- if d.module_path -- else document.uri, -- "range": { -- "start": {"line": d.line - 1, "character": d.column}, -- "end": {"line": d.line - 1, "character": d.column + len(d.name)}, -- }, -- } -- for d in usages -- if not d.in_builtin_module() -- ] -diff --git a/pylsp/plugins/rope_autoimport.py b/pylsp/plugins/rope_autoimport.py -index 12f5d80..5447f5b 100644 ---- a/pylsp/plugins/rope_autoimport.py -+++ b/pylsp/plugins/rope_autoimport.py -@@ -1,5 +1,16 @@ - # Copyright 2022- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import MAX_RESULTS_COMPLETIONS -+from pylsp.pylsp_shared import _should_insert -+from pylsp.pylsp_shared import _score_pow -+from pylsp.pylsp_shared import _score_max -+from pylsp.pylsp_shared import _process_statements -+from pylsp.pylsp_shared import get_names -+from pylsp.pylsp_shared import AutoimportCache -+from pylsp.pylsp_shared import cache -+from pylsp.pylsp_shared import MAX_RESULTS_CODE_ACTIONS -+from pylsp.pylsp_shared import get_name_or_module - import logging - import threading - from typing import Any, Dict, Generator, List, Optional, Set, Union -@@ -17,391 +28,3 @@ from pylsp.config.config import Config - from pylsp.workspace import Document, Workspace - - from ._rope_task_handle import PylspTaskHandle -- --log = logging.getLogger(__name__) -- --_score_pow = 5 --_score_max = 10**_score_pow --MAX_RESULTS_COMPLETIONS = 1000 --MAX_RESULTS_CODE_ACTIONS = 5 -- -- --class AutoimportCache: -- """Handles the cache creation.""" -- -- def __init__(self) -> None: -- self.thread = None -- -- def reload_cache( -- self, -- config: Config, -- workspace: Workspace, -- files: Optional[List[Document]] = None, -- single_thread: Optional[bool] = True, -- ): -- if self.is_blocked(): -- return -- -- memory: bool = config.plugin_settings("rope_autoimport").get("memory", False) -- rope_config = config.settings().get("rope", {}) -- autoimport = workspace._rope_autoimport(rope_config, memory) -- resources: Optional[List[Resource]] = ( -- None -- if files is None -- else [document._rope_resource(rope_config) for document in files] -- ) -- -- if single_thread: -- self._reload_cache(workspace, autoimport, resources) -- else: -- # Creating the cache may take 10-20s for a environment with 5k python modules. That's -- # why we decided to move cache creation into its own thread. -- self.thread = threading.Thread( -- target=self._reload_cache, args=(workspace, autoimport, resources) -- ) -- self.thread.start() -- -- def _reload_cache( -- self, -- workspace: Workspace, -- autoimport: AutoImport, -- resources: Optional[List[Resource]] = None, -- ) -> None: -- task_handle = PylspTaskHandle(workspace) -- autoimport.generate_cache(task_handle=task_handle, resources=resources) -- autoimport.generate_modules_cache(task_handle=task_handle) -- -- def is_blocked(self): -- return self.thread and self.thread.is_alive() -- -- --@hookimpl --def pylsp_settings() -> Dict[str, Dict[str, Dict[str, Any]]]: -- # Default rope_completion to disabled -- return { -- "plugins": { -- "rope_autoimport": { -- "enabled": False, -- "memory": False, -- "completions": { -- "enabled": True, -- }, -- "code_actions": { -- "enabled": True, -- }, -- } -- } -- } -- -- --def _should_insert(expr: tree.BaseNode, word_node: tree.Leaf) -> bool: -- """ -- Check if we should insert the word_node on the given expr. -- -- Works for both correct and incorrect code. This is because the -- user is often working on the code as they write it. -- """ -- if not word_node: -- return False -- if len(expr.children) == 0: -- return True -- first_child = expr.children[0] -- if isinstance(first_child, tree.EndMarker): -- if "#" in first_child.prefix: -- return False # Check for single line comment -- if first_child == word_node: -- return True # If the word is the first word then its fine -- if len(expr.children) > 1: -- if any( -- node.type == "operator" and "." in node.value or node.type == "trailer" -- for node in expr.children -- ): -- return False # Check if we're on a method of a function -- if isinstance(first_child, (tree.PythonErrorNode, tree.PythonNode)): -- # The tree will often include error nodes like this to indicate errors -- # we want to ignore errors since the code is being written -- return _should_insert(first_child, word_node) -- return _handle_first_child(first_child, expr, word_node) -- -- --def _handle_first_child( -- first_child: NodeOrLeaf, expr: tree.BaseNode, word_node: tree.Leaf --) -> bool: -- """Check if we suggest imports given the following first child.""" -- if isinstance(first_child, tree.Import): -- return False -- if isinstance(first_child, (tree.PythonLeaf, tree.PythonErrorLeaf)): -- # Check if the first item is a from or import statement even when incomplete -- if first_child.value in ("import", "from"): -- return False -- if isinstance(first_child, tree.Keyword): -- if first_child.value == "def": -- return _should_import_function(word_node, expr) -- if first_child.value == "class": -- return _should_import_class(word_node, expr) -- return True -- -- --def _should_import_class(word_node: tree.Leaf, expr: tree.BaseNode) -> bool: -- prev_node = None -- for node in expr.children: -- if isinstance(node, tree.Name): -- if isinstance(prev_node, tree.Operator): -- if node == word_node and prev_node.value == "(": -- return True -- prev_node = node -- -- return False -- -- --def _should_import_function(word_node: tree.Leaf, expr: tree.BaseNode) -> bool: -- prev_node = None -- for node in expr.children: -- if _handle_argument(node, word_node): -- return True -- if isinstance(prev_node, tree.Operator): -- if prev_node.value == "->": -- if node == word_node: -- return True -- prev_node = node -- return False -- -- --def _handle_argument(node: NodeOrLeaf, word_node: tree.Leaf): -- if isinstance(node, tree.PythonNode): -- if node.type == "tfpdef": -- if node.children[2] == word_node: -- return True -- if node.type == "parameters": -- for parameter in node.children: -- if _handle_argument(parameter, word_node): -- return True -- return False -- -- --def _process_statements( -- suggestions: List[SearchResult], -- doc_uri: str, -- word: str, -- autoimport: AutoImport, -- document: Document, -- feature: str = "completions", --) -> Generator[Dict[str, Any], None, None]: -- for suggestion in suggestions: -- insert_line = autoimport.find_insertion_line(document.source) - 1 -- start = {"line": insert_line, "character": 0} -- edit_range = {"start": start, "end": start} -- edit = {"range": edit_range, "newText": suggestion.import_statement + "\n"} -- score = _get_score( -- suggestion.source, suggestion.import_statement, suggestion.name, word -- ) -- if score > _score_max: -- continue -- # TODO make this markdown -- if feature == "completions": -- yield { -- "label": suggestion.name, -- "kind": suggestion.itemkind, -- "sortText": _sort_import(score), -- "data": {"doc_uri": doc_uri}, -- "detail": _document(suggestion.import_statement), -- "additionalTextEdits": [edit], -- } -- elif feature == "code_actions": -- yield { -- "title": suggestion.import_statement, -- "kind": "quickfix", -- "edit": {"changes": {doc_uri: [edit]}}, -- # data is a supported field for codeAction responses -- # See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_codeAction -- "data": {"sortText": _sort_import(score)}, -- } -- else: -- raise ValueError(f"Unknown feature: {feature}") -- -- --def get_names(script: Script) -> Set[str]: -- """Get all names to ignore from the current file.""" -- raw_names = script.get_names(definitions=True) -- log.debug(raw_names) -- return {name.name for name in raw_names} -- -- --@hookimpl --def pylsp_completions( -- config: Config, -- workspace: Workspace, -- document: Document, -- position, -- ignored_names: Union[Set[str], None], --): -- """Get autoimport suggestions.""" -- if ( -- not config.plugin_settings("rope_autoimport") -- .get("completions", {}) -- .get("enabled", True) -- ) or cache.is_blocked(): -- return [] -- -- line = document.lines[position["line"]] -- expr = parso.parse(line) -- word_node = expr.get_leaf_for_position((1, position["character"])) -- if not _should_insert(expr, word_node): -- return [] -- word = word_node.value -- log.debug(f"autoimport: searching for word: {word}") -- rope_config = config.settings(document_path=document.path).get("rope", {}) -- ignored_names: Set[str] = ignored_names or get_names( -- document.jedi_script(use_document_path=True) -- ) -- autoimport = workspace._rope_autoimport(rope_config) -- suggestions = list(autoimport.search_full(word, ignored_names=ignored_names)) -- results = sorted( -- _process_statements( -- suggestions, document.uri, word, autoimport, document, "completions" -- ), -- key=lambda statement: statement["sortText"], -- ) -- if len(results) > MAX_RESULTS_COMPLETIONS: -- results = results[:MAX_RESULTS_COMPLETIONS] -- return results -- -- --def _document(import_statement: str) -> str: -- return """# Auto-Import\n""" + import_statement -- -- --def _get_score( -- source: int, full_statement: str, suggested_name: str, desired_name --) -> int: -- import_length = len("import") -- full_statement_score = len(full_statement) - import_length -- suggested_name_score = (len(suggested_name) - len(desired_name)) ** 2 -- source_score = 20 * source -- return suggested_name_score + full_statement_score + source_score -- -- --def _sort_import(score: int) -> str: -- score = max(min(score, (_score_max) - 1), 0) -- # Since we are using ints, we need to pad them. -- # We also want to prioritize autoimport behind everything since its the last priority. -- # The minimum is to prevent score from overflowing the pad -- return "[z" + str(score).rjust(_score_pow, "0") -- -- --def get_name_or_module(document, diagnostic) -> str: -- start = diagnostic["range"]["start"] -- return ( -- parso.parse(document.lines[start["line"]]) -- .get_leaf_for_position((1, start["character"] + 1)) -- .value -- ) -- -- --@hookimpl --def pylsp_code_actions( -- config: Config, -- workspace: Workspace, -- document: Document, -- range: Dict, -- context: Dict, --) -> List[Dict]: -- """ -- Provide code actions through rope. -- -- Parameters -- ---------- -- config : pylsp.config.config.Config -- Current config. -- workspace : pylsp.workspace.Workspace -- Current workspace. -- document : pylsp.workspace.Document -- Document to apply code actions on. -- range : Dict -- Range argument given by pylsp. Not used here. -- context : Dict -- CodeActionContext given as dict. -- -- Returns -- ------- -- List of dicts containing the code actions. -- """ -- if ( -- not config.plugin_settings("rope_autoimport") -- .get("code_actions", {}) -- .get("enabled", True) -- ) or cache.is_blocked(): -- return [] -- -- log.debug(f"textDocument/codeAction: {document} {range} {context}") -- code_actions = [] -- for diagnostic in context.get("diagnostics", []): -- if "undefined name" not in diagnostic.get("message", "").lower(): -- continue -- -- word = get_name_or_module(document, diagnostic) -- log.debug(f"autoimport: searching for word: {word}") -- rope_config = config.settings(document_path=document.path).get("rope", {}) -- autoimport = workspace._rope_autoimport(rope_config) -- suggestions = list(autoimport.search_full(word)) -- log.debug("autoimport: suggestions: %s", suggestions) -- results = sorted( -- _process_statements( -- suggestions, -- document.uri, -- word, -- autoimport, -- document, -- "code_actions", -- ), -- key=lambda statement: statement["data"]["sortText"], -- ) -- -- if len(results) > MAX_RESULTS_CODE_ACTIONS: -- results = results[:MAX_RESULTS_CODE_ACTIONS] -- code_actions.extend(results) -- -- return code_actions -- -- --@hookimpl --def pylsp_initialize(config: Config, workspace: Workspace) -> None: -- """Initialize AutoImport. -- -- Generates the cache for local and global items. -- """ -- cache.reload_cache(config, workspace) -- -- --@hookimpl --def pylsp_document_did_open(config: Config, workspace: Workspace) -> None: -- """Initialize AutoImport. -- -- Generates the cache for local and global items. -- """ -- cache.reload_cache(config, workspace) -- -- --@hookimpl --def pylsp_document_did_save( -- config: Config, workspace: Workspace, document: Document --) -> None: -- """Update the names associated with this document.""" -- cache.reload_cache(config, workspace, [document]) -- -- --@hookimpl --def pylsp_workspace_configuration_changed(config: Config, workspace: Workspace) -> None: -- """ -- Initialize autoimport if it has been enabled through a -- workspace/didChangeConfiguration message from the frontend. -- -- Generates the cache for local and global items. -- """ -- if config.plugin_settings("rope_autoimport").get("enabled", False): -- cache.reload_cache(config, workspace) -- else: -- log.debug("autoimport: Skipping cache reload.") -- -- --cache: AutoimportCache = AutoimportCache() -diff --git a/pylsp/plugins/rope_completion.py b/pylsp/plugins/rope_completion.py -index b3a1f06..7317474 100644 ---- a/pylsp/plugins/rope_completion.py -+++ b/pylsp/plugins/rope_completion.py -@@ -1,161 +1,12 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import _resolve_completion -+from pylsp.pylsp_shared import _sort_text -+from pylsp.pylsp_shared import _kind - import logging - - from rope.contrib.codeassist import code_assist, sorted_proposals - - from pylsp import _utils, hookimpl, lsp -- --log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_settings(): -- # Default rope_completion to disabled -- return {"plugins": {"rope_completion": {"enabled": False, "eager": False}}} -- -- --def _resolve_completion(completion, data, markup_kind): -- try: -- doc = _utils.format_docstring(data.get_doc(), markup_kind=markup_kind) -- except Exception as e: -- log.debug("Failed to resolve Rope completion: %s", e) -- doc = "" -- completion["detail"] = "{0} {1}".format(data.scope or "", data.name) -- completion["documentation"] = doc -- return completion -- -- --@hookimpl --def pylsp_completions(config, workspace, document, position): -- settings = config.plugin_settings("rope_completion", document_path=document.path) -- resolve_eagerly = settings.get("eager", False) -- -- # Rope is a bit rubbish at completing module imports, so we'll return None -- word = document.word_at_position( -- { -- # The -1 should really be trying to look at the previous word, but that might be quite expensive -- # So we only skip import completions when the cursor is one space after `import` -- "line": position["line"], -- "character": max(position["character"] - 1, 0), -- } -- ) -- if word == "import": -- return None -- -- offset = document.offset_at_position(position) -- rope_config = config.settings(document_path=document.path).get("rope", {}) -- rope_project = workspace._rope_project_builder(rope_config) -- document_rope = document._rope_resource(rope_config) -- -- completion_capabilities = config.capabilities.get("textDocument", {}).get( -- "completion", {} -- ) -- item_capabilities = completion_capabilities.get("completionItem", {}) -- supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -- preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -- -- try: -- definitions = code_assist( -- rope_project, document.source, offset, document_rope, maxfixes=3 -- ) -- except Exception as e: -- log.debug("Failed to run Rope code assist: %s", e) -- return [] -- -- definitions = sorted_proposals(definitions) -- new_definitions = [] -- for d in definitions: -- item = { -- "label": d.name, -- "kind": _kind(d), -- "sortText": _sort_text(d), -- "data": {"doc_uri": document.uri}, -- } -- if resolve_eagerly: -- item = _resolve_completion(item, d, preferred_markup_kind) -- new_definitions.append(item) -- -- # most recently retrieved completion items, used for resolution -- document.shared_data["LAST_ROPE_COMPLETIONS"] = { -- # label is the only required property; here it is assumed to be unique -- completion["label"]: (completion, data) -- for completion, data in zip(new_definitions, definitions) -- } -- -- definitions = new_definitions -- -- return definitions or None -- -- --@hookimpl --def pylsp_completion_item_resolve(config, completion_item, document): -- """Resolve formatted completion for given non-resolved completion""" -- shared_data = document.shared_data["LAST_ROPE_COMPLETIONS"].get( -- completion_item["label"] -- ) -- -- completion_capabilities = config.capabilities.get("textDocument", {}).get( -- "completion", {} -- ) -- item_capabilities = completion_capabilities.get("completionItem", {}) -- supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -- preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -- -- if shared_data: -- completion, data = shared_data -- return _resolve_completion(completion, data, preferred_markup_kind) -- return completion_item -- -- --def _sort_text(definition): -- """Ensure builtins appear at the bottom. -- Description is of format : . -- """ -- if definition.name.startswith("_"): -- # It's a 'hidden' func, put it next last -- return "z" + definition.name -- if definition.scope == "builtin": -- return "y" + definition.name -- -- # Else put it at the front -- return "a" + definition.name -- -- --def _kind(d): -- """Return the LSP type""" -- MAP = { -- "none": lsp.CompletionItemKind.Value, -- "type": lsp.CompletionItemKind.Class, -- "tuple": lsp.CompletionItemKind.Class, -- "dict": lsp.CompletionItemKind.Class, -- "dictionary": lsp.CompletionItemKind.Class, -- "function": lsp.CompletionItemKind.Function, -- "lambda": lsp.CompletionItemKind.Function, -- "generator": lsp.CompletionItemKind.Function, -- "class": lsp.CompletionItemKind.Class, -- "instance": lsp.CompletionItemKind.Reference, -- "method": lsp.CompletionItemKind.Method, -- "builtin": lsp.CompletionItemKind.Class, -- "builtinfunction": lsp.CompletionItemKind.Function, -- "module": lsp.CompletionItemKind.Module, -- "file": lsp.CompletionItemKind.File, -- "xrange": lsp.CompletionItemKind.Class, -- "slice": lsp.CompletionItemKind.Class, -- "traceback": lsp.CompletionItemKind.Class, -- "frame": lsp.CompletionItemKind.Class, -- "buffer": lsp.CompletionItemKind.Class, -- "dictproxy": lsp.CompletionItemKind.Class, -- "funcdef": lsp.CompletionItemKind.Function, -- "property": lsp.CompletionItemKind.Property, -- "import": lsp.CompletionItemKind.Module, -- "keyword": lsp.CompletionItemKind.Keyword, -- "constant": lsp.CompletionItemKind.Variable, -- "variable": lsp.CompletionItemKind.Variable, -- "value": lsp.CompletionItemKind.Value, -- "param": lsp.CompletionItemKind.Variable, -- "statement": lsp.CompletionItemKind.Keyword, -- } -- -- return MAP.get(d.type) -diff --git a/pylsp/plugins/signature.py b/pylsp/plugins/signature.py -index 7ad5b20..58d3417 100644 ---- a/pylsp/plugins/signature.py -+++ b/pylsp/plugins/signature.py -@@ -1,81 +1,10 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import _param_docs - import logging - import re - - from pylsp import _utils, hookimpl - - log = logging.getLogger(__name__) -- --SPHINX = re.compile(r"\s*:param\s+(?P\w+):\s*(?P[^\n]+)") --EPYDOC = re.compile(r"\s*@param\s+(?P\w+):\s*(?P[^\n]+)") --GOOGLE = re.compile(r"\s*(?P\w+).*:\s*(?P[^\n]+)") -- --DOC_REGEX = [SPHINX, EPYDOC, GOOGLE] -- -- --@hookimpl --def pylsp_signature_help(config, document, position): -- code_position = _utils.position_to_jedi_linecolumn(document, position) -- signatures = document.jedi_script().get_signatures(**code_position) -- -- if not signatures: -- return {"signatures": []} -- -- signature_capabilities = config.capabilities.get("textDocument", {}).get( -- "signatureHelp", {} -- ) -- signature_information_support = signature_capabilities.get( -- "signatureInformation", {} -- ) -- supported_markup_kinds = signature_information_support.get( -- "documentationFormat", ["markdown"] -- ) -- preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -- -- s = signatures[0] -- -- docstring = s.docstring() -- -- # Docstring contains one or more lines of signature, followed by empty line, followed by docstring -- function_sig_lines = (docstring.split("\n\n") or [""])[0].splitlines() -- function_sig = " ".join([line.strip() for line in function_sig_lines]) -- sig = { -- "label": function_sig, -- "documentation": _utils.format_docstring( -- s.docstring(raw=True), markup_kind=preferred_markup_kind -- ), -- } -- -- # If there are params, add those -- if s.params: -- sig["parameters"] = [ -- { -- "label": p.name, -- "documentation": _utils.format_docstring( -- _param_docs(docstring, p.name), markup_kind=preferred_markup_kind -- ), -- } -- for p in s.params -- ] -- -- # We only return a single signature because Python doesn't allow overloading -- sig_info = {"signatures": [sig], "activeSignature": 0} -- -- if s.index is not None and s.params: -- # Then we know which parameter we're looking at -- sig_info["activeParameter"] = s.index -- -- return sig_info -- -- --def _param_docs(docstring, param_name): -- for line in docstring.splitlines(): -- for regex in DOC_REGEX: -- m = regex.match(line) -- if not m: -- continue -- if m.group("param") != param_name: -- continue -- return m.group("doc") or "" -diff --git a/pylsp/plugins/symbols.py b/pylsp/plugins/symbols.py -index 4e1890c..24344a0 100644 ---- a/pylsp/plugins/symbols.py -+++ b/pylsp/plugins/symbols.py -@@ -1,6 +1,12 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import _include_def -+from pylsp.pylsp_shared import _container -+from pylsp.pylsp_shared import _range -+from pylsp.pylsp_shared import _tuple_range -+from pylsp.pylsp_shared import _SYMBOL_KIND_MAP -+from pylsp.pylsp_shared import _kind - import logging - from pathlib import Path - -@@ -8,207 +14,3 @@ from pylsp import hookimpl - from pylsp.lsp import SymbolKind - - log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_document_symbols(config, document): -- symbols_settings = config.plugin_settings("jedi_symbols") -- all_scopes = symbols_settings.get("all_scopes", True) -- add_import_symbols = symbols_settings.get("include_import_symbols", True) -- definitions = document.jedi_names(all_scopes=all_scopes) -- symbols = [] -- exclude = set({}) -- redefinitions = {} -- -- while definitions != []: -- d = definitions.pop(0) -- -- # Skip symbols imported from other modules. -- if not add_import_symbols: -- # Skip if there's an import in the code the symbol is defined. -- code = d.get_line_code() -- if " import " in code or "import " in code: -- continue -- -- # Skip imported symbols comparing module names. -- sym_full_name = d.full_name -- if sym_full_name is not None: -- document_dot_path = document.dot_path -- -- # We assume a symbol is imported from another module to start -- # with. -- imported_symbol = True -- -- # The last element of sym_full_name is the symbol itself, so -- # we need to discard it to do module comparisons below. -- if "." in sym_full_name: -- sym_module_name = sym_full_name.rpartition(".")[0] -- else: -- sym_module_name = sym_full_name -- -- # This is necessary to display symbols in init files (the checks -- # below fail without it). -- if document_dot_path.endswith("__init__"): -- document_dot_path = document_dot_path.rpartition(".")[0] -- -- # document_dot_path is the module where the symbol is imported, -- # whereas sym_module_name is the one where it was declared. -- if document_dot_path in sym_module_name: -- # If document_dot_path is in sym_module_name, we can safely assume -- # that the symbol was declared in the document. -- imported_symbol = False -- elif sym_module_name.split(".")[0] in document_dot_path.split("."): -- # If the first module in sym_module_name is one of the modules in -- # document_dot_path, we need to check if sym_module_name starts -- # with the modules in document_dot_path. -- document_mods = document_dot_path.split(".") -- for i in range(1, len(document_mods) + 1): -- submod = ".".join(document_mods[-i:]) -- if sym_module_name.startswith(submod): -- imported_symbol = False -- break -- -- # When there's no __init__.py next to a file or in one of its -- # parents, the checks above fail. However, Jedi has a nice way -- # to tell if the symbol was declared in the same file: if -- # sym_module_name starts by __main__. -- if imported_symbol: -- if not sym_module_name.startswith("__main__"): -- continue -- else: -- # We need to skip symbols if their definition doesn't have `full_name` info, they -- # are detected as a definition, but their description (e.g. `class Foo`) doesn't -- # match the code where they're detected by Jedi. This happens for relative imports. -- if _include_def(d): -- if d.description not in d.get_line_code(): -- continue -- else: -- continue -- -- if _include_def(d) and Path(document.path) == Path(d.module_path): -- tuple_range = _tuple_range(d) -- if tuple_range in exclude: -- continue -- -- kind = redefinitions.get(tuple_range, None) -- if kind is not None: -- exclude |= {tuple_range} -- -- if d.type == "statement": -- if d.description.startswith("self"): -- kind = "field" -- -- symbol = { -- "name": d.name, -- "containerName": _container(d), -- "location": { -- "uri": document.uri, -- "range": _range(d), -- }, -- "kind": _kind(d) if kind is None else _SYMBOL_KIND_MAP[kind], -- } -- symbols.append(symbol) -- -- if d.type == "class": -- try: -- defined_names = list(d.defined_names()) -- for method in defined_names: -- if method.type == "function": -- redefinitions[_tuple_range(method)] = "method" -- elif method.type == "statement": -- redefinitions[_tuple_range(method)] = "field" -- else: -- redefinitions[_tuple_range(method)] = method.type -- definitions = list(defined_names) + definitions -- except Exception: -- pass -- return symbols -- -- --def _include_def(definition): -- return ( -- # Don't tend to include parameters as symbols -- definition.type != "param" -- and -- # Unused vars should also be skipped -- definition.name != "_" -- and _kind(definition) is not None -- ) -- -- --def _container(definition): -- try: -- # Jedi sometimes fails here. -- parent = definition.parent() -- # Here we check that a grand-parent exists to avoid declaring symbols -- # as children of the module. -- if parent.parent(): -- return parent.name -- except: -- return None -- -- return None -- -- --def _range(definition): -- # This gets us more accurate end position -- definition = definition._name.tree_name.get_definition() -- (start_line, start_column) = definition.start_pos -- (end_line, end_column) = definition.end_pos -- return { -- "start": {"line": start_line - 1, "character": start_column}, -- "end": {"line": end_line - 1, "character": end_column}, -- } -- -- --def _tuple_range(definition): -- definition = definition._name.tree_name.get_definition() -- return (definition.start_pos, definition.end_pos) -- -- --_SYMBOL_KIND_MAP = { -- "none": SymbolKind.Variable, -- "type": SymbolKind.Class, -- "tuple": SymbolKind.Class, -- "dict": SymbolKind.Class, -- "dictionary": SymbolKind.Class, -- "function": SymbolKind.Function, -- "lambda": SymbolKind.Function, -- "generator": SymbolKind.Function, -- "class": SymbolKind.Class, -- "instance": SymbolKind.Class, -- "method": SymbolKind.Method, -- "builtin": SymbolKind.Class, -- "builtinfunction": SymbolKind.Function, -- "module": SymbolKind.Module, -- "file": SymbolKind.File, -- "xrange": SymbolKind.Array, -- "slice": SymbolKind.Class, -- "traceback": SymbolKind.Class, -- "frame": SymbolKind.Class, -- "buffer": SymbolKind.Array, -- "dictproxy": SymbolKind.Class, -- "funcdef": SymbolKind.Function, -- "property": SymbolKind.Property, -- "import": SymbolKind.Module, -- "keyword": SymbolKind.Variable, -- "constant": SymbolKind.Constant, -- "variable": SymbolKind.Variable, -- "value": SymbolKind.Variable, -- "param": SymbolKind.Variable, -- "statement": SymbolKind.Variable, -- "boolean": SymbolKind.Boolean, -- "int": SymbolKind.Number, -- "longlean": SymbolKind.Number, -- "float": SymbolKind.Number, -- "complex": SymbolKind.Number, -- "string": SymbolKind.String, -- "unicode": SymbolKind.String, -- "list": SymbolKind.Array, -- "field": SymbolKind.Field, --} -- -- --def _kind(d): -- """Return the VSCode Symbol Type""" -- return _SYMBOL_KIND_MAP.get(d.type) -diff --git a/pylsp/plugins/yapf_format.py b/pylsp/plugins/yapf_format.py -index 72aa740..363f2ab 100644 ---- a/pylsp/plugins/yapf_format.py -+++ b/pylsp/plugins/yapf_format.py -@@ -1,6 +1,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import _format - import logging - import os - -@@ -10,190 +12,3 @@ from yapf.yapflib.yapf_api import FormatCode - - from pylsp import hookimpl - from pylsp._utils import get_eol_chars -- --log = logging.getLogger(__name__) -- -- --@hookimpl --def pylsp_format_document(workspace, document, options): -- log.info("Formatting document %s with yapf", document) -- with workspace.report_progress("format: yapf"): -- return _format(document, options=options) -- -- --@hookimpl --def pylsp_format_range(document, range, options): -- log.info("Formatting document %s in range %s with yapf", document, range) -- # First we 'round' the range up/down to full lines only -- range["start"]["character"] = 0 -- range["end"]["line"] += 1 -- range["end"]["character"] = 0 -- -- # From Yapf docs: -- # lines: (list of tuples of integers) A list of tuples of lines, [start, end], -- # that we want to format. The lines are 1-based indexed. It can be used by -- # third-party code (e.g., IDEs) when reformatting a snippet of code rather -- # than a whole file. -- -- # Add 1 for 1-indexing vs LSP's 0-indexing -- lines = [(range["start"]["line"] + 1, range["end"]["line"] + 1)] -- return _format(document, lines=lines, options=options) -- -- --def get_style_config(document_path, options=None): -- # Exclude file if it follows the patterns for that -- exclude_patterns_from_ignore_file = file_resources.GetExcludePatternsForDir( -- os.getcwd() -- ) -- if file_resources.IsIgnored(document_path, exclude_patterns_from_ignore_file): -- return [] -- -- # Get the default styles as a string -- # for a preset configuration, i.e. "pep8" -- style_config = file_resources.GetDefaultStyleForDir(os.path.dirname(document_path)) -- if options is None: -- return style_config -- -- # We have options passed from LSP format request -- # let's pass them to the formatter. -- # First we want to get a dictionary of the preset style -- # to pass instead of a string so that we can modify it -- style_config = style.CreateStyleFromConfig(style_config) -- -- use_tabs = style_config["USE_TABS"] -- indent_width = style_config["INDENT_WIDTH"] -- -- if options.get("tabSize") is not None: -- indent_width = max(int(options.get("tabSize")), 1) -- -- if options.get("insertSpaces") is not None: -- # TODO is it guaranteed to be a boolean, or can it be a string? -- use_tabs = not options.get("insertSpaces") -- -- if use_tabs: -- # Indent width doesn't make sense when using tabs -- # the specifications state: "Size of a tab in spaces" -- indent_width = 1 -- -- style_config["USE_TABS"] = use_tabs -- style_config["INDENT_WIDTH"] = indent_width -- style_config["CONTINUATION_INDENT_WIDTH"] = indent_width -- -- for style_option, value in options.items(): -- # Apply arbitrary options passed as formatter options -- if style_option not in style_config: -- # ignore if it's not a known yapf config -- continue -- -- style_config[style_option] = value -- -- return style_config -- -- --def diff_to_text_edits(diff, eol_chars): -- # To keep things simple our text edits will be line based. -- # We will also return the edits uncompacted, meaning a -- # line replacement will come in as a line remove followed -- # by a line add instead of a line replace. -- text_edits = [] -- # keep track of line number since additions -- # don't include the line number it's being added -- # to in diffs. lsp is 0-indexed so we'll start with -1 -- prev_line_no = -1 -- -- for change in diff.changes: -- if change.old and change.new: -- # old and new are the same line, no change -- # diffs are 1-indexed -- prev_line_no = change.old - 1 -- elif change.new: -- # addition -- text_edits.append( -- { -- "range": { -- "start": {"line": prev_line_no + 1, "character": 0}, -- "end": {"line": prev_line_no + 1, "character": 0}, -- }, -- "newText": change.line + eol_chars, -- } -- ) -- elif change.old: -- # remove -- lsp_line_no = change.old - 1 -- text_edits.append( -- { -- "range": { -- "start": {"line": lsp_line_no, "character": 0}, -- "end": { -- # From LSP spec: -- # If you want to specify a range that contains a line -- # including the line ending character(s) then use an -- # end position denoting the start of the next line. -- "line": lsp_line_no + 1, -- "character": 0, -- }, -- }, -- "newText": "", -- } -- ) -- prev_line_no = lsp_line_no -- -- return text_edits -- -- --def ensure_eof_new_line(document, eol_chars, text_edits): -- # diffs don't include EOF newline https://github.com/google/yapf/issues/1008 -- # we'll add it ourselves if our document doesn't already have it and the diff -- # does not change the last line. -- if document.source.endswith(eol_chars): -- return -- -- lines = document.lines -- last_line_number = len(lines) - 1 -- -- if text_edits and text_edits[-1]["range"]["start"]["line"] >= last_line_number: -- return -- -- text_edits.append( -- { -- "range": { -- "start": {"line": last_line_number, "character": 0}, -- "end": {"line": last_line_number + 1, "character": 0}, -- }, -- "newText": lines[-1] + eol_chars, -- } -- ) -- -- --def _format(document, lines=None, options=None): -- source = document.source -- # Yapf doesn't work with CRLF/CR line endings, so we replace them by '\n' -- # and restore them below when adding new lines -- eol_chars = get_eol_chars(source) -- if eol_chars in ["\r", "\r\n"]: -- source = source.replace(eol_chars, "\n") -- else: -- eol_chars = "\n" -- -- style_config = get_style_config(document_path=document.path, options=options) -- -- diff_txt, changed = FormatCode( -- source, -- lines=lines, -- filename=document.filename, -- print_diff=True, -- style_config=style_config, -- ) -- -- if not changed: -- return [] -- -- patch_generator = whatthepatch.parse_patch(diff_txt) -- diff = next(patch_generator) -- patch_generator.close() -- -- text_edits = diff_to_text_edits(diff=diff, eol_chars=eol_chars) -- -- ensure_eof_new_line(document=document, eol_chars=eol_chars, text_edits=text_edits) -- -- return text_edits -diff --git a/pylsp/pylsp_shared.py b/pylsp/pylsp_shared.py -new file mode 100644 -index 0000000..24f571a ---- /dev/null -+++ b/pylsp/pylsp_shared.py -@@ -0,0 +1,3622 @@ -+import logging -+import pycodestyle -+from autopep8 import continued_indentation as autopep8_c_i -+from autopep8 import fix_code -+from pylsp._utils import get_eol_chars -+from typing import TYPE_CHECKING, Any, Dict, List -+import jedi -+from pylsp import hookimpl -+from pylsp import uris -+from pylsp import _utils -+from pylsp.config.config import Config -+from pylsp.workspace import Document -+from jedi.api import Script -+from jedi.api.classes import Name -+import threading -+import uuid -+from typing import Any, Dict, List -+from pylsp_jsonrpc.dispatchers import MethodDispatcher -+from pylsp_jsonrpc.endpoint import Endpoint -+from pylsp_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter -+from . import _utils, lsp, uris -+from ._version import __version__ -+from .config import config -+from .workspace import Cell, Document, Notebook, Workspace -+from io import StringIO -+import pytest -+from pylsp_jsonrpc.exceptions import JsonRpcException -+import os.path -+from pathlib import PurePath -+import sys -+from subprocess import PIPE, Popen -+import re -+from pylsp import lsp -+import parso -+import parso.python.tree as tree_nodes -+from pylsp import hookspec -+from pylsp.plugins._resolvers import LABEL_RESOLVER -+from pylsp.plugins._resolvers import SNIPPET_RESOLVER -+import os -+import ast -+import mccabe -+import pydocstyle -+import contextlib -+from pyflakes import api as pyflakes_api -+from pyflakes import messages -+import collections -+import shlex -+import json -+from typing import Any, Dict, Generator, List, Optional, Set, Union -+from pylsp.workspace import Workspace -+from parso.python import tree -+from parso.tree import NodeOrLeaf -+from rope.contrib.autoimport.defs import SearchResult -+from rope.contrib.autoimport.sqlite import AutoImport -+from jedi import Script -+from rope.base.resources import Resource -+from ._rope_task_handle import PylspTaskHandle -+from rope.contrib.codeassist import code_assist, sorted_proposals -+from pathlib import Path -+from pylsp.lsp import SymbolKind -+import whatthepatch -+from yapf.yapflib.yapf_api import FormatCode -+from yapf.yapflib import file_resources, style -+ -+ -+log = logging.getLogger(__name__) -+ -+def _autopep8_config(config, document=None): -+ # We user pycodestyle settings to avoid redefining things -+ path = document.path if document is not None else None -+ settings = config.plugin_settings("pycodestyle", document_path=path) -+ options = { -+ "exclude": settings.get("exclude"), -+ "hang_closing": settings.get("hangClosing"), -+ "ignore": settings.get("ignore"), -+ "max_line_length": settings.get("maxLineLength"), -+ "select": settings.get("select"), -+ "aggressive": settings.get("aggressive"), -+ } -+ -+ # Filter out null options -+ return {k: v for k, v in options.items() if v} -+ -+def _format(config, document, line_range=None): -+ options = _autopep8_config(config, document) -+ if line_range: -+ options["line_range"] = list(line_range) -+ -+ # Temporarily re-monkey-patch the continued_indentation checker - #771 -+ del pycodestyle._checks["logical_line"][pycodestyle.continued_indentation] -+ pycodestyle.register_check(autopep8_c_i) -+ -+ # Autopep8 doesn't work with CR line endings, so we replace them by '\n' -+ # and restore them below. -+ replace_cr = False -+ source = document.source -+ eol_chars = get_eol_chars(source) -+ if eol_chars == "\r": -+ replace_cr = True -+ source = source.replace("\r", "\n") -+ -+ new_source = fix_code(source, options=options) -+ -+ # Switch it back -+ del pycodestyle._checks["logical_line"][autopep8_c_i] -+ pycodestyle.register_check(pycodestyle.continued_indentation) -+ -+ if new_source == source: -+ return [] -+ -+ if replace_cr: -+ new_source = new_source.replace("\n", "\r") -+ -+ # I'm too lazy at the moment to parse diffs into TextEdit items -+ # So let's just return the entire file... -+ return [ -+ { -+ "range": { -+ "start": {"line": 0, "character": 0}, -+ # End char 0 of the line after our document -+ "end": {"line": len(document.lines), "character": 0}, -+ }, -+ "newText": new_source, -+ } -+ ] -+ -+@hookimpl(tryfirst=True) # Prefer autopep8 over YAPF -+def pylsp_format_document(config, workspace, document, options): -+ with workspace.report_progress("format: autopep8"): -+ log.info("Formatting document %s with autopep8", document) -+ return _format(config, document) -+ -+@hookimpl(tryfirst=True) # Prefer autopep8 over YAPF -+def pylsp_format_range(config, workspace, document, range, options): -+ log.info("Formatting document %s in range %s with autopep8", document, range) -+ -+ # First we 'round' the range up/down to full lines only -+ range["start"]["character"] = 0 -+ range["end"]["line"] += 1 -+ range["end"]["character"] = 0 -+ -+ # Add 1 for 1-indexing vs LSP's 0-indexing -+ line_range = (range["start"]["line"] + 1, range["end"]["line"]) -+ return _format(config, document, line_range=line_range) -+ -+MAX_JEDI_GOTO_HOPS = 100 -+ -+def _resolve_definition( -+ maybe_defn: Name, script: Script, settings: Dict[str, Any] -+) -> Name: -+ for _ in range(MAX_JEDI_GOTO_HOPS): -+ if maybe_defn.is_definition() or maybe_defn.module_path != script.path: -+ break -+ defns = script.goto( -+ follow_imports=settings.get("follow_imports", True), -+ follow_builtin_imports=settings.get("follow_builtin_imports", True), -+ line=maybe_defn.line, -+ column=maybe_defn.column, -+ ) -+ if len(defns) == 1: -+ maybe_defn = defns[0] -+ else: -+ break -+ return maybe_defn -+ -+def _not_internal_definition(definition: Name) -> bool: -+ return ( -+ definition.line is not None -+ and definition.column is not None -+ and definition.module_path is not None -+ and not definition.in_builtin_module() -+ ) -+ -+@hookimpl -+def pylsp_definitions( -+ config: Config, document: Document, position: Dict[str, int] -+) -> List[Dict[str, Any]]: -+ settings = config.plugin_settings("jedi_definition") -+ code_position = _utils.position_to_jedi_linecolumn(document, position) -+ script = document.jedi_script(use_document_path=True) -+ auto_import_modules = jedi.settings.auto_import_modules -+ -+ try: -+ jedi.settings.auto_import_modules = [] -+ definitions = script.goto( -+ follow_imports=settings.get("follow_imports", True), -+ follow_builtin_imports=settings.get("follow_builtin_imports", True), -+ **code_position, -+ ) -+ definitions = [_resolve_definition(d, script, settings) for d in definitions] -+ finally: -+ jedi.settings.auto_import_modules = auto_import_modules -+ -+ follow_builtin_defns = settings.get("follow_builtin_definitions", True) -+ return [ -+ { -+ "uri": uris.uri_with(document.uri, path=str(d.module_path)), -+ "range": { -+ "start": {"line": d.line - 1, "character": d.column}, -+ "end": {"line": d.line - 1, "character": d.column + len(d.name)}, -+ }, -+ } -+ for d in definitions -+ if d.is_definition() and (follow_builtin_defns or _not_internal_definition(d)) -+ ] -+ -+LINT_DEBOUNCE_S = 0.5 -+# 500 ms -+ -+# 500 ms -+PARENT_PROCESS_WATCH_INTERVAL = 10 -+# 10 s -+ -+# 10 s -+MAX_WORKERS = 64 -+ -+PYTHON_FILE_EXTENSIONS = (".py", ".pyi") -+ -+CONFIG_FILEs = ("pycodestyle.cfg", "setup.cfg", "tox.ini", ".flake8") -+ -+def flatten(list_of_lists): -+ return [item for lst in list_of_lists for item in lst] -+ -+def merge(list_of_dicts): -+ return {k: v for dictionary in list_of_dicts for k, v in dictionary.items()} -+ -+class PythonLSPServer(MethodDispatcher): -+ """Implementation of the Microsoft VSCode Language Server Protocol -+ https://github.com/Microsoft/language-server-protocol/blob/master/versions/protocol-1-x.md -+ """ -+ -+ def __init__( -+ self, rx, tx, check_parent_process=False, consumer=None, *, endpoint_cls=None -+ ) -> None: -+ self.workspace = None -+ self.config = None -+ self.root_uri = None -+ self.watching_thread = None -+ self.workspaces = {} -+ self.uri_workspace_mapper = {} -+ -+ self._check_parent_process = check_parent_process -+ -+ if rx is not None: -+ self._jsonrpc_stream_reader = JsonRpcStreamReader(rx) -+ else: -+ self._jsonrpc_stream_reader = None -+ -+ if tx is not None: -+ self._jsonrpc_stream_writer = JsonRpcStreamWriter(tx) -+ else: -+ self._jsonrpc_stream_writer = None -+ -+ endpoint_cls = endpoint_cls or Endpoint -+ -+ # if consumer is None, it is assumed that the default streams-based approach is being used -+ if consumer is None: -+ self._endpoint = endpoint_cls( -+ self, self._jsonrpc_stream_writer.write, max_workers=MAX_WORKERS -+ ) -+ else: -+ self._endpoint = endpoint_cls(self, consumer, max_workers=MAX_WORKERS) -+ -+ self._dispatchers = [] -+ self._shutdown = False -+ -+ def start(self) -> None: -+ """Entry point for the server.""" -+ self._jsonrpc_stream_reader.listen(self._endpoint.consume) -+ -+ def consume(self, message) -> None: -+ """Entry point for consumer based server. Alternative to stream listeners.""" -+ # assuming message will be JSON -+ self._endpoint.consume(message) -+ -+ def __getitem__(self, item): -+ """Override getitem to fallback through multiple dispatchers.""" -+ if self._shutdown and item != "exit": -+ # exit is the only allowed method during shutdown -+ log.debug("Ignoring non-exit method during shutdown: %s", item) -+ item = "invalid_request_after_shutdown" -+ -+ try: -+ return super().__getitem__(item) -+ except KeyError: -+ # Fallback through extra dispatchers -+ for dispatcher in self._dispatchers: -+ try: -+ return dispatcher[item] -+ except KeyError: -+ continue -+ -+ raise KeyError() -+ -+ def m_shutdown(self, **_kwargs) -> None: -+ for workspace in self.workspaces.values(): -+ workspace.close() -+ self._shutdown = True -+ -+ def m_invalid_request_after_shutdown(self, **_kwargs): -+ return { -+ "error": { -+ "code": lsp.ErrorCodes.InvalidRequest, -+ "message": "Requests after shutdown are not valid", -+ } -+ } -+ -+ def m_exit(self, **_kwargs) -> None: -+ self._endpoint.shutdown() -+ if self._jsonrpc_stream_reader is not None: -+ self._jsonrpc_stream_reader.close() -+ if self._jsonrpc_stream_writer is not None: -+ self._jsonrpc_stream_writer.close() -+ -+ def _match_uri_to_workspace(self, uri): -+ workspace_uri = _utils.match_uri_to_workspace(uri, self.workspaces) -+ return self.workspaces.get(workspace_uri, self.workspace) -+ -+ def _hook(self, hook_name, doc_uri=None, **kwargs): -+ """Calls hook_name and returns a list of results from all registered handlers""" -+ workspace = self._match_uri_to_workspace(doc_uri) -+ doc = workspace.get_document(doc_uri) if doc_uri else None -+ hook_handlers = self.config.plugin_manager.subset_hook_caller( -+ hook_name, self.config.disabled_plugins -+ ) -+ return hook_handlers( -+ config=self.config, workspace=workspace, document=doc, **kwargs -+ ) -+ -+ def capabilities(self): -+ server_capabilities = { -+ "codeActionProvider": True, -+ "codeLensProvider": { -+ "resolveProvider": False, # We may need to make this configurable -+ }, -+ "completionProvider": { -+ "resolveProvider": True, # We could know everything ahead of time, but this takes time to transfer -+ "triggerCharacters": ["."], -+ }, -+ "documentFormattingProvider": True, -+ "documentHighlightProvider": True, -+ "documentRangeFormattingProvider": True, -+ "documentSymbolProvider": True, -+ "definitionProvider": True, -+ "executeCommandProvider": { -+ "commands": flatten(self._hook("pylsp_commands")) -+ }, -+ "hoverProvider": True, -+ "referencesProvider": True, -+ "renameProvider": True, -+ "foldingRangeProvider": True, -+ "signatureHelpProvider": {"triggerCharacters": ["(", ",", "="]}, -+ "textDocumentSync": { -+ "change": lsp.TextDocumentSyncKind.INCREMENTAL, -+ "save": { -+ "includeText": True, -+ }, -+ "openClose": True, -+ }, -+ "notebookDocumentSync": { -+ "notebookSelector": [{"cells": [{"language": "python"}]}] -+ }, -+ "workspace": { -+ "workspaceFolders": {"supported": True, "changeNotifications": True} -+ }, -+ "experimental": merge(self._hook("pylsp_experimental_capabilities")), -+ } -+ log.info("Server capabilities: %s", server_capabilities) -+ return server_capabilities -+ -+ def m_initialize( -+ self, -+ processId=None, -+ rootUri=None, -+ rootPath=None, -+ initializationOptions=None, -+ workspaceFolders=None, -+ **_kwargs, -+ ): -+ log.debug( -+ "Language server initialized with %s %s %s %s", -+ processId, -+ rootUri, -+ rootPath, -+ initializationOptions, -+ ) -+ if rootUri is None: -+ rootUri = uris.from_fs_path(rootPath) if rootPath is not None else "" -+ -+ self.workspaces.pop(self.root_uri, None) -+ self.root_uri = rootUri -+ self.config = config.Config( -+ rootUri, -+ initializationOptions or {}, -+ processId, -+ _kwargs.get("capabilities", {}), -+ ) -+ self.workspace = Workspace(rootUri, self._endpoint, self.config) -+ self.workspaces[rootUri] = self.workspace -+ if workspaceFolders: -+ for folder in workspaceFolders: -+ uri = folder["uri"] -+ if uri == rootUri: -+ # Already created -+ continue -+ workspace_config = config.Config( -+ uri, -+ self.config._init_opts, -+ self.config._process_id, -+ self.config._capabilities, -+ ) -+ workspace_config.update(self.config._settings) -+ self.workspaces[uri] = Workspace(uri, self._endpoint, workspace_config) -+ -+ self._dispatchers = self._hook("pylsp_dispatchers") -+ self._hook("pylsp_initialize") -+ -+ if ( -+ self._check_parent_process -+ and processId is not None -+ and self.watching_thread is None -+ ): -+ -+ def watch_parent_process(pid): -+ # exit when the given pid is not alive -+ if not _utils.is_process_alive(pid): -+ log.info("parent process %s is not alive, exiting!", pid) -+ self.m_exit() -+ else: -+ threading.Timer( -+ PARENT_PROCESS_WATCH_INTERVAL, watch_parent_process, args=[pid] -+ ).start() -+ -+ self.watching_thread = threading.Thread( -+ target=watch_parent_process, args=(processId,) -+ ) -+ self.watching_thread.daemon = True -+ self.watching_thread.start() -+ # Get our capabilities -+ return { -+ "capabilities": self.capabilities(), -+ "serverInfo": { -+ "name": "pylsp", -+ "version": __version__, -+ }, -+ } -+ -+ def m_initialized(self, **_kwargs) -> None: -+ self._hook("pylsp_initialized") -+ -+ def code_actions(self, doc_uri: str, range: Dict, context: Dict): -+ return flatten( -+ self._hook("pylsp_code_actions", doc_uri, range=range, context=context) -+ ) -+ -+ def code_lens(self, doc_uri): -+ return flatten(self._hook("pylsp_code_lens", doc_uri)) -+ -+ def completions(self, doc_uri, position): -+ workspace = self._match_uri_to_workspace(doc_uri) -+ document = workspace.get_document(doc_uri) -+ ignored_names = None -+ if isinstance(document, Cell): -+ # We need to get the ignored names from the whole notebook document -+ notebook_document = workspace.get_maybe_document(document.notebook_uri) -+ ignored_names = notebook_document.jedi_names(doc_uri) -+ completions = self._hook( -+ "pylsp_completions", doc_uri, position=position, ignored_names=ignored_names -+ ) -+ return {"isIncomplete": False, "items": flatten(completions)} -+ -+ def completion_item_resolve(self, completion_item): -+ doc_uri = completion_item.get("data", {}).get("doc_uri", None) -+ return self._hook( -+ "pylsp_completion_item_resolve", doc_uri, completion_item=completion_item -+ ) -+ -+ def definitions(self, doc_uri, position): -+ return flatten(self._hook("pylsp_definitions", doc_uri, position=position)) -+ -+ def document_symbols(self, doc_uri): -+ return flatten(self._hook("pylsp_document_symbols", doc_uri)) -+ -+ def document_did_save(self, doc_uri): -+ return self._hook("pylsp_document_did_save", doc_uri) -+ -+ def execute_command(self, command, arguments): -+ return self._hook("pylsp_execute_command", command=command, arguments=arguments) -+ -+ def format_document(self, doc_uri, options): -+ return lambda: self._hook("pylsp_format_document", doc_uri, options=options) -+ -+ def format_range(self, doc_uri, range, options): -+ return self._hook("pylsp_format_range", doc_uri, range=range, options=options) -+ -+ def highlight(self, doc_uri, position): -+ return ( -+ flatten(self._hook("pylsp_document_highlight", doc_uri, position=position)) -+ or None -+ ) -+ -+ def hover(self, doc_uri, position): -+ return self._hook("pylsp_hover", doc_uri, position=position) or {"contents": ""} -+ -+ @_utils.debounce(LINT_DEBOUNCE_S, keyed_by="doc_uri") -+ def lint(self, doc_uri, is_saved) -> None: -+ # Since we're debounced, the document may no longer be open -+ workspace = self._match_uri_to_workspace(doc_uri) -+ document_object = workspace.documents.get(doc_uri, None) -+ if isinstance(document_object, Document): -+ self._lint_text_document( -+ doc_uri, workspace, is_saved, document_object.version -+ ) -+ elif isinstance(document_object, Notebook): -+ self._lint_notebook_document(document_object, workspace) -+ -+ def _lint_text_document( -+ self, doc_uri, workspace, is_saved, doc_version=None -+ ) -> None: -+ workspace.publish_diagnostics( -+ doc_uri, -+ flatten(self._hook("pylsp_lint", doc_uri, is_saved=is_saved)), -+ doc_version, -+ ) -+ -+ def _lint_notebook_document(self, notebook_document, workspace) -> None: -+ """ -+ Lint a notebook document. -+ -+ This is a bit more complicated than linting a text document, because we need to -+ send the entire notebook document to the pylsp_lint hook, but we need to send -+ the diagnostics back to the client on a per-cell basis. -+ """ -+ -+ # First, we create a temp TextDocument that represents the whole notebook -+ # contents. We'll use this to send to the pylsp_lint hook. -+ random_uri = str(uuid.uuid4()) -+ -+ # cell_list helps us map the diagnostics back to the correct cell later. -+ cell_list: List[Dict[str, Any]] = [] -+ -+ offset = 0 -+ total_source = "" -+ for cell in notebook_document.cells: -+ cell_uri = cell["document"] -+ cell_document = workspace.get_cell_document(cell_uri) -+ -+ num_lines = cell_document.line_count -+ -+ data = { -+ "uri": cell_uri, -+ "line_start": offset, -+ "line_end": offset + num_lines - 1, -+ "source": cell_document.source, -+ } -+ -+ cell_list.append(data) -+ if offset == 0: -+ total_source = cell_document.source -+ else: -+ total_source += "\n" + cell_document.source -+ -+ offset += num_lines -+ -+ workspace.put_document(random_uri, total_source) -+ -+ try: -+ document_diagnostics = flatten( -+ self._hook("pylsp_lint", random_uri, is_saved=True) -+ ) -+ -+ # Now we need to map the diagnostics back to the correct cell and publish them. -+ # Note: this is O(n*m) in the number of cells and diagnostics, respectively. -+ for cell in cell_list: -+ cell_diagnostics = [] -+ for diagnostic in document_diagnostics: -+ start_line = diagnostic["range"]["start"]["line"] -+ end_line = diagnostic["range"]["end"]["line"] -+ -+ if start_line > cell["line_end"] or end_line < cell["line_start"]: -+ continue -+ diagnostic["range"]["start"]["line"] = ( -+ start_line - cell["line_start"] -+ ) -+ diagnostic["range"]["end"]["line"] = end_line - cell["line_start"] -+ cell_diagnostics.append(diagnostic) -+ -+ workspace.publish_diagnostics(cell["uri"], cell_diagnostics) -+ finally: -+ workspace.rm_document(random_uri) -+ -+ def references(self, doc_uri, position, exclude_declaration): -+ return flatten( -+ self._hook( -+ "pylsp_references", -+ doc_uri, -+ position=position, -+ exclude_declaration=exclude_declaration, -+ ) -+ ) -+ -+ def rename(self, doc_uri, position, new_name): -+ return self._hook("pylsp_rename", doc_uri, position=position, new_name=new_name) -+ -+ def signature_help(self, doc_uri, position): -+ return self._hook("pylsp_signature_help", doc_uri, position=position) -+ -+ def folding(self, doc_uri): -+ return flatten(self._hook("pylsp_folding_range", doc_uri)) -+ -+ def m_completion_item__resolve(self, **completionItem): -+ return self.completion_item_resolve(completionItem) -+ -+ def m_notebook_document__did_open( -+ self, notebookDocument=None, cellTextDocuments=None, **_kwargs -+ ) -> None: -+ workspace = self._match_uri_to_workspace(notebookDocument["uri"]) -+ workspace.put_notebook_document( -+ notebookDocument["uri"], -+ notebookDocument["notebookType"], -+ cells=notebookDocument["cells"], -+ version=notebookDocument.get("version"), -+ metadata=notebookDocument.get("metadata"), -+ ) -+ for cell in cellTextDocuments or []: -+ workspace.put_cell_document( -+ cell["uri"], -+ notebookDocument["uri"], -+ cell["languageId"], -+ cell["text"], -+ version=cell.get("version"), -+ ) -+ self.lint(notebookDocument["uri"], is_saved=True) -+ -+ def m_notebook_document__did_close( -+ self, notebookDocument=None, cellTextDocuments=None, **_kwargs -+ ) -> None: -+ workspace = self._match_uri_to_workspace(notebookDocument["uri"]) -+ for cell in cellTextDocuments or []: -+ workspace.publish_diagnostics(cell["uri"], []) -+ workspace.rm_document(cell["uri"]) -+ workspace.rm_document(notebookDocument["uri"]) -+ -+ def m_notebook_document__did_change( -+ self, notebookDocument=None, change=None, **_kwargs -+ ) -> None: -+ """ -+ Changes to the notebook document. -+ -+ This could be one of the following: -+ 1. Notebook metadata changed -+ 2. Cell(s) added -+ 3. Cell(s) deleted -+ 4. Cell(s) data changed -+ 4.1 Cell metadata changed -+ 4.2 Cell source changed -+ """ -+ workspace = self._match_uri_to_workspace(notebookDocument["uri"]) -+ -+ if change.get("metadata"): -+ # Case 1 -+ workspace.update_notebook_metadata( -+ notebookDocument["uri"], change.get("metadata") -+ ) -+ -+ cells = change.get("cells") -+ if cells: -+ # Change to cells -+ structure = cells.get("structure") -+ if structure: -+ # Case 2 or 3 -+ notebook_cell_array_change = structure["array"] -+ start = notebook_cell_array_change["start"] -+ cell_delete_count = notebook_cell_array_change["deleteCount"] -+ if cell_delete_count == 0: -+ # Case 2 -+ # Cell documents -+ for cell_document in structure["didOpen"]: -+ workspace.put_cell_document( -+ cell_document["uri"], -+ notebookDocument["uri"], -+ cell_document["languageId"], -+ cell_document["text"], -+ cell_document.get("version"), -+ ) -+ # Cell metadata which is added to Notebook -+ workspace.add_notebook_cells( -+ notebookDocument["uri"], -+ notebook_cell_array_change["cells"], -+ start, -+ ) -+ else: -+ # Case 3 -+ # Cell documents -+ for cell_document in structure["didClose"]: -+ workspace.rm_document(cell_document["uri"]) -+ workspace.publish_diagnostics(cell_document["uri"], []) -+ # Cell metadata which is removed from Notebook -+ workspace.remove_notebook_cells( -+ notebookDocument["uri"], start, cell_delete_count -+ ) -+ -+ data = cells.get("data") -+ if data: -+ # Case 4.1 -+ for cell in data: -+ # update NotebookDocument.cells properties -+ pass -+ -+ text_content = cells.get("textContent") -+ if text_content: -+ # Case 4.2 -+ for cell in text_content: -+ cell_uri = cell["document"]["uri"] -+ # Even though the protocol says that `changes` is an array, we assume that it's always a single -+ # element array that contains the last change to the cell source. -+ workspace.update_document(cell_uri, cell["changes"][0]) -+ self.lint(notebookDocument["uri"], is_saved=True) -+ -+ def m_text_document__did_close(self, textDocument=None, **_kwargs) -> None: -+ workspace = self._match_uri_to_workspace(textDocument["uri"]) -+ workspace.publish_diagnostics(textDocument["uri"], []) -+ workspace.rm_document(textDocument["uri"]) -+ -+ def m_text_document__did_open(self, textDocument=None, **_kwargs) -> None: -+ workspace = self._match_uri_to_workspace(textDocument["uri"]) -+ workspace.put_document( -+ textDocument["uri"], -+ textDocument["text"], -+ version=textDocument.get("version"), -+ ) -+ self._hook("pylsp_document_did_open", textDocument["uri"]) -+ self.lint(textDocument["uri"], is_saved=True) -+ -+ def m_text_document__did_change( -+ self, contentChanges=None, textDocument=None, **_kwargs -+ ) -> None: -+ workspace = self._match_uri_to_workspace(textDocument["uri"]) -+ for change in contentChanges: -+ workspace.update_document( -+ textDocument["uri"], change, version=textDocument.get("version") -+ ) -+ self.lint(textDocument["uri"], is_saved=False) -+ -+ def m_text_document__did_save(self, textDocument=None, **_kwargs) -> None: -+ self.lint(textDocument["uri"], is_saved=True) -+ self.document_did_save(textDocument["uri"]) -+ -+ def m_text_document__code_action( -+ self, textDocument=None, range=None, context=None, **_kwargs -+ ): -+ return self.code_actions(textDocument["uri"], range, context) -+ -+ def m_text_document__code_lens(self, textDocument=None, **_kwargs): -+ return self.code_lens(textDocument["uri"]) -+ -+ def _cell_document__completion(self, cellDocument, position=None, **_kwargs): -+ workspace = self._match_uri_to_workspace(cellDocument.notebook_uri) -+ notebookDocument = workspace.get_maybe_document(cellDocument.notebook_uri) -+ if notebookDocument is None: -+ raise ValueError("Invalid notebook document") -+ -+ cell_data = notebookDocument.cell_data() -+ -+ # Concatenate all cells to be a single temporary document -+ total_source = "\n".join(data["source"] for data in cell_data.values()) -+ with workspace.temp_document(total_source) as temp_uri: -+ # update position to be the position in the temp document -+ if position is not None: -+ position["line"] += cell_data[cellDocument.uri]["line_start"] -+ -+ completions = self.completions(temp_uri, position) -+ -+ # Translate temp_uri locations to cell document locations -+ for item in completions.get("items", []): -+ if item.get("data", {}).get("doc_uri") == temp_uri: -+ item["data"]["doc_uri"] = cellDocument.uri -+ -+ return completions -+ -+ def m_text_document__completion(self, textDocument=None, position=None, **_kwargs): -+ # textDocument here is just a dict with a uri -+ workspace = self._match_uri_to_workspace(textDocument["uri"]) -+ document = workspace.get_document(textDocument["uri"]) -+ if isinstance(document, Cell): -+ return self._cell_document__completion(document, position, **_kwargs) -+ return self.completions(textDocument["uri"], position) -+ -+ def _cell_document__definition(self, cellDocument, position=None, **_kwargs): -+ workspace = self._match_uri_to_workspace(cellDocument.notebook_uri) -+ notebookDocument = workspace.get_maybe_document(cellDocument.notebook_uri) -+ if notebookDocument is None: -+ raise ValueError("Invalid notebook document") -+ -+ cell_data = notebookDocument.cell_data() -+ -+ # Concatenate all cells to be a single temporary document -+ total_source = "\n".join(data["source"] for data in cell_data.values()) -+ with workspace.temp_document(total_source) as temp_uri: -+ # update position to be the position in the temp document -+ if position is not None: -+ position["line"] += cell_data[cellDocument.uri]["line_start"] -+ -+ definitions = self.definitions(temp_uri, position) -+ -+ # Translate temp_uri locations to cell document locations -+ for definition in definitions: -+ if definition["uri"] == temp_uri: -+ # Find the cell the start line is in and adjust the uri and line numbers -+ for cell_uri, data in cell_data.items(): -+ if ( -+ data["line_start"] -+ <= definition["range"]["start"]["line"] -+ <= data["line_end"] -+ ): -+ definition["uri"] = cell_uri -+ definition["range"]["start"]["line"] -= data["line_start"] -+ definition["range"]["end"]["line"] -= data["line_start"] -+ break -+ -+ return definitions -+ -+ def m_text_document__definition(self, textDocument=None, position=None, **_kwargs): -+ # textDocument here is just a dict with a uri -+ workspace = self._match_uri_to_workspace(textDocument["uri"]) -+ document = workspace.get_document(textDocument["uri"]) -+ if isinstance(document, Cell): -+ return self._cell_document__definition(document, position, **_kwargs) -+ return self.definitions(textDocument["uri"], position) -+ -+ def m_text_document__document_highlight( -+ self, textDocument=None, position=None, **_kwargs -+ ): -+ return self.highlight(textDocument["uri"], position) -+ -+ def m_text_document__hover(self, textDocument=None, position=None, **_kwargs): -+ return self.hover(textDocument["uri"], position) -+ -+ def m_text_document__document_symbol(self, textDocument=None, **_kwargs): -+ return self.document_symbols(textDocument["uri"]) -+ -+ def m_text_document__formatting(self, textDocument=None, options=None, **_kwargs): -+ return self.format_document(textDocument["uri"], options) -+ -+ def m_text_document__rename( -+ self, textDocument=None, position=None, newName=None, **_kwargs -+ ): -+ return self.rename(textDocument["uri"], position, newName) -+ -+ def m_text_document__folding_range(self, textDocument=None, **_kwargs): -+ return self.folding(textDocument["uri"]) -+ -+ def m_text_document__range_formatting( -+ self, textDocument=None, range=None, options=None, **_kwargs -+ ): -+ return self.format_range(textDocument["uri"], range, options) -+ -+ def m_text_document__references( -+ self, textDocument=None, position=None, context=None, **_kwargs -+ ): -+ exclude_declaration = not context["includeDeclaration"] -+ return self.references(textDocument["uri"], position, exclude_declaration) -+ -+ def m_text_document__signature_help( -+ self, textDocument=None, position=None, **_kwargs -+ ): -+ return self.signature_help(textDocument["uri"], position) -+ -+ def m_workspace__did_change_configuration(self, settings=None) -> None: -+ if self.config is not None: -+ self.config.update((settings or {}).get("pylsp", {})) -+ for workspace in self.workspaces.values(): -+ workspace.update_config(settings) -+ self._hook("pylsp_workspace_configuration_changed") -+ for doc_uri in workspace.documents: -+ self.lint(doc_uri, is_saved=False) -+ -+ def m_workspace__did_change_workspace_folders(self, event=None, **_kwargs): -+ if event is None: -+ return -+ added = event.get("added", []) -+ removed = event.get("removed", []) -+ -+ for removed_info in removed: -+ if "uri" in removed_info: -+ removed_uri = removed_info["uri"] -+ self.workspaces.pop(removed_uri, None) -+ -+ for added_info in added: -+ if "uri" in added_info: -+ added_uri = added_info["uri"] -+ workspace_config = config.Config( -+ added_uri, -+ self.config._init_opts, -+ self.config._process_id, -+ self.config._capabilities, -+ ) -+ workspace_config.update(self.config._settings) -+ self.workspaces[added_uri] = Workspace( -+ added_uri, self._endpoint, workspace_config -+ ) -+ -+ root_workspace_removed = any( -+ removed_info["uri"] == self.root_uri for removed_info in removed -+ ) -+ workspace_added = len(added) > 0 and "uri" in added[0] -+ if root_workspace_removed and workspace_added: -+ added_uri = added[0]["uri"] -+ self.root_uri = added_uri -+ new_root_workspace = self.workspaces[added_uri] -+ self.config = new_root_workspace._config -+ self.workspace = new_root_workspace -+ elif root_workspace_removed: -+ # NOTE: Removing the root workspace can only happen when the server -+ # is closed, thus the else condition of this if can never happen. -+ if self.workspaces: -+ log.debug("Root workspace deleted!") -+ available_workspaces = sorted(self.workspaces) -+ first_workspace = available_workspaces[0] -+ new_root_workspace = self.workspaces[first_workspace] -+ self.root_uri = first_workspace -+ self.config = new_root_workspace._config -+ self.workspace = new_root_workspace -+ -+ # Migrate documents that are on the root workspace and have a better -+ # match now -+ doc_uris = list(self.workspace._docs.keys()) -+ for uri in doc_uris: -+ doc = self.workspace._docs.pop(uri) -+ new_workspace = self._match_uri_to_workspace(uri) -+ new_workspace._docs[uri] = doc -+ -+ def m_workspace__did_change_watched_files(self, changes=None, **_kwargs): -+ changed_py_files = set() -+ config_changed = False -+ for d in changes or []: -+ if d["uri"].endswith(PYTHON_FILE_EXTENSIONS): -+ changed_py_files.add(d["uri"]) -+ elif d["uri"].endswith(CONFIG_FILEs): -+ config_changed = True -+ -+ if config_changed: -+ self.config.settings.cache_clear() -+ elif not changed_py_files: -+ # Only externally changed python files and lint configs may result in changed diagnostics. -+ return -+ -+ for workspace in self.workspaces.values(): -+ for doc_uri in workspace.documents: -+ # Changes in doc_uri are already handled by m_text_document__did_save -+ if doc_uri not in changed_py_files: -+ self.lint(doc_uri, is_saved=False) -+ -+ def m_workspace__execute_command(self, command=None, arguments=None): -+ return self.execute_command(command, arguments) -+ -+class FakeEditorMethodsMixin: -+ """ -+ Represents the methods to be added to a dispatcher class when faking an editor. -+ """ -+ -+ def m_window__work_done_progress__create(self, *_args, **_kwargs): -+ """ -+ Fake editor method `window/workDoneProgress/create`. -+ -+ related spec: -+ https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#window_workDoneProgress_create -+ """ -+ return None -+ -+class FakePythonLSPServer(FakeEditorMethodsMixin, PythonLSPServer): -+ pass -+ -+class FakeEndpoint(Endpoint): -+ """ -+ Fake Endpoint representing the editor / LSP client. -+ -+ The `dispatcher` dict will be used to synchronously calculate the responses -+ for calls to `.request` and resolve the futures with the value or errors. -+ -+ Fake methods in the `dispatcher` should raise `JsonRpcException` for any -+ error. -+ """ -+ -+ def request(self, method, params=None): -+ request_future = super().request(method, params) -+ try: -+ request_future.set_result(self._dispatcher[method](params)) -+ except JsonRpcException as e: -+ request_future.set_exception(e) -+ -+ return request_future -+ -+@pytest.fixture -+def pylsp_w_workspace_folders(tmpdir): -+ """Return an initialized python LS""" -+ ls = FakePythonLSPServer(StringIO, StringIO, endpoint_cls=FakeEndpoint) -+ -+ folder1 = tmpdir.mkdir("folder1") -+ folder2 = tmpdir.mkdir("folder2") -+ -+ ls.m_initialize( -+ processId=1, -+ rootUri=uris.from_fs_path(str(folder1)), -+ initializationOptions={}, -+ workspaceFolders=[ -+ {"uri": uris.from_fs_path(str(folder1)), "name": "folder1"}, -+ {"uri": uris.from_fs_path(str(folder2)), "name": "folder2"}, -+ ], -+ ) -+ -+ workspace_folders = [folder1, folder2] -+ return (ls, workspace_folders) -+ -+@hookimpl -+def pylsp_settings(): -+ # Default flake8 to disabled -+ return {"plugins": {"flake8": {"enabled": False}}} -+ -+FIX_IGNORES_RE = re.compile(r"([^a-zA-Z0-9_,]*;.*(\W+||$))") -+ -+def run_flake8(flake8_executable, args, document, source): -+ """Run flake8 with the provided arguments, logs errors -+ from stderr if any. -+ """ -+ # a quick temporary fix to deal with Atom -+ args = [ -+ (i if not i.startswith("--ignore=") else FIX_IGNORES_RE.sub("", i)) -+ for i in args -+ if i is not None -+ ] -+ -+ if document.path and document.path.startswith(document._workspace.root_path): -+ args.extend( -+ [ -+ "--stdin-display-name", -+ os.path.relpath(document.path, document._workspace.root_path), -+ ] -+ ) -+ -+ # if executable looks like a path resolve it -+ if not os.path.isfile(flake8_executable) and os.sep in flake8_executable: -+ flake8_executable = os.path.abspath( -+ os.path.expanduser(os.path.expandvars(flake8_executable)) -+ ) -+ -+ log.debug("Calling %s with args: '%s'", flake8_executable, args) -+ popen_kwargs = {} -+ if cwd := document._workspace.root_path: -+ popen_kwargs["cwd"] = cwd -+ try: -+ cmd = [flake8_executable] -+ cmd.extend(args) -+ p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, **popen_kwargs) -+ except IOError: -+ log.debug( -+ "Can't execute %s. Trying with '%s -m flake8'", -+ flake8_executable, -+ sys.executable, -+ ) -+ cmd = [sys.executable, "-m", "flake8"] -+ cmd.extend(args) -+ p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, **popen_kwargs) -+ (stdout, stderr) = p.communicate(source.encode()) -+ if stderr: -+ log.error("Error while running flake8 '%s'", stderr.decode()) -+ return stdout.decode() -+ -+def build_args(options): -+ """Build arguments for calling flake8. -+ -+ Args: -+ options: dictionary of argument names and their values. -+ """ -+ args = ["-"] # use stdin -+ for arg_name, arg_val in options.items(): -+ if arg_val is None: -+ continue -+ arg = None -+ if isinstance(arg_val, list): -+ arg = "--{}={}".format(arg_name, ",".join(arg_val)) -+ elif isinstance(arg_val, bool): -+ if arg_val: -+ arg = "--{}".format(arg_name) -+ else: -+ arg = "--{}={}".format(arg_name, arg_val) -+ args.append(arg) -+ return args -+ -+UNNECESSITY_CODES = { -+ "F401", # `module` imported but unused -+ "F504", # % format unused named arguments -+ "F522", # .format(...) unused named arguments -+ "F523", # .format(...) unused positional arguments -+ "F841", # local variable `name` is assigned to but never used -+} -+ -+# NOTE: If the user sets the flake8 executable with workspace configuration, the -+# error codes in this set may be inaccurate. -+ERROR_CODES = ( -+ # Errors from the pyflakes plugin of flake8 -+ {FLAKE8_PYFLAKES_CODES.get(m.__name__, "E999") for m in PYFLAKES_ERROR_MESSAGES} -+ # Syntax error from flake8 itself -+ | {"E999"} -+) -+ -+def parse_stdout(source, stdout): -+ """ -+ Build a diagnostics from flake8's output, it should extract every result and format -+ it into a dict that looks like this: -+ { -+ 'source': 'flake8', -+ 'code': code, # 'E501' -+ 'range': { -+ 'start': { -+ 'line': start_line, -+ 'character': start_column, -+ }, -+ 'end': { -+ 'line': end_line, -+ 'character': end_column, -+ }, -+ }, -+ 'message': msg, -+ 'severity': lsp.DiagnosticSeverity.*, -+ } -+ -+ Args: -+ document: The document to be linted. -+ stdout: output from flake8 -+ Returns: -+ A list of dictionaries. -+ """ -+ -+ document_lines = source.splitlines(True) -+ diagnostics = [] -+ lines = stdout.splitlines() -+ for raw_line in lines: -+ parsed_line = re.match(r"(.*):(\d*):(\d*): (\w*) (.*)", raw_line) -+ if not parsed_line: -+ log.debug("Flake8 output parser can't parse line '%s'", raw_line) -+ continue -+ -+ parsed_line = parsed_line.groups() -+ if len(parsed_line) != 5: -+ log.debug("Flake8 output parser can't parse line '%s'", raw_line) -+ continue -+ -+ _, line, character, code, msg = parsed_line -+ line = int(line) - 1 -+ character = int(character) - 1 -+ # show also the code in message -+ msg = code + " " + msg -+ severity = lsp.DiagnosticSeverity.Warning -+ if code in ERROR_CODES: -+ severity = lsp.DiagnosticSeverity.Error -+ diagnostic = { -+ "source": "flake8", -+ "code": code, -+ "range": { -+ "start": {"line": line, "character": character}, -+ "end": { -+ "line": line, -+ # no way to determine the column -+ "character": len(document_lines[line]), -+ }, -+ }, -+ "message": msg, -+ "severity": severity, -+ } -+ if code in UNNECESSITY_CODES: -+ diagnostic["tags"] = [lsp.DiagnosticTag.Unnecessary] -+ diagnostics.append(diagnostic) -+ -+ return diagnostics -+ -+@hookimpl -+def pylsp_lint(workspace, document): -+ with workspace.report_progress("lint: flake8"): -+ config = workspace._config -+ settings = config.plugin_settings("flake8", document_path=document.path) -+ log.debug("Got flake8 settings: %s", settings) -+ -+ ignores = settings.get("ignore", []) -+ per_file_ignores = settings.get("perFileIgnores") -+ -+ if per_file_ignores: -+ prev_file_pat = None -+ for path in per_file_ignores: -+ try: -+ file_pat, errors = path.split(":") -+ prev_file_pat = file_pat -+ except ValueError: -+ # It's legal to just specify another error type for the same -+ # file pattern: -+ if prev_file_pat is None: -+ log.warning("skipping a Per-file-ignore with no file pattern") -+ continue -+ file_pat = prev_file_pat -+ errors = path -+ if PurePath(document.path).match(file_pat): -+ ignores.extend(errors.split(",")) -+ -+ opts = { -+ "config": settings.get("config"), -+ "exclude": settings.get("exclude"), -+ "extend-ignore": settings.get("extendIgnore"), -+ "extend-select": settings.get("extendSelect"), -+ "filename": settings.get("filename"), -+ "hang-closing": settings.get("hangClosing"), -+ "ignore": ignores or None, -+ "max-complexity": settings.get("maxComplexity"), -+ "max-line-length": settings.get("maxLineLength"), -+ "indent-size": settings.get("indentSize"), -+ "select": settings.get("select"), -+ } -+ -+ # flake takes only absolute path to the config. So we should check and -+ # convert if necessary -+ if opts.get("config") and not os.path.isabs(opts.get("config")): -+ opts["config"] = os.path.abspath( -+ os.path.expanduser(os.path.expandvars(opts.get("config"))) -+ ) -+ log.debug("using flake8 with config: %s", opts["config"]) -+ -+ # Call the flake8 utility then parse diagnostics from stdout -+ flake8_executable = settings.get("executable", "flake8") -+ -+ args = build_args(opts) -+ -+ # ensure the same source is used for flake8 execution and result parsing; -+ # single source access improves performance as it is only one disk access -+ source = document.source -+ output = run_flake8(flake8_executable, args, document, source) -+ return parse_stdout(source, output) -+ -+SKIP_NODES = (tree_nodes.Module, tree_nodes.IfStmt, tree_nodes.TryStmt) -+ -+def __merge_folding_ranges(left, right): -+ for start in list(left.keys()): -+ right_start = right.pop(start, None) -+ if right_start is not None: -+ left[start] = max(right_start, start) -+ left.update(right) -+ return left -+ -+IDENTATION_REGEX = re.compile(r"(\s+).+") -+ -+def __empty_identation_stack( -+ identation_stack, level_limits, current_line, folding_ranges -+): -+ while identation_stack != []: -+ upper_level = identation_stack.pop(0) -+ level_start = level_limits.pop(upper_level) -+ folding_ranges.append((level_start, current_line)) -+ return folding_ranges -+ -+def __match_identation_stack( -+ identation_stack, level, level_limits, folding_ranges, current_line -+): -+ upper_level = identation_stack.pop(0) -+ while upper_level >= level: -+ level_start = level_limits.pop(upper_level) -+ folding_ranges.append((level_start, current_line)) -+ upper_level = identation_stack.pop(0) -+ identation_stack.insert(0, upper_level) -+ return identation_stack, folding_ranges -+ -+def __compute_folding_ranges_identation(text): -+ lines = text.splitlines() -+ folding_ranges = [] -+ identation_stack = [] -+ level_limits = {} -+ current_level = 0 -+ current_line = 0 -+ while lines[current_line] == "": -+ current_line += 1 -+ for i, line in enumerate(lines): -+ if i < current_line: -+ continue -+ i += 1 -+ identation_match = IDENTATION_REGEX.match(line) -+ if identation_match is not None: -+ whitespace = identation_match.group(1) -+ level = len(whitespace) -+ if level > current_level: -+ level_limits[current_level] = current_line -+ identation_stack.insert(0, current_level) -+ current_level = level -+ elif level < current_level: -+ identation_stack, folding_ranges = __match_identation_stack( -+ identation_stack, level, level_limits, folding_ranges, current_line -+ ) -+ current_level = level -+ else: -+ folding_ranges = __empty_identation_stack( -+ identation_stack, level_limits, current_line, folding_ranges -+ ) -+ current_level = 0 -+ if line.strip() != "": -+ current_line = i -+ folding_ranges = __empty_identation_stack( -+ identation_stack, level_limits, current_line, folding_ranges -+ ) -+ return dict(folding_ranges) -+ -+def __check_if_node_is_valid(node): -+ valid = True -+ if isinstance(node, tree_nodes.PythonNode): -+ kind = node.type -+ valid = kind not in { -+ "decorated", -+ "parameters", -+ "dictorsetmaker", -+ "testlist_comp", -+ } -+ if kind == "suite": -+ if isinstance(node.parent, tree_nodes.Function): -+ valid = False -+ return valid -+ -+def __handle_skip(stack, skip): -+ body = stack[skip] -+ children = [body] -+ if hasattr(body, "children"): -+ children = body.children -+ stack = stack[:skip] + children + stack[skip + 1 :] -+ node = body -+ end_line, _ = body.end_pos -+ return node, end_line -+ -+def __handle_flow_nodes(node, end_line, stack): -+ from_keyword = False -+ if isinstance(node, tree_nodes.Keyword): -+ from_keyword = True -+ if node.value in {"if", "elif", "with", "while"}: -+ node, end_line = __handle_skip(stack, 2) -+ elif node.value in {"except"}: -+ first_node = stack[0] -+ if isinstance(first_node, tree_nodes.Operator): -+ node, end_line = __handle_skip(stack, 1) -+ else: -+ node, end_line = __handle_skip(stack, 2) -+ elif node.value in {"for"}: -+ node, end_line = __handle_skip(stack, 4) -+ elif node.value in {"else"}: -+ node, end_line = __handle_skip(stack, 1) -+ return end_line, from_keyword, node, stack -+ -+def __compute_start_end_lines(node, stack): -+ start_line, _ = node.start_pos -+ end_line, _ = node.end_pos -+ modified = False -+ end_line, from_keyword, node, stack = __handle_flow_nodes(node, end_line, stack) -+ -+ last_leaf = node.get_last_leaf() -+ last_newline = isinstance(last_leaf, tree_nodes.Newline) -+ last_operator = isinstance(last_leaf, tree_nodes.Operator) -+ node_is_operator = isinstance(node, tree_nodes.Operator) -+ last_operator = last_operator or not node_is_operator -+ -+ end_line -= 1 -+ -+ if isinstance(node.parent, tree_nodes.PythonNode) and not from_keyword: -+ kind = node.type -+ if kind in {"suite", "atom", "atom_expr", "arglist"}: -+ if len(stack) > 0: -+ next_node = stack[0] -+ next_line, _ = next_node.start_pos -+ if next_line > end_line: -+ end_line += 1 -+ modified = True -+ if not last_newline and not modified and not last_operator: -+ end_line += 1 -+ return start_line, end_line, stack -+ -+def __compute_folding_ranges(tree, lines): -+ folding_ranges = {} -+ stack = [tree] -+ -+ while len(stack) > 0: -+ node = stack.pop(0) -+ if isinstance(node, tree_nodes.Newline): -+ # Skip newline nodes -+ continue -+ if isinstance(node, tree_nodes.PythonErrorNode): -+ # Fallback to indentation-based (best-effort) folding -+ start_line, _ = node.start_pos -+ start_line -= 1 -+ padding = [""] * start_line -+ text = "\n".join(padding + lines[start_line:]) + "\n" -+ identation_ranges = __compute_folding_ranges_identation(text) -+ folding_ranges = __merge_folding_ranges(folding_ranges, identation_ranges) -+ break -+ if not isinstance(node, SKIP_NODES): -+ valid = __check_if_node_is_valid(node) -+ if valid: -+ start_line, end_line, stack = __compute_start_end_lines(node, stack) -+ if end_line > start_line: -+ current_end = folding_ranges.get(start_line, -1) -+ folding_ranges[start_line] = max(current_end, end_line) -+ if hasattr(node, "children"): -+ stack = node.children + stack -+ -+ folding_ranges = sorted(folding_ranges.items()) -+ return folding_ranges -+ -+@hookimpl -+def pylsp_folding_range(document): -+ program = document.source + "\n" -+ lines = program.splitlines() -+ tree = parso.parse(program) -+ ranges = __compute_folding_ranges(tree, lines) -+ -+ results = [] -+ for start_line, end_line in ranges: -+ start_line -= 1 -+ end_line -= 1 -+ # If start/end character is not defined, then it defaults to the -+ # corresponding line last character -+ results.append( -+ { -+ "startLine": start_line, -+ "endLine": end_line, -+ } -+ ) -+ return results -+ -+@hookimpl -+def pylsp_document_highlight(document, position): -+ code_position = _utils.position_to_jedi_linecolumn(document, position) -+ usages = document.jedi_script().get_references(**code_position) -+ -+ def is_valid(definition): -+ return definition.line is not None and definition.column is not None -+ -+ def local_to_document(definition): -+ return ( -+ not definition.module_path or str(definition.module_path) == document.path -+ ) -+ -+ return [ -+ { -+ "range": { -+ "start": {"line": d.line - 1, "character": d.column}, -+ "end": {"line": d.line - 1, "character": d.column + len(d.name)}, -+ }, -+ "kind": lsp.DocumentHighlightKind.Write -+ if d.is_definition() -+ else lsp.DocumentHighlightKind.Read, -+ } -+ for d in usages -+ if is_valid(d) and local_to_document(d) -+ ] -+ -+@hookspec -+def pylsp_code_actions(config, workspace, document, range, context): -+ pass -+ -+@hookspec -+def pylsp_code_lens(config, workspace, document) -> None: -+ pass -+ -+@hookspec -+def pylsp_commands(config, workspace) -> None: -+ """The list of command strings supported by the server. -+ -+ Returns: -+ List[str]: The supported commands. -+ """ -+ -+@hookspec -+def pylsp_completions(config, workspace, document, position, ignored_names) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_completion_item_resolve(config, workspace, document, completion_item) -> None: -+ pass -+ -+@hookspec -+def pylsp_definitions(config, workspace, document, position) -> None: -+ pass -+ -+@hookspec -+def pylsp_dispatchers(config, workspace) -> None: -+ pass -+ -+@hookspec -+def pylsp_document_did_open(config, workspace, document) -> None: -+ pass -+ -+@hookspec -+def pylsp_document_did_save(config, workspace, document) -> None: -+ pass -+ -+@hookspec -+def pylsp_document_highlight(config, workspace, document, position) -> None: -+ pass -+ -+@hookspec -+def pylsp_document_symbols(config, workspace, document) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_execute_command(config, workspace, command, arguments) -> None: -+ pass -+ -+@hookspec -+def pylsp_experimental_capabilities(config, workspace) -> None: -+ pass -+ -+@hookspec -+def pylsp_folding_range(config, workspace, document) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_format_document(config, workspace, document, options) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_format_range(config, workspace, document, range, options) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_hover(config, workspace, document, position) -> None: -+ pass -+ -+@hookspec -+def pylsp_initialize(config, workspace) -> None: -+ pass -+ -+@hookspec -+def pylsp_initialized() -> None: -+ pass -+ -+@hookspec -+def pylsp_lint(config, workspace, document, is_saved) -> None: -+ pass -+ -+@hookspec -+def pylsp_references( -+ config, workspace, document, position, exclude_declaration -+) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_rename(config, workspace, document, position, new_name) -> None: -+ pass -+ -+@hookspec -+def pylsp_settings(config) -> None: -+ pass -+ -+@hookspec(firstresult=True) -+def pylsp_signature_help(config, workspace, document, position) -> None: -+ pass -+ -+@hookspec -+def pylsp_workspace_configuration_changed(config, workspace) -> None: -+ pass -+ -+@hookimpl -+def pylsp_hover(config, document, position): -+ code_position = _utils.position_to_jedi_linecolumn(document, position) -+ definitions = document.jedi_script(use_document_path=True).infer(**code_position) -+ word = document.word_at_position(position) -+ -+ # Find first exact matching definition -+ definition = next((x for x in definitions if x.name == word), None) -+ -+ # Ensure a definition is used if only one is available -+ # even if the word doesn't match. An example of this case is 'np' -+ # where 'numpy' doesn't match with 'np'. Same for NumPy ufuncs -+ if len(definitions) == 1: -+ definition = definitions[0] -+ -+ if not definition: -+ return {"contents": ""} -+ -+ hover_capabilities = config.capabilities.get("textDocument", {}).get("hover", {}) -+ supported_markup_kinds = hover_capabilities.get("contentFormat", ["markdown"]) -+ preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -+ -+ # Find first exact matching signature -+ signature = next( -+ ( -+ x.to_string() -+ for x in definition.get_signatures() -+ if (x.name == word and x.type not in ["module"]) -+ ), -+ "", -+ ) -+ -+ return { -+ "contents": _utils.format_docstring( -+ # raw docstring returns only doc, without signature -+ definition.docstring(raw=True), -+ preferred_markup_kind, -+ signatures=[signature] if signature else None, -+ ) -+ } -+ -+# Types of parso nodes for which snippet is not included in the completion -+_IMPORTS = ("import_name", "import_from") -+ -+# Types of parso node for errors -+_ERRORS = ("error_node",) -+ -+def use_snippets(document, position): -+ """ -+ Determine if it's necessary to return snippets in code completions. -+ -+ This returns `False` if a completion is being requested on an import -+ statement, `True` otherwise. -+ """ -+ line = position["line"] -+ lines = document.source.split("\n", line) -+ act_lines = [lines[line][: position["character"]]] -+ line -= 1 -+ last_character = "" -+ while line > -1: -+ act_line = lines[line] -+ if ( -+ act_line.rstrip().endswith("\\") -+ or act_line.rstrip().endswith("(") -+ or act_line.rstrip().endswith(",") -+ ): -+ act_lines.insert(0, act_line) -+ line -= 1 -+ if act_line.rstrip().endswith("("): -+ # Needs to be added to the end of the code before parsing -+ # to make it valid, otherwise the node type could end -+ # being an 'error_node' for multi-line imports that use '(' -+ last_character = ")" -+ else: -+ break -+ if "(" in act_lines[-1].strip(): -+ last_character = ")" -+ code = "\n".join(act_lines).rsplit(";", maxsplit=1)[-1].strip() + last_character -+ tokens = parso.parse(code) -+ expr_type = tokens.children[0].type -+ return expr_type not in _IMPORTS and not (expr_type in _ERRORS and "import" in code) -+ -+# Map to the LSP type -+# > Valid values for type are ``module``, `` class ``, ``instance``, ``function``, -+# > ``param``, ``path``, ``keyword``, ``property`` and ``statement``. -+# see: https://jedi.readthedocs.io/en/latest/docs/api-classes.html#jedi.api.classes.BaseName.type -+_TYPE_MAP = { -+ "module": lsp.CompletionItemKind.Module, -+ "namespace": lsp.CompletionItemKind.Module, # to be added in Jedi 0.18+ -+ "class": lsp.CompletionItemKind.Class, -+ "instance": lsp.CompletionItemKind.Reference, -+ "function": lsp.CompletionItemKind.Function, -+ "param": lsp.CompletionItemKind.Variable, -+ "path": lsp.CompletionItemKind.File, -+ "keyword": lsp.CompletionItemKind.Keyword, -+ "property": lsp.CompletionItemKind.Property, # added in Jedi 0.18 -+ "statement": lsp.CompletionItemKind.Variable, -+} -+ -+def is_exception_class(name): -+ """ -+ Determine if a class name is an instance of an Exception. -+ -+ This returns `False` if the name given corresponds with a instance of -+ the 'Exception' class, `True` otherwise -+ """ -+ try: -+ return name in [cls.__name__ for cls in Exception.__subclasses__()] -+ except AttributeError: -+ # Needed in case a class don't uses new-style -+ # class definition in Python 2 -+ return False -+ -+def _detail(definition): -+ try: -+ return definition.parent().full_name or "" -+ except AttributeError: -+ return definition.full_name or "" -+ -+def _resolve_completion(completion, d, markup_kind: str): -+ completion["detail"] = _detail(d) -+ try: -+ docs = _utils.format_docstring( -+ d.docstring(raw=True), -+ signatures=[signature.to_string() for signature in d.get_signatures()], -+ markup_kind=markup_kind, -+ ) -+ except Exception: -+ docs = "" -+ completion["documentation"] = docs -+ return completion -+ -+def _label(definition, resolve=False): -+ if not resolve: -+ return definition.name -+ sig = LABEL_RESOLVER.get_or_create(definition) -+ if sig: -+ return sig -+ return definition.name -+ -+def _snippet(definition, resolve=False): -+ if not resolve: -+ return {} -+ snippet = SNIPPET_RESOLVER.get_or_create(definition) -+ return snippet -+ -+def _sort_text(definition): -+ """Ensure builtins appear at the bottom. -+ Description is of format : . -+ """ -+ -+ # If its 'hidden', put it next last -+ prefix = "z{}" if definition.name.startswith("_") else "a{}" -+ return prefix.format(definition.name) -+ -+def _format_completion( -+ d, -+ markup_kind: str, -+ include_params=True, -+ resolve=False, -+ resolve_label_or_snippet=False, -+ snippet_support=False, -+): -+ completion = { -+ "label": _label(d, resolve_label_or_snippet), -+ "kind": _TYPE_MAP.get(d.type), -+ "sortText": _sort_text(d), -+ "insertText": d.name, -+ } -+ -+ if resolve: -+ completion = _resolve_completion(completion, d, markup_kind) -+ -+ # Adjustments for file completions -+ if d.type == "path": -+ path = os.path.normpath(d.name) -+ -+ # If the completion ends with os.sep, it means it's a directory. So we add os.sep at the end -+ # to ease additional file completions. -+ if d.name.endswith(os.sep): -+ if os.name == "nt": -+ path = path + "\\" -+ else: -+ path = path + "/" -+ -+ # Escape to prevent conflicts with the code snippets grammer -+ # See also https://github.com/python-lsp/python-lsp-server/issues/373 -+ if snippet_support: -+ path = path.replace("\\", "\\\\") -+ path = path.replace("/", "\\/") -+ -+ completion["insertText"] = path -+ -+ if include_params and not is_exception_class(d.name): -+ snippet = _snippet(d, resolve_label_or_snippet) -+ completion.update(snippet) -+ -+ return completion -+ -+@hookimpl -+def pylsp_completions(config, document, position): -+ """Get formatted completions for current code position""" -+ settings = config.plugin_settings("jedi_completion", document_path=document.path) -+ resolve_eagerly = settings.get("eager", False) -+ code_position = _utils.position_to_jedi_linecolumn(document, position) -+ -+ code_position["fuzzy"] = settings.get("fuzzy", False) -+ completions = document.jedi_script(use_document_path=True).complete(**code_position) -+ -+ if not completions: -+ return None -+ -+ completion_capabilities = config.capabilities.get("textDocument", {}).get( -+ "completion", {} -+ ) -+ item_capabilities = completion_capabilities.get("completionItem", {}) -+ snippet_support = item_capabilities.get("snippetSupport") -+ supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -+ preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -+ -+ should_include_params = settings.get("include_params") -+ should_include_class_objects = settings.get("include_class_objects", False) -+ should_include_function_objects = settings.get("include_function_objects", False) -+ -+ max_to_resolve = settings.get("resolve_at_most", 25) -+ modules_to_cache_for = settings.get("cache_for", None) -+ if modules_to_cache_for is not None: -+ LABEL_RESOLVER.cached_modules = modules_to_cache_for -+ SNIPPET_RESOLVER.cached_modules = modules_to_cache_for -+ -+ include_params = ( -+ snippet_support and should_include_params and use_snippets(document, position) -+ ) -+ include_class_objects = ( -+ snippet_support -+ and should_include_class_objects -+ and use_snippets(document, position) -+ ) -+ include_function_objects = ( -+ snippet_support -+ and should_include_function_objects -+ and use_snippets(document, position) -+ ) -+ -+ ready_completions = [ -+ _format_completion( -+ c, -+ markup_kind=preferred_markup_kind, -+ include_params=include_params if c.type in ["class", "function"] else False, -+ resolve=resolve_eagerly, -+ resolve_label_or_snippet=(i < max_to_resolve), -+ snippet_support=snippet_support, -+ ) -+ for i, c in enumerate(completions) -+ ] -+ -+ # TODO split up once other improvements are merged -+ if include_class_objects: -+ for i, c in enumerate(completions): -+ if c.type == "class": -+ completion_dict = _format_completion( -+ c, -+ markup_kind=preferred_markup_kind, -+ include_params=False, -+ resolve=resolve_eagerly, -+ resolve_label_or_snippet=(i < max_to_resolve), -+ snippet_support=snippet_support, -+ ) -+ completion_dict["kind"] = lsp.CompletionItemKind.TypeParameter -+ completion_dict["label"] += " object" -+ ready_completions.append(completion_dict) -+ -+ if include_function_objects: -+ for i, c in enumerate(completions): -+ if c.type == "function": -+ completion_dict = _format_completion( -+ c, -+ markup_kind=preferred_markup_kind, -+ include_params=False, -+ resolve=resolve_eagerly, -+ resolve_label_or_snippet=(i < max_to_resolve), -+ snippet_support=snippet_support, -+ ) -+ completion_dict["kind"] = lsp.CompletionItemKind.TypeParameter -+ completion_dict["label"] += " object" -+ ready_completions.append(completion_dict) -+ -+ for completion_dict in ready_completions: -+ completion_dict["data"] = {"doc_uri": document.uri} -+ -+ # most recently retrieved completion items, used for resolution -+ document.shared_data["LAST_JEDI_COMPLETIONS"] = { -+ # label is the only required property; here it is assumed to be unique -+ completion["label"]: (completion, data) -+ for completion, data in zip(ready_completions, completions) -+ } -+ -+ return ready_completions or None -+ -+@hookimpl -+def pylsp_completion_item_resolve(config, completion_item, document): -+ """Resolve formatted completion for given non-resolved completion""" -+ shared_data = document.shared_data["LAST_JEDI_COMPLETIONS"].get( -+ completion_item["label"] -+ ) -+ -+ completion_capabilities = config.capabilities.get("textDocument", {}).get( -+ "completion", {} -+ ) -+ item_capabilities = completion_capabilities.get("completionItem", {}) -+ supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -+ preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -+ -+ if shared_data: -+ completion, data = shared_data -+ return _resolve_completion(completion, data, markup_kind=preferred_markup_kind) -+ return completion_item -+ -+def _num_lines(file_contents): -+ "Count the number of lines in the given string." -+ if _utils.get_eol_chars(file_contents): -+ return len(file_contents.splitlines()) -+ return 0 -+ -+@hookimpl -+def pylsp_rename(config, workspace, document, position, new_name): -+ log.debug( -+ "Executing rename of %s to %s", document.word_at_position(position), new_name -+ ) -+ kwargs = _utils.position_to_jedi_linecolumn(document, position) -+ kwargs["new_name"] = new_name -+ try: -+ refactoring = document.jedi_script().rename(**kwargs) -+ except NotImplementedError as exc: -+ raise Exception( -+ "No support for renaming in Python 2/3.5 with Jedi. " -+ "Consider using the pylsp-rope plugin instead" -+ ) from exc -+ log.debug("Finished rename: %s", refactoring.get_diff()) -+ changes = [] -+ -+ changed_files = refactoring.get_changed_files() -+ for file_path, changed_file in changed_files.items(): -+ uri = uris.from_fs_path(str(file_path)) -+ doc = workspace.get_maybe_document(uri) -+ changes.append( -+ { -+ "textDocument": {"uri": uri, "version": doc.version if doc else None}, -+ "edits": [ -+ { -+ "range": { -+ "start": {"line": 0, "character": 0}, -+ "end": { -+ "line": _num_lines(changed_file.get_new_code()), -+ "character": 0, -+ }, -+ }, -+ "newText": changed_file.get_new_code(), -+ } -+ ], -+ } -+ ) -+ return {"documentChanges": changes} -+ -+THRESHOLD = "threshold" -+ -+DEFAULT_THRESHOLD = 15 -+ -+@hookimpl -+def pylsp_lint(config, workspace, document): -+ with workspace.report_progress("lint: mccabe"): -+ threshold = config.plugin_settings("mccabe", document_path=document.path).get( -+ THRESHOLD, DEFAULT_THRESHOLD -+ ) -+ log.debug("Running mccabe lint with threshold: %s", threshold) -+ -+ try: -+ tree = compile(document.source, document.path, "exec", ast.PyCF_ONLY_AST) -+ except SyntaxError: -+ # We'll let the other linters point this one out -+ return None -+ -+ visitor = mccabe.PathGraphingAstVisitor() -+ visitor.preorder(tree, visitor) -+ -+ diags = [] -+ for graph in visitor.graphs.values(): -+ if graph.complexity() >= threshold: -+ diags.append( -+ { -+ "source": "mccabe", -+ "range": { -+ "start": { -+ "line": graph.lineno - 1, -+ "character": graph.column, -+ }, -+ "end": { -+ "line": graph.lineno - 1, -+ "character": len(document.lines[graph.lineno]), -+ }, -+ }, -+ "message": "Cyclomatic complexity too high: %s (threshold %s)" -+ % (graph.complexity(), threshold), -+ "severity": lsp.DiagnosticSeverity.Warning, -+ } -+ ) -+ -+ return diags -+ -+MODULES = [ -+ "OpenGL", -+ "PIL", -+ "array", -+ "audioop", -+ "binascii", -+ "cPickle", -+ "cStringIO", -+ "cmath", -+ "collections", -+ "datetime", -+ "errno", -+ "exceptions", -+ "gc", -+ "imageop", -+ "imp", -+ "itertools", -+ "marshal", -+ "math", -+ "matplotlib", -+ "mmap", -+ "mpmath", -+ "msvcrt", -+ "networkx", -+ "nose", -+ "nt", -+ "numpy", -+ "operator", -+ "os", -+ "os.path", -+ "pandas", -+ "parser", -+ "rgbimg", -+ "scipy", -+ "signal", -+ "skimage", -+ "sklearn", -+ "statsmodels", -+ "strop", -+ "sympy", -+ "sys", -+ "thread", -+ "time", -+ "wx", -+ "xxsubtype", -+ "zipimport", -+ "zlib", -+] -+ -+@hookimpl -+def pylsp_settings(): -+ # Setup default modules to preload, and rope extension modules -+ return { -+ "plugins": {"preload": {"modules": MODULES}}, -+ "rope": {"extensionModules": MODULES}, -+ } -+ -+@hookimpl -+def pylsp_initialize(config) -> None: -+ for mod_name in config.plugin_settings("preload").get("modules", []): -+ try: -+ __import__(mod_name) -+ log.debug("Preloaded module %s", mod_name) -+ except Exception: -+ # Catch any exception since not only ImportError can be raised here -+ # For example, old versions of NumPy can cause a ValueError. -+ # See spyder-ide/spyder#13985 -+ pass -+ -+def _get_severity(code): -+ # Are style errors ever really errors? -+ if code[0] == "E" or code[0] == "W": -+ return lsp.DiagnosticSeverity.Warning -+ # If no severity is specified, why wouldn't this be informational only? -+ return lsp.DiagnosticSeverity.Information -+ -+class PyCodeStyleDiagnosticReport(pycodestyle.BaseReport): -+ def __init__(self, options) -> None: -+ self.diagnostics = [] -+ super().__init__(options=options) -+ -+ def error(self, line_number, offset, text, check): -+ code = text[:4] -+ if self._ignore_code(code): -+ return -+ -+ # Don't care about expected errors or warnings -+ if code in self.expected: -+ return -+ -+ # PyCodeStyle will sometimes give you an error the line after the end of the file -+ # e.g. no newline at end of file -+ # In that case, the end offset should just be some number ~100 -+ # (because why not? There's nothing to underline anyways) -+ err_range = { -+ "start": {"line": line_number - 1, "character": offset}, -+ "end": { -+ # FIXME: It's a little naiive to mark until the end of the line, can we not easily do better? -+ "line": line_number - 1, -+ "character": 100 -+ if line_number > len(self.lines) -+ else len(self.lines[line_number - 1]), -+ }, -+ } -+ diagnostic = { -+ "source": "pycodestyle", -+ "range": err_range, -+ "message": text, -+ "code": code, -+ # Are style errors really ever errors? -+ "severity": _get_severity(code), -+ } -+ if code.startswith("W6"): -+ diagnostic["tags"] = [lsp.DiagnosticTag.Deprecated] -+ self.diagnostics.append(diagnostic) -+ -+@hookimpl -+def pylsp_lint(workspace, document): -+ with workspace.report_progress("lint: pycodestyle"): -+ config = workspace._config -+ settings = config.plugin_settings("pycodestyle", document_path=document.path) -+ log.debug("Got pycodestyle settings: %s", settings) -+ -+ opts = { -+ "exclude": settings.get("exclude"), -+ "filename": settings.get("filename"), -+ "hang_closing": settings.get("hangClosing"), -+ "ignore": settings.get("ignore"), -+ "max_line_length": settings.get("maxLineLength"), -+ "indent_size": settings.get("indentSize"), -+ "select": settings.get("select"), -+ } -+ kwargs = {k: v for k, v in opts.items() if v} -+ styleguide = pycodestyle.StyleGuide(kwargs) -+ -+ # Use LF to lint file because other line endings can give false positives. -+ # See spyder-ide/spyder#19565 for context. -+ source = document.source -+ eol_chars = get_eol_chars(source) -+ if eol_chars in ["\r", "\r\n"]: -+ source = source.replace(eol_chars, "\n") -+ lines = source.splitlines(keepends=True) -+ else: -+ lines = document.lines -+ -+ c = pycodestyle.Checker( -+ filename=document.path, -+ lines=lines, -+ options=styleguide.options, -+ report=PyCodeStyleDiagnosticReport(styleguide.options), -+ ) -+ c.check_all() -+ diagnostics = c.report.diagnostics -+ -+ return diagnostics -+ -+@hookimpl -+def pylsp_settings(): -+ # Default pydocstyle to disabled -+ return {"plugins": {"pydocstyle": {"enabled": False}}} -+ -+DEFAULT_MATCH_RE = pydocstyle.config.ConfigurationParser.DEFAULT_MATCH_RE -+ -+DEFAULT_MATCH_DIR_RE = pydocstyle.config.ConfigurationParser.DEFAULT_MATCH_DIR_RE -+ -+def _parse_diagnostic(document, error): -+ lineno = error.definition.start - 1 -+ line = document.lines[0] if document.lines else "" -+ -+ start_character = len(line) - len(line.lstrip()) -+ end_character = len(line) -+ -+ return { -+ "source": "pydocstyle", -+ "code": error.code, -+ "message": error.message, -+ "severity": lsp.DiagnosticSeverity.Warning, -+ "range": { -+ "start": {"line": lineno, "character": start_character}, -+ "end": {"line": lineno, "character": end_character}, -+ }, -+ } -+ -+@contextlib.contextmanager -+def _patch_sys_argv(arguments) -> None: -+ old_args = sys.argv -+ -+ # Preserve argv[0] since it's the executable -+ sys.argv = old_args[0:1] + arguments -+ -+ try: -+ yield -+ finally: -+ sys.argv = old_args -+ -+@hookimpl -+def pylsp_lint(config, workspace, document): -+ with workspace.report_progress("lint: pydocstyle"): -+ settings = config.plugin_settings("pydocstyle", document_path=document.path) -+ log.debug("Got pydocstyle settings: %s", settings) -+ -+ # Explicitly passing a path to pydocstyle means it doesn't respect the --match flag, so do it ourselves -+ filename_match_re = re.compile(settings.get("match", DEFAULT_MATCH_RE) + "$") -+ if not filename_match_re.match(os.path.basename(document.path)): -+ return [] -+ -+ # Likewise with --match-dir -+ dir_match_re = re.compile(settings.get("matchDir", DEFAULT_MATCH_DIR_RE) + "$") -+ if not dir_match_re.match(os.path.basename(os.path.dirname(document.path))): -+ return [] -+ -+ args = [document.path] -+ -+ if settings.get("convention"): -+ args.append("--convention=" + settings["convention"]) -+ -+ if settings.get("addSelect"): -+ args.append("--add-select=" + ",".join(settings["addSelect"])) -+ if settings.get("addIgnore"): -+ args.append("--add-ignore=" + ",".join(settings["addIgnore"])) -+ -+ elif settings.get("select"): -+ args.append("--select=" + ",".join(settings["select"])) -+ elif settings.get("ignore"): -+ args.append("--ignore=" + ",".join(settings["ignore"])) -+ -+ log.info("Using pydocstyle args: %s", args) -+ -+ conf = pydocstyle.config.ConfigurationParser() -+ with _patch_sys_argv(args): -+ # TODO(gatesn): We can add more pydocstyle args here from our pylsp config -+ conf.parse() -+ -+ # Will only yield a single filename, the document path -+ diags = [] -+ for ( -+ filename, -+ checked_codes, -+ ignore_decorators, -+ property_decorators, -+ ignore_self_only_init, -+ ) in conf.get_files_to_check(): -+ errors = pydocstyle.checker.ConventionChecker().check_source( -+ document.source, -+ filename, -+ ignore_decorators=ignore_decorators, -+ property_decorators=property_decorators, -+ ignore_self_only_init=ignore_self_only_init, -+ ) -+ -+ try: -+ for error in errors: -+ if error.code not in checked_codes: -+ continue -+ diags.append(_parse_diagnostic(document, error)) -+ except pydocstyle.parser.ParseError: -+ # In the case we cannot parse the Python file, just continue -+ pass -+ -+ log.debug("Got pydocstyle errors: %s", diags) -+ return diags -+ -+# Pyflakes messages that should be reported as Errors instead of Warns -+PYFLAKES_ERROR_MESSAGES = ( -+ messages.UndefinedName, -+ messages.UndefinedExport, -+ messages.UndefinedLocal, -+ messages.DuplicateArgument, -+ messages.FutureFeatureNotDefined, -+ messages.ReturnOutsideFunction, -+ messages.YieldOutsideFunction, -+ messages.ContinueOutsideLoop, -+ messages.BreakOutsideLoop, -+ messages.TwoStarredExpressions, -+) -+ -+class PyflakesDiagnosticReport: -+ def __init__(self, lines) -> None: -+ self.lines = lines -+ self.diagnostics = [] -+ -+ def unexpectedError(self, _filename, msg) -> None: # pragma: no cover -+ err_range = { -+ "start": {"line": 0, "character": 0}, -+ "end": {"line": 0, "character": 0}, -+ } -+ self.diagnostics.append( -+ { -+ "source": "pyflakes", -+ "range": err_range, -+ "message": msg, -+ "severity": lsp.DiagnosticSeverity.Error, -+ } -+ ) -+ -+ def syntaxError(self, _filename, msg, lineno, offset, text) -> None: -+ # We've seen that lineno and offset can sometimes be None -+ lineno = lineno or 1 -+ offset = offset or 0 -+ # could be None if the error is due to an invalid encoding -+ # see e.g. https://github.com/python-lsp/python-lsp-server/issues/429 -+ text = text or "" -+ -+ err_range = { -+ "start": {"line": lineno - 1, "character": offset}, -+ "end": {"line": lineno - 1, "character": offset + len(text)}, -+ } -+ self.diagnostics.append( -+ { -+ "source": "pyflakes", -+ "range": err_range, -+ "message": msg, -+ "severity": lsp.DiagnosticSeverity.Error, -+ } -+ ) -+ -+ def flake(self, message) -> None: -+ """Get message like :: """ -+ err_range = { -+ "start": {"line": message.lineno - 1, "character": message.col}, -+ "end": { -+ "line": message.lineno - 1, -+ "character": len(self.lines[message.lineno - 1]), -+ }, -+ } -+ -+ severity = lsp.DiagnosticSeverity.Warning -+ for message_type in PYFLAKES_ERROR_MESSAGES: -+ if isinstance(message, message_type): -+ severity = lsp.DiagnosticSeverity.Error -+ break -+ -+ self.diagnostics.append( -+ { -+ "source": "pyflakes", -+ "range": err_range, -+ "message": message.message % message.message_args, -+ "severity": severity, -+ } -+ ) -+ -+@hookimpl -+def pylsp_lint(workspace, document): -+ with workspace.report_progress("lint: pyflakes"): -+ reporter = PyflakesDiagnosticReport(document.lines) -+ pyflakes_api.check( -+ document.source.encode("utf-8"), document.path, reporter=reporter -+ ) -+ return reporter.diagnostics -+ -+@hookimpl -+def pylsp_settings(): -+ # Default pylint to disabled because it requires a config -+ # file to be useful. -+ return { -+ "plugins": { -+ "pylint": { -+ "enabled": False, -+ "args": [], -+ # disabled by default as it can slow down the workflow -+ "executable": None, -+ } -+ } -+ } -+ -+DEPRECATION_CODES = { -+ "W0402", # Uses of a deprecated module %r -+ "W1505", # Using deprecated method %s() -+ "W1511", # Using deprecated argument %s of method %s() -+ "W1512", # Using deprecated class %s of module %s -+ "W1513", # Using deprecated decorator %s() -+} -+ -+UNNECESSITY_CODES = { -+ "W0611", # Unused import %s -+ "W0612", # Unused variable %r -+ "W0613", # Unused argument %r -+ "W0614", # Unused import %s from wildcard import -+ "W1304", # Unused-format-string-argument -+} -+ -+class PylintLinter: -+ last_diags = collections.defaultdict(list) -+ -+ @classmethod -+ def lint(cls, document, is_saved, flags=""): -+ """Plugin interface to pylsp linter. -+ -+ Args: -+ document: The document to be linted. -+ is_saved: Whether or not the file has been saved to disk. -+ flags: Additional flags to pass to pylint. Not exposed to -+ pylsp_lint, but used for testing. -+ -+ Returns: -+ A list of dicts with the following format: -+ -+ { -+ 'source': 'pylint', -+ 'range': { -+ 'start': { -+ 'line': start_line, -+ 'character': start_column, -+ }, -+ 'end': { -+ 'line': end_line, -+ 'character': end_column, -+ }, -+ } -+ 'message': msg, -+ 'severity': lsp.DiagnosticSeverity.*, -+ } -+ """ -+ if not is_saved: -+ # Pylint can only be run on files that have been saved to disk. -+ # Rather than return nothing, return the previous list of -+ # diagnostics. If we return an empty list, any diagnostics we'd -+ # previously shown will be cleared until the next save. Instead, -+ # continue showing (possibly stale) diagnostics until the next -+ # save. -+ return cls.last_diags[document.path] -+ -+ cmd = [ -+ sys.executable, -+ "-c", -+ "import sys; from pylint.lint import Run; Run(sys.argv[1:])", -+ "-f", -+ "json", -+ document.path, -+ ] + (shlex.split(str(flags)) if flags else []) -+ log.debug("Calling pylint with '%s'", " ".join(cmd)) -+ -+ cwd = document._workspace.root_path -+ if not cwd: -+ cwd = os.path.dirname(__file__) -+ -+ with Popen( -+ cmd, stdout=PIPE, stderr=PIPE, cwd=cwd, universal_newlines=True -+ ) as process: -+ json_out, err = process.communicate() -+ -+ if err != "": -+ log.error("Error calling pylint: '%s'", err) -+ -+ # pylint prints nothing rather than [] when there are no diagnostics. -+ # json.loads will not parse an empty string, so just return. -+ if not json_out.strip(): -+ cls.last_diags[document.path] = [] -+ return [] -+ -+ # Pylint's JSON output is a list of objects with the following format. -+ # -+ # { -+ # "obj": "main", -+ # "path": "foo.py", -+ # "message": "Missing function docstring", -+ # "message-id": "C0111", -+ # "symbol": "missing-docstring", -+ # "column": 0, -+ # "type": "convention", -+ # "line": 5, -+ # "module": "foo" -+ # } -+ # -+ # The type can be any of: -+ # -+ # * convention -+ # * information -+ # * error -+ # * fatal -+ # * refactor -+ # * warning -+ diagnostics = [] -+ for diag in json.loads(json_out): -+ # pylint lines index from 1, pylsp lines index from 0 -+ line = diag["line"] - 1 -+ -+ err_range = { -+ "start": { -+ "line": line, -+ # Index columns start from 0 -+ "character": diag["column"], -+ }, -+ "end": { -+ "line": line, -+ # It's possible that we're linting an empty file. Even an empty -+ # file might fail linting if it isn't named properly. -+ "character": len(document.lines[line]) if document.lines else 0, -+ }, -+ } -+ -+ if diag["type"] == "convention": -+ severity = lsp.DiagnosticSeverity.Information -+ elif diag["type"] == "information": -+ severity = lsp.DiagnosticSeverity.Information -+ elif diag["type"] == "error": -+ severity = lsp.DiagnosticSeverity.Error -+ elif diag["type"] == "fatal": -+ severity = lsp.DiagnosticSeverity.Error -+ elif diag["type"] == "refactor": -+ severity = lsp.DiagnosticSeverity.Hint -+ elif diag["type"] == "warning": -+ severity = lsp.DiagnosticSeverity.Warning -+ -+ code = diag["message-id"] -+ -+ diagnostic = { -+ "source": "pylint", -+ "range": err_range, -+ "message": "[{}] {}".format(diag["symbol"], diag["message"]), -+ "severity": severity, -+ "code": code, -+ } -+ -+ if code in UNNECESSITY_CODES: -+ diagnostic["tags"] = [lsp.DiagnosticTag.Unnecessary] -+ if code in DEPRECATION_CODES: -+ diagnostic["tags"] = [lsp.DiagnosticTag.Deprecated] -+ -+ diagnostics.append(diagnostic) -+ cls.last_diags[document.path] = diagnostics -+ return diagnostics -+ -+def _build_pylint_flags(settings): -+ """Build arguments for calling pylint.""" -+ pylint_args = settings.get("args") -+ if pylint_args is None: -+ return "" -+ return " ".join(pylint_args) -+ -+def build_args_stdio(settings): -+ """Build arguments for calling pylint. -+ -+ :param settings: client settings -+ :type settings: dict -+ -+ :return: arguments to path to pylint -+ :rtype: list -+ """ -+ pylint_args = settings.get("args") -+ if pylint_args is None: -+ return [] -+ return pylint_args -+ -+def _run_pylint_stdio(pylint_executable, document, flags): -+ """Run pylint in popen. -+ -+ :param pylint_executable: path to pylint executable -+ :type pylint_executable: string -+ :param document: document to run pylint on -+ :type document: pylsp.workspace.Document -+ :param flags: arguments to path to pylint -+ :type flags: list -+ -+ :return: result of calling pylint -+ :rtype: string -+ """ -+ log.debug("Calling %s with args: '%s'", pylint_executable, flags) -+ try: -+ cmd = [pylint_executable] -+ cmd.extend(flags) -+ cmd.extend(["--from-stdin", document.path]) -+ p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) -+ except IOError: -+ log.debug("Can't execute %s. Trying with 'python -m pylint'", pylint_executable) -+ cmd = [sys.executable, "-m", "pylint"] -+ cmd.extend(flags) -+ cmd.extend(["--from-stdin", document.path]) -+ p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) -+ (stdout, stderr) = p.communicate(document.source.encode()) -+ if stderr: -+ log.error("Error while running pylint '%s'", stderr.decode()) -+ return stdout.decode() -+ -+def _parse_pylint_stdio_result(document, stdout): -+ """Parse pylint results. -+ -+ :param document: document to run pylint on -+ :type document: pylsp.workspace.Document -+ :param stdout: pylint results to parse -+ :type stdout: string -+ -+ :return: linting diagnostics -+ :rtype: list -+ """ -+ diagnostics = [] -+ lines = stdout.splitlines() -+ for raw_line in lines: -+ parsed_line = re.match(r"(.*):(\d*):(\d*): (\w*): (.*)", raw_line) -+ if not parsed_line: -+ log.debug("Pylint output parser can't parse line '%s'", raw_line) -+ continue -+ -+ parsed_line = parsed_line.groups() -+ if len(parsed_line) != 5: -+ log.debug("Pylint output parser can't parse line '%s'", raw_line) -+ continue -+ -+ _, line, character, code, msg = parsed_line -+ line = int(line) - 1 -+ character = int(character) -+ severity_map = { -+ "C": lsp.DiagnosticSeverity.Information, -+ "E": lsp.DiagnosticSeverity.Error, -+ "F": lsp.DiagnosticSeverity.Error, -+ "I": lsp.DiagnosticSeverity.Information, -+ "R": lsp.DiagnosticSeverity.Hint, -+ "W": lsp.DiagnosticSeverity.Warning, -+ } -+ severity = severity_map[code[0]] -+ diagnostic = { -+ "source": "pylint", -+ "code": code, -+ "range": { -+ "start": {"line": line, "character": character}, -+ "end": { -+ "line": line, -+ # no way to determine the column -+ "character": len(document.lines[line]) - 1, -+ }, -+ }, -+ "message": msg, -+ "severity": severity, -+ } -+ if code in UNNECESSITY_CODES: -+ diagnostic["tags"] = [lsp.DiagnosticTag.Unnecessary] -+ if code in DEPRECATION_CODES: -+ diagnostic["tags"] = [lsp.DiagnosticTag.Deprecated] -+ diagnostics.append(diagnostic) -+ -+ return diagnostics -+ -+def pylint_lint_stdin(pylint_executable, document, flags): -+ """Run pylint linter from stdin. -+ -+ This runs pylint in a subprocess with popen. -+ This allows passing the file from stdin and as a result -+ run pylint on unsaved files. Can slowdown the workflow. -+ -+ :param pylint_executable: path to pylint executable -+ :type pylint_executable: string -+ :param document: document to run pylint on -+ :type document: pylsp.workspace.Document -+ :param flags: arguments to path to pylint -+ :type flags: list -+ -+ :return: linting diagnostics -+ :rtype: list -+ """ -+ pylint_result = _run_pylint_stdio(pylint_executable, document, flags) -+ return _parse_pylint_stdio_result(document, pylint_result) -+ -+@hookimpl -+def pylsp_lint(config, workspace, document, is_saved): -+ """Run pylint linter.""" -+ with workspace.report_progress("lint: pylint"): -+ settings = config.plugin_settings("pylint") -+ log.debug("Got pylint settings: %s", settings) -+ # pylint >= 2.5.0 is required for working through stdin and only -+ # available with python3 -+ if settings.get("executable") and sys.version_info[0] >= 3: -+ flags = build_args_stdio(settings) -+ pylint_executable = settings.get("executable", "pylint") -+ return pylint_lint_stdin(pylint_executable, document, flags) -+ flags = _build_pylint_flags(settings) -+ return PylintLinter.lint(document, is_saved, flags=flags) -+ -+@hookimpl -+def pylsp_references(document, position, exclude_declaration): -+ code_position = _utils.position_to_jedi_linecolumn(document, position) -+ usages = document.jedi_script().get_references(**code_position) -+ -+ if exclude_declaration: -+ # Filter out if the usage is the actual declaration of the thing -+ usages = [d for d in usages if not d.is_definition()] -+ -+ # Filter out builtin modules -+ return [ -+ { -+ "uri": uris.uri_with(document.uri, path=str(d.module_path)) -+ if d.module_path -+ else document.uri, -+ "range": { -+ "start": {"line": d.line - 1, "character": d.column}, -+ "end": {"line": d.line - 1, "character": d.column + len(d.name)}, -+ }, -+ } -+ for d in usages -+ if not d.in_builtin_module() -+ ] -+ -+@hookimpl -+def pylsp_settings() -> Dict[str, Dict[str, Dict[str, Any]]]: -+ # Default rope_completion to disabled -+ return { -+ "plugins": { -+ "rope_autoimport": { -+ "enabled": False, -+ "memory": False, -+ "completions": { -+ "enabled": True, -+ }, -+ "code_actions": { -+ "enabled": True, -+ }, -+ } -+ } -+ } -+ -+MAX_RESULTS_COMPLETIONS = 1000 -+ -+def _should_import_class(word_node: tree.Leaf, expr: tree.BaseNode) -> bool: -+ prev_node = None -+ for node in expr.children: -+ if isinstance(node, tree.Name): -+ if isinstance(prev_node, tree.Operator): -+ if node == word_node and prev_node.value == "(": -+ return True -+ prev_node = node -+ -+ return False -+ -+def _handle_argument(node: NodeOrLeaf, word_node: tree.Leaf): -+ if isinstance(node, tree.PythonNode): -+ if node.type == "tfpdef": -+ if node.children[2] == word_node: -+ return True -+ if node.type == "parameters": -+ for parameter in node.children: -+ if _handle_argument(parameter, word_node): -+ return True -+ return False -+ -+def _should_import_function(word_node: tree.Leaf, expr: tree.BaseNode) -> bool: -+ prev_node = None -+ for node in expr.children: -+ if _handle_argument(node, word_node): -+ return True -+ if isinstance(prev_node, tree.Operator): -+ if prev_node.value == "->": -+ if node == word_node: -+ return True -+ prev_node = node -+ return False -+ -+def _handle_first_child( -+ first_child: NodeOrLeaf, expr: tree.BaseNode, word_node: tree.Leaf -+) -> bool: -+ """Check if we suggest imports given the following first child.""" -+ if isinstance(first_child, tree.Import): -+ return False -+ if isinstance(first_child, (tree.PythonLeaf, tree.PythonErrorLeaf)): -+ # Check if the first item is a from or import statement even when incomplete -+ if first_child.value in ("import", "from"): -+ return False -+ if isinstance(first_child, tree.Keyword): -+ if first_child.value == "def": -+ return _should_import_function(word_node, expr) -+ if first_child.value == "class": -+ return _should_import_class(word_node, expr) -+ return True -+ -+def _should_insert(expr: tree.BaseNode, word_node: tree.Leaf) -> bool: -+ """ -+ Check if we should insert the word_node on the given expr. -+ -+ Works for both correct and incorrect code. This is because the -+ user is often working on the code as they write it. -+ """ -+ if not word_node: -+ return False -+ if len(expr.children) == 0: -+ return True -+ first_child = expr.children[0] -+ if isinstance(first_child, tree.EndMarker): -+ if "#" in first_child.prefix: -+ return False # Check for single line comment -+ if first_child == word_node: -+ return True # If the word is the first word then its fine -+ if len(expr.children) > 1: -+ if any( -+ node.type == "operator" and "." in node.value or node.type == "trailer" -+ for node in expr.children -+ ): -+ return False # Check if we're on a method of a function -+ if isinstance(first_child, (tree.PythonErrorNode, tree.PythonNode)): -+ # The tree will often include error nodes like this to indicate errors -+ # we want to ignore errors since the code is being written -+ return _should_insert(first_child, word_node) -+ return _handle_first_child(first_child, expr, word_node) -+ -+_score_pow = 5 -+ -+_score_max = 10**_score_pow -+ -+def _document(import_statement: str) -> str: -+ return """# Auto-Import\n""" + import_statement -+ -+def _get_score( -+ source: int, full_statement: str, suggested_name: str, desired_name -+) -> int: -+ import_length = len("import") -+ full_statement_score = len(full_statement) - import_length -+ suggested_name_score = (len(suggested_name) - len(desired_name)) ** 2 -+ source_score = 20 * source -+ return suggested_name_score + full_statement_score + source_score -+ -+def _sort_import(score: int) -> str: -+ score = max(min(score, (_score_max) - 1), 0) -+ # Since we are using ints, we need to pad them. -+ # We also want to prioritize autoimport behind everything since its the last priority. -+ # The minimum is to prevent score from overflowing the pad -+ return "[z" + str(score).rjust(_score_pow, "0") -+ -+def _process_statements( -+ suggestions: List[SearchResult], -+ doc_uri: str, -+ word: str, -+ autoimport: AutoImport, -+ document: Document, -+ feature: str = "completions", -+) -> Generator[Dict[str, Any], None, None]: -+ for suggestion in suggestions: -+ insert_line = autoimport.find_insertion_line(document.source) - 1 -+ start = {"line": insert_line, "character": 0} -+ edit_range = {"start": start, "end": start} -+ edit = {"range": edit_range, "newText": suggestion.import_statement + "\n"} -+ score = _get_score( -+ suggestion.source, suggestion.import_statement, suggestion.name, word -+ ) -+ if score > _score_max: -+ continue -+ # TODO make this markdown -+ if feature == "completions": -+ yield { -+ "label": suggestion.name, -+ "kind": suggestion.itemkind, -+ "sortText": _sort_import(score), -+ "data": {"doc_uri": doc_uri}, -+ "detail": _document(suggestion.import_statement), -+ "additionalTextEdits": [edit], -+ } -+ elif feature == "code_actions": -+ yield { -+ "title": suggestion.import_statement, -+ "kind": "quickfix", -+ "edit": {"changes": {doc_uri: [edit]}}, -+ # data is a supported field for codeAction responses -+ # See https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_codeAction -+ "data": {"sortText": _sort_import(score)}, -+ } -+ else: -+ raise ValueError(f"Unknown feature: {feature}") -+ -+def get_names(script: Script) -> Set[str]: -+ """Get all names to ignore from the current file.""" -+ raw_names = script.get_names(definitions=True) -+ log.debug(raw_names) -+ return {name.name for name in raw_names} -+ -+class AutoimportCache: -+ """Handles the cache creation.""" -+ -+ def __init__(self) -> None: -+ self.thread = None -+ -+ def reload_cache( -+ self, -+ config: Config, -+ workspace: Workspace, -+ files: Optional[List[Document]] = None, -+ single_thread: Optional[bool] = True, -+ ): -+ if self.is_blocked(): -+ return -+ -+ memory: bool = config.plugin_settings("rope_autoimport").get("memory", False) -+ rope_config = config.settings().get("rope", {}) -+ autoimport = workspace._rope_autoimport(rope_config, memory) -+ resources: Optional[List[Resource]] = ( -+ None -+ if files is None -+ else [document._rope_resource(rope_config) for document in files] -+ ) -+ -+ if single_thread: -+ self._reload_cache(workspace, autoimport, resources) -+ else: -+ # Creating the cache may take 10-20s for a environment with 5k python modules. That's -+ # why we decided to move cache creation into its own thread. -+ self.thread = threading.Thread( -+ target=self._reload_cache, args=(workspace, autoimport, resources) -+ ) -+ self.thread.start() -+ -+ def _reload_cache( -+ self, -+ workspace: Workspace, -+ autoimport: AutoImport, -+ resources: Optional[List[Resource]] = None, -+ ) -> None: -+ task_handle = PylspTaskHandle(workspace) -+ autoimport.generate_cache(task_handle=task_handle, resources=resources) -+ autoimport.generate_modules_cache(task_handle=task_handle) -+ -+ def is_blocked(self): -+ return self.thread and self.thread.is_alive() -+ -+cache: AutoimportCache = AutoimportCache() -+ -+@hookimpl -+def pylsp_completions( -+ config: Config, -+ workspace: Workspace, -+ document: Document, -+ position, -+ ignored_names: Union[Set[str], None], -+): -+ """Get autoimport suggestions.""" -+ if ( -+ not config.plugin_settings("rope_autoimport") -+ .get("completions", {}) -+ .get("enabled", True) -+ ) or cache.is_blocked(): -+ return [] -+ -+ line = document.lines[position["line"]] -+ expr = parso.parse(line) -+ word_node = expr.get_leaf_for_position((1, position["character"])) -+ if not _should_insert(expr, word_node): -+ return [] -+ word = word_node.value -+ log.debug(f"autoimport: searching for word: {word}") -+ rope_config = config.settings(document_path=document.path).get("rope", {}) -+ ignored_names: Set[str] = ignored_names or get_names( -+ document.jedi_script(use_document_path=True) -+ ) -+ autoimport = workspace._rope_autoimport(rope_config) -+ suggestions = list(autoimport.search_full(word, ignored_names=ignored_names)) -+ results = sorted( -+ _process_statements( -+ suggestions, document.uri, word, autoimport, document, "completions" -+ ), -+ key=lambda statement: statement["sortText"], -+ ) -+ if len(results) > MAX_RESULTS_COMPLETIONS: -+ results = results[:MAX_RESULTS_COMPLETIONS] -+ return results -+ -+MAX_RESULTS_CODE_ACTIONS = 5 -+ -+def get_name_or_module(document, diagnostic) -> str: -+ start = diagnostic["range"]["start"] -+ return ( -+ parso.parse(document.lines[start["line"]]) -+ .get_leaf_for_position((1, start["character"] + 1)) -+ .value -+ ) -+ -+@hookimpl -+def pylsp_code_actions( -+ config: Config, -+ workspace: Workspace, -+ document: Document, -+ range: Dict, -+ context: Dict, -+) -> List[Dict]: -+ """ -+ Provide code actions through rope. -+ -+ Parameters -+ ---------- -+ config : pylsp.config.config.Config -+ Current config. -+ workspace : pylsp.workspace.Workspace -+ Current workspace. -+ document : pylsp.workspace.Document -+ Document to apply code actions on. -+ range : Dict -+ Range argument given by pylsp. Not used here. -+ context : Dict -+ CodeActionContext given as dict. -+ -+ Returns -+ ------- -+ List of dicts containing the code actions. -+ """ -+ if ( -+ not config.plugin_settings("rope_autoimport") -+ .get("code_actions", {}) -+ .get("enabled", True) -+ ) or cache.is_blocked(): -+ return [] -+ -+ log.debug(f"textDocument/codeAction: {document} {range} {context}") -+ code_actions = [] -+ for diagnostic in context.get("diagnostics", []): -+ if "undefined name" not in diagnostic.get("message", "").lower(): -+ continue -+ -+ word = get_name_or_module(document, diagnostic) -+ log.debug(f"autoimport: searching for word: {word}") -+ rope_config = config.settings(document_path=document.path).get("rope", {}) -+ autoimport = workspace._rope_autoimport(rope_config) -+ suggestions = list(autoimport.search_full(word)) -+ log.debug("autoimport: suggestions: %s", suggestions) -+ results = sorted( -+ _process_statements( -+ suggestions, -+ document.uri, -+ word, -+ autoimport, -+ document, -+ "code_actions", -+ ), -+ key=lambda statement: statement["data"]["sortText"], -+ ) -+ -+ if len(results) > MAX_RESULTS_CODE_ACTIONS: -+ results = results[:MAX_RESULTS_CODE_ACTIONS] -+ code_actions.extend(results) -+ -+ return code_actions -+ -+@hookimpl -+def pylsp_initialize(config: Config, workspace: Workspace) -> None: -+ """Initialize AutoImport. -+ -+ Generates the cache for local and global items. -+ """ -+ cache.reload_cache(config, workspace) -+ -+@hookimpl -+def pylsp_document_did_open(config: Config, workspace: Workspace) -> None: -+ """Initialize AutoImport. -+ -+ Generates the cache for local and global items. -+ """ -+ cache.reload_cache(config, workspace) -+ -+@hookimpl -+def pylsp_document_did_save( -+ config: Config, workspace: Workspace, document: Document -+) -> None: -+ """Update the names associated with this document.""" -+ cache.reload_cache(config, workspace, [document]) -+ -+@hookimpl -+def pylsp_workspace_configuration_changed(config: Config, workspace: Workspace) -> None: -+ """ -+ Initialize autoimport if it has been enabled through a -+ workspace/didChangeConfiguration message from the frontend. -+ -+ Generates the cache for local and global items. -+ """ -+ if config.plugin_settings("rope_autoimport").get("enabled", False): -+ cache.reload_cache(config, workspace) -+ else: -+ log.debug("autoimport: Skipping cache reload.") -+ -+@hookimpl -+def pylsp_settings(): -+ # Default rope_completion to disabled -+ return {"plugins": {"rope_completion": {"enabled": False, "eager": False}}} -+ -+def _resolve_completion(completion, data, markup_kind): -+ try: -+ doc = _utils.format_docstring(data.get_doc(), markup_kind=markup_kind) -+ except Exception as e: -+ log.debug("Failed to resolve Rope completion: %s", e) -+ doc = "" -+ completion["detail"] = "{0} {1}".format(data.scope or "", data.name) -+ completion["documentation"] = doc -+ return completion -+ -+def _sort_text(definition): -+ """Ensure builtins appear at the bottom. -+ Description is of format : . -+ """ -+ if definition.name.startswith("_"): -+ # It's a 'hidden' func, put it next last -+ return "z" + definition.name -+ if definition.scope == "builtin": -+ return "y" + definition.name -+ -+ # Else put it at the front -+ return "a" + definition.name -+ -+def _kind(d): -+ """Return the LSP type""" -+ MAP = { -+ "none": lsp.CompletionItemKind.Value, -+ "type": lsp.CompletionItemKind.Class, -+ "tuple": lsp.CompletionItemKind.Class, -+ "dict": lsp.CompletionItemKind.Class, -+ "dictionary": lsp.CompletionItemKind.Class, -+ "function": lsp.CompletionItemKind.Function, -+ "lambda": lsp.CompletionItemKind.Function, -+ "generator": lsp.CompletionItemKind.Function, -+ "class": lsp.CompletionItemKind.Class, -+ "instance": lsp.CompletionItemKind.Reference, -+ "method": lsp.CompletionItemKind.Method, -+ "builtin": lsp.CompletionItemKind.Class, -+ "builtinfunction": lsp.CompletionItemKind.Function, -+ "module": lsp.CompletionItemKind.Module, -+ "file": lsp.CompletionItemKind.File, -+ "xrange": lsp.CompletionItemKind.Class, -+ "slice": lsp.CompletionItemKind.Class, -+ "traceback": lsp.CompletionItemKind.Class, -+ "frame": lsp.CompletionItemKind.Class, -+ "buffer": lsp.CompletionItemKind.Class, -+ "dictproxy": lsp.CompletionItemKind.Class, -+ "funcdef": lsp.CompletionItemKind.Function, -+ "property": lsp.CompletionItemKind.Property, -+ "import": lsp.CompletionItemKind.Module, -+ "keyword": lsp.CompletionItemKind.Keyword, -+ "constant": lsp.CompletionItemKind.Variable, -+ "variable": lsp.CompletionItemKind.Variable, -+ "value": lsp.CompletionItemKind.Value, -+ "param": lsp.CompletionItemKind.Variable, -+ "statement": lsp.CompletionItemKind.Keyword, -+ } -+ -+ return MAP.get(d.type) -+ -+@hookimpl -+def pylsp_completions(config, workspace, document, position): -+ settings = config.plugin_settings("rope_completion", document_path=document.path) -+ resolve_eagerly = settings.get("eager", False) -+ -+ # Rope is a bit rubbish at completing module imports, so we'll return None -+ word = document.word_at_position( -+ { -+ # The -1 should really be trying to look at the previous word, but that might be quite expensive -+ # So we only skip import completions when the cursor is one space after `import` -+ "line": position["line"], -+ "character": max(position["character"] - 1, 0), -+ } -+ ) -+ if word == "import": -+ return None -+ -+ offset = document.offset_at_position(position) -+ rope_config = config.settings(document_path=document.path).get("rope", {}) -+ rope_project = workspace._rope_project_builder(rope_config) -+ document_rope = document._rope_resource(rope_config) -+ -+ completion_capabilities = config.capabilities.get("textDocument", {}).get( -+ "completion", {} -+ ) -+ item_capabilities = completion_capabilities.get("completionItem", {}) -+ supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -+ preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -+ -+ try: -+ definitions = code_assist( -+ rope_project, document.source, offset, document_rope, maxfixes=3 -+ ) -+ except Exception as e: -+ log.debug("Failed to run Rope code assist: %s", e) -+ return [] -+ -+ definitions = sorted_proposals(definitions) -+ new_definitions = [] -+ for d in definitions: -+ item = { -+ "label": d.name, -+ "kind": _kind(d), -+ "sortText": _sort_text(d), -+ "data": {"doc_uri": document.uri}, -+ } -+ if resolve_eagerly: -+ item = _resolve_completion(item, d, preferred_markup_kind) -+ new_definitions.append(item) -+ -+ # most recently retrieved completion items, used for resolution -+ document.shared_data["LAST_ROPE_COMPLETIONS"] = { -+ # label is the only required property; here it is assumed to be unique -+ completion["label"]: (completion, data) -+ for completion, data in zip(new_definitions, definitions) -+ } -+ -+ definitions = new_definitions -+ -+ return definitions or None -+ -+@hookimpl -+def pylsp_completion_item_resolve(config, completion_item, document): -+ """Resolve formatted completion for given non-resolved completion""" -+ shared_data = document.shared_data["LAST_ROPE_COMPLETIONS"].get( -+ completion_item["label"] -+ ) -+ -+ completion_capabilities = config.capabilities.get("textDocument", {}).get( -+ "completion", {} -+ ) -+ item_capabilities = completion_capabilities.get("completionItem", {}) -+ supported_markup_kinds = item_capabilities.get("documentationFormat", ["markdown"]) -+ preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -+ -+ if shared_data: -+ completion, data = shared_data -+ return _resolve_completion(completion, data, preferred_markup_kind) -+ return completion_item -+ -+SPHINX = re.compile(r"\s*:param\s+(?P\w+):\s*(?P[^\n]+)") -+ -+EPYDOC = re.compile(r"\s*@param\s+(?P\w+):\s*(?P[^\n]+)") -+ -+GOOGLE = re.compile(r"\s*(?P\w+).*:\s*(?P[^\n]+)") -+ -+DOC_REGEX = [SPHINX, EPYDOC, GOOGLE] -+ -+def _param_docs(docstring, param_name): -+ for line in docstring.splitlines(): -+ for regex in DOC_REGEX: -+ m = regex.match(line) -+ if not m: -+ continue -+ if m.group("param") != param_name: -+ continue -+ return m.group("doc") or "" -+ -+@hookimpl -+def pylsp_signature_help(config, document, position): -+ code_position = _utils.position_to_jedi_linecolumn(document, position) -+ signatures = document.jedi_script().get_signatures(**code_position) -+ -+ if not signatures: -+ return {"signatures": []} -+ -+ signature_capabilities = config.capabilities.get("textDocument", {}).get( -+ "signatureHelp", {} -+ ) -+ signature_information_support = signature_capabilities.get( -+ "signatureInformation", {} -+ ) -+ supported_markup_kinds = signature_information_support.get( -+ "documentationFormat", ["markdown"] -+ ) -+ preferred_markup_kind = _utils.choose_markup_kind(supported_markup_kinds) -+ -+ s = signatures[0] -+ -+ docstring = s.docstring() -+ -+ # Docstring contains one or more lines of signature, followed by empty line, followed by docstring -+ function_sig_lines = (docstring.split("\n\n") or [""])[0].splitlines() -+ function_sig = " ".join([line.strip() for line in function_sig_lines]) -+ sig = { -+ "label": function_sig, -+ "documentation": _utils.format_docstring( -+ s.docstring(raw=True), markup_kind=preferred_markup_kind -+ ), -+ } -+ -+ # If there are params, add those -+ if s.params: -+ sig["parameters"] = [ -+ { -+ "label": p.name, -+ "documentation": _utils.format_docstring( -+ _param_docs(docstring, p.name), markup_kind=preferred_markup_kind -+ ), -+ } -+ for p in s.params -+ ] -+ -+ # We only return a single signature because Python doesn't allow overloading -+ sig_info = {"signatures": [sig], "activeSignature": 0} -+ -+ if s.index is not None and s.params: -+ # Then we know which parameter we're looking at -+ sig_info["activeParameter"] = s.index -+ -+ return sig_info -+ -+def _include_def(definition): -+ return ( -+ # Don't tend to include parameters as symbols -+ definition.type != "param" -+ and -+ # Unused vars should also be skipped -+ definition.name != "_" -+ and _kind(definition) is not None -+ ) -+ -+def _container(definition): -+ try: -+ # Jedi sometimes fails here. -+ parent = definition.parent() -+ # Here we check that a grand-parent exists to avoid declaring symbols -+ # as children of the module. -+ if parent.parent(): -+ return parent.name -+ except: -+ return None -+ -+ return None -+ -+def _range(definition): -+ # This gets us more accurate end position -+ definition = definition._name.tree_name.get_definition() -+ (start_line, start_column) = definition.start_pos -+ (end_line, end_column) = definition.end_pos -+ return { -+ "start": {"line": start_line - 1, "character": start_column}, -+ "end": {"line": end_line - 1, "character": end_column}, -+ } -+ -+def _tuple_range(definition): -+ definition = definition._name.tree_name.get_definition() -+ return (definition.start_pos, definition.end_pos) -+ -+_SYMBOL_KIND_MAP = { -+ "none": SymbolKind.Variable, -+ "type": SymbolKind.Class, -+ "tuple": SymbolKind.Class, -+ "dict": SymbolKind.Class, -+ "dictionary": SymbolKind.Class, -+ "function": SymbolKind.Function, -+ "lambda": SymbolKind.Function, -+ "generator": SymbolKind.Function, -+ "class": SymbolKind.Class, -+ "instance": SymbolKind.Class, -+ "method": SymbolKind.Method, -+ "builtin": SymbolKind.Class, -+ "builtinfunction": SymbolKind.Function, -+ "module": SymbolKind.Module, -+ "file": SymbolKind.File, -+ "xrange": SymbolKind.Array, -+ "slice": SymbolKind.Class, -+ "traceback": SymbolKind.Class, -+ "frame": SymbolKind.Class, -+ "buffer": SymbolKind.Array, -+ "dictproxy": SymbolKind.Class, -+ "funcdef": SymbolKind.Function, -+ "property": SymbolKind.Property, -+ "import": SymbolKind.Module, -+ "keyword": SymbolKind.Variable, -+ "constant": SymbolKind.Constant, -+ "variable": SymbolKind.Variable, -+ "value": SymbolKind.Variable, -+ "param": SymbolKind.Variable, -+ "statement": SymbolKind.Variable, -+ "boolean": SymbolKind.Boolean, -+ "int": SymbolKind.Number, -+ "longlean": SymbolKind.Number, -+ "float": SymbolKind.Number, -+ "complex": SymbolKind.Number, -+ "string": SymbolKind.String, -+ "unicode": SymbolKind.String, -+ "list": SymbolKind.Array, -+ "field": SymbolKind.Field, -+} -+ -+def _kind(d): -+ """Return the VSCode Symbol Type""" -+ return _SYMBOL_KIND_MAP.get(d.type) -+ -+@hookimpl -+def pylsp_document_symbols(config, document): -+ symbols_settings = config.plugin_settings("jedi_symbols") -+ all_scopes = symbols_settings.get("all_scopes", True) -+ add_import_symbols = symbols_settings.get("include_import_symbols", True) -+ definitions = document.jedi_names(all_scopes=all_scopes) -+ symbols = [] -+ exclude = set({}) -+ redefinitions = {} -+ -+ while definitions != []: -+ d = definitions.pop(0) -+ -+ # Skip symbols imported from other modules. -+ if not add_import_symbols: -+ # Skip if there's an import in the code the symbol is defined. -+ code = d.get_line_code() -+ if " import " in code or "import " in code: -+ continue -+ -+ # Skip imported symbols comparing module names. -+ sym_full_name = d.full_name -+ if sym_full_name is not None: -+ document_dot_path = document.dot_path -+ -+ # We assume a symbol is imported from another module to start -+ # with. -+ imported_symbol = True -+ -+ # The last element of sym_full_name is the symbol itself, so -+ # we need to discard it to do module comparisons below. -+ if "." in sym_full_name: -+ sym_module_name = sym_full_name.rpartition(".")[0] -+ else: -+ sym_module_name = sym_full_name -+ -+ # This is necessary to display symbols in init files (the checks -+ # below fail without it). -+ if document_dot_path.endswith("__init__"): -+ document_dot_path = document_dot_path.rpartition(".")[0] -+ -+ # document_dot_path is the module where the symbol is imported, -+ # whereas sym_module_name is the one where it was declared. -+ if document_dot_path in sym_module_name: -+ # If document_dot_path is in sym_module_name, we can safely assume -+ # that the symbol was declared in the document. -+ imported_symbol = False -+ elif sym_module_name.split(".")[0] in document_dot_path.split("."): -+ # If the first module in sym_module_name is one of the modules in -+ # document_dot_path, we need to check if sym_module_name starts -+ # with the modules in document_dot_path. -+ document_mods = document_dot_path.split(".") -+ for i in range(1, len(document_mods) + 1): -+ submod = ".".join(document_mods[-i:]) -+ if sym_module_name.startswith(submod): -+ imported_symbol = False -+ break -+ -+ # When there's no __init__.py next to a file or in one of its -+ # parents, the checks above fail. However, Jedi has a nice way -+ # to tell if the symbol was declared in the same file: if -+ # sym_module_name starts by __main__. -+ if imported_symbol: -+ if not sym_module_name.startswith("__main__"): -+ continue -+ else: -+ # We need to skip symbols if their definition doesn't have `full_name` info, they -+ # are detected as a definition, but their description (e.g. `class Foo`) doesn't -+ # match the code where they're detected by Jedi. This happens for relative imports. -+ if _include_def(d): -+ if d.description not in d.get_line_code(): -+ continue -+ else: -+ continue -+ -+ if _include_def(d) and Path(document.path) == Path(d.module_path): -+ tuple_range = _tuple_range(d) -+ if tuple_range in exclude: -+ continue -+ -+ kind = redefinitions.get(tuple_range, None) -+ if kind is not None: -+ exclude |= {tuple_range} -+ -+ if d.type == "statement": -+ if d.description.startswith("self"): -+ kind = "field" -+ -+ symbol = { -+ "name": d.name, -+ "containerName": _container(d), -+ "location": { -+ "uri": document.uri, -+ "range": _range(d), -+ }, -+ "kind": _kind(d) if kind is None else _SYMBOL_KIND_MAP[kind], -+ } -+ symbols.append(symbol) -+ -+ if d.type == "class": -+ try: -+ defined_names = list(d.defined_names()) -+ for method in defined_names: -+ if method.type == "function": -+ redefinitions[_tuple_range(method)] = "method" -+ elif method.type == "statement": -+ redefinitions[_tuple_range(method)] = "field" -+ else: -+ redefinitions[_tuple_range(method)] = method.type -+ definitions = list(defined_names) + definitions -+ except Exception: -+ pass -+ return symbols -+ -+def get_style_config(document_path, options=None): -+ # Exclude file if it follows the patterns for that -+ exclude_patterns_from_ignore_file = file_resources.GetExcludePatternsForDir( -+ os.getcwd() -+ ) -+ if file_resources.IsIgnored(document_path, exclude_patterns_from_ignore_file): -+ return [] -+ -+ # Get the default styles as a string -+ # for a preset configuration, i.e. "pep8" -+ style_config = file_resources.GetDefaultStyleForDir(os.path.dirname(document_path)) -+ if options is None: -+ return style_config -+ -+ # We have options passed from LSP format request -+ # let's pass them to the formatter. -+ # First we want to get a dictionary of the preset style -+ # to pass instead of a string so that we can modify it -+ style_config = style.CreateStyleFromConfig(style_config) -+ -+ use_tabs = style_config["USE_TABS"] -+ indent_width = style_config["INDENT_WIDTH"] -+ -+ if options.get("tabSize") is not None: -+ indent_width = max(int(options.get("tabSize")), 1) -+ -+ if options.get("insertSpaces") is not None: -+ # TODO is it guaranteed to be a boolean, or can it be a string? -+ use_tabs = not options.get("insertSpaces") -+ -+ if use_tabs: -+ # Indent width doesn't make sense when using tabs -+ # the specifications state: "Size of a tab in spaces" -+ indent_width = 1 -+ -+ style_config["USE_TABS"] = use_tabs -+ style_config["INDENT_WIDTH"] = indent_width -+ style_config["CONTINUATION_INDENT_WIDTH"] = indent_width -+ -+ for style_option, value in options.items(): -+ # Apply arbitrary options passed as formatter options -+ if style_option not in style_config: -+ # ignore if it's not a known yapf config -+ continue -+ -+ style_config[style_option] = value -+ -+ return style_config -+ -+def diff_to_text_edits(diff, eol_chars): -+ # To keep things simple our text edits will be line based. -+ # We will also return the edits uncompacted, meaning a -+ # line replacement will come in as a line remove followed -+ # by a line add instead of a line replace. -+ text_edits = [] -+ # keep track of line number since additions -+ # don't include the line number it's being added -+ # to in diffs. lsp is 0-indexed so we'll start with -1 -+ prev_line_no = -1 -+ -+ for change in diff.changes: -+ if change.old and change.new: -+ # old and new are the same line, no change -+ # diffs are 1-indexed -+ prev_line_no = change.old - 1 -+ elif change.new: -+ # addition -+ text_edits.append( -+ { -+ "range": { -+ "start": {"line": prev_line_no + 1, "character": 0}, -+ "end": {"line": prev_line_no + 1, "character": 0}, -+ }, -+ "newText": change.line + eol_chars, -+ } -+ ) -+ elif change.old: -+ # remove -+ lsp_line_no = change.old - 1 -+ text_edits.append( -+ { -+ "range": { -+ "start": {"line": lsp_line_no, "character": 0}, -+ "end": { -+ # From LSP spec: -+ # If you want to specify a range that contains a line -+ # including the line ending character(s) then use an -+ # end position denoting the start of the next line. -+ "line": lsp_line_no + 1, -+ "character": 0, -+ }, -+ }, -+ "newText": "", -+ } -+ ) -+ prev_line_no = lsp_line_no -+ -+ return text_edits -+ -+def ensure_eof_new_line(document, eol_chars, text_edits): -+ # diffs don't include EOF newline https://github.com/google/yapf/issues/1008 -+ # we'll add it ourselves if our document doesn't already have it and the diff -+ # does not change the last line. -+ if document.source.endswith(eol_chars): -+ return -+ -+ lines = document.lines -+ last_line_number = len(lines) - 1 -+ -+ if text_edits and text_edits[-1]["range"]["start"]["line"] >= last_line_number: -+ return -+ -+ text_edits.append( -+ { -+ "range": { -+ "start": {"line": last_line_number, "character": 0}, -+ "end": {"line": last_line_number + 1, "character": 0}, -+ }, -+ "newText": lines[-1] + eol_chars, -+ } -+ ) -+ -+def _format(document, lines=None, options=None): -+ source = document.source -+ # Yapf doesn't work with CRLF/CR line endings, so we replace them by '\n' -+ # and restore them below when adding new lines -+ eol_chars = get_eol_chars(source) -+ if eol_chars in ["\r", "\r\n"]: -+ source = source.replace(eol_chars, "\n") -+ else: -+ eol_chars = "\n" -+ -+ style_config = get_style_config(document_path=document.path, options=options) -+ -+ diff_txt, changed = FormatCode( -+ source, -+ lines=lines, -+ filename=document.filename, -+ print_diff=True, -+ style_config=style_config, -+ ) -+ -+ if not changed: -+ return [] -+ -+ patch_generator = whatthepatch.parse_patch(diff_txt) -+ diff = next(patch_generator) -+ patch_generator.close() -+ -+ text_edits = diff_to_text_edits(diff=diff, eol_chars=eol_chars) -+ -+ ensure_eof_new_line(document=document, eol_chars=eol_chars, text_edits=text_edits) -+ -+ return text_edits -+ -+@hookimpl -+def pylsp_format_document(workspace, document, options): -+ log.info("Formatting document %s with yapf", document) -+ with workspace.report_progress("format: yapf"): -+ return _format(document, options=options) -+ -+@hookimpl -+def pylsp_format_range(document, range, options): -+ log.info("Formatting document %s in range %s with yapf", document, range) -+ # First we 'round' the range up/down to full lines only -+ range["start"]["character"] = 0 -+ range["end"]["line"] += 1 -+ range["end"]["character"] = 0 -+ -+ # From Yapf docs: -+ # lines: (list of tuples of integers) A list of tuples of lines, [start, end], -+ # that we want to format. The lines are 1-based indexed. It can be used by -+ # third-party code (e.g., IDEs) when reformatting a snippet of code rather -+ # than a whole file. -+ -+ # Add 1 for 1-indexing vs LSP's 0-indexing -+ lines = [(range["start"]["line"] + 1, range["end"]["line"] + 1)] -+ return _format(document, lines=lines, options=options) -\ No newline at end of file -diff --git a/pylsp/python_lsp.py b/pylsp/python_lsp.py -index ba41d6a..cad1046 100644 ---- a/pylsp/python_lsp.py -+++ b/pylsp/python_lsp.py -@@ -1,6 +1,15 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import log -+from pylsp.pylsp_shared import LINT_DEBOUNCE_S -+from pylsp.pylsp_shared import PARENT_PROCESS_WATCH_INTERVAL -+from pylsp.pylsp_shared import MAX_WORKERS -+from pylsp.pylsp_shared import PYTHON_FILE_EXTENSIONS -+from pylsp.pylsp_shared import CONFIG_FILEs -+from pylsp.pylsp_shared import flatten -+from pylsp.pylsp_shared import merge -+from pylsp.pylsp_shared import PythonLSPServer - import logging - import os - import socketserver -@@ -23,15 +32,6 @@ from ._version import __version__ - from .config import config - from .workspace import Cell, Document, Notebook, Workspace - --log = logging.getLogger(__name__) -- -- --LINT_DEBOUNCE_S = 0.5 # 500 ms --PARENT_PROCESS_WATCH_INTERVAL = 10 # 10 s --MAX_WORKERS = 64 --PYTHON_FILE_EXTENSIONS = (".py", ".pyi") --CONFIG_FILEs = ("pycodestyle.cfg", "setup.cfg", "tox.ini", ".flake8") -- - - class _StreamHandlerWrapper(socketserver.StreamRequestHandler): - """A wrapper class that is used to construct a custom handler class.""" -@@ -156,742 +156,3 @@ def start_ws_lang_server(port, check_parent_process, handler_class) -> None: - await asyncio.Future() - - asyncio.run(run_server()) -- -- --class PythonLSPServer(MethodDispatcher): -- """Implementation of the Microsoft VSCode Language Server Protocol -- https://github.com/Microsoft/language-server-protocol/blob/master/versions/protocol-1-x.md -- """ -- -- def __init__( -- self, rx, tx, check_parent_process=False, consumer=None, *, endpoint_cls=None -- ) -> None: -- self.workspace = None -- self.config = None -- self.root_uri = None -- self.watching_thread = None -- self.workspaces = {} -- self.uri_workspace_mapper = {} -- -- self._check_parent_process = check_parent_process -- -- if rx is not None: -- self._jsonrpc_stream_reader = JsonRpcStreamReader(rx) -- else: -- self._jsonrpc_stream_reader = None -- -- if tx is not None: -- self._jsonrpc_stream_writer = JsonRpcStreamWriter(tx) -- else: -- self._jsonrpc_stream_writer = None -- -- endpoint_cls = endpoint_cls or Endpoint -- -- # if consumer is None, it is assumed that the default streams-based approach is being used -- if consumer is None: -- self._endpoint = endpoint_cls( -- self, self._jsonrpc_stream_writer.write, max_workers=MAX_WORKERS -- ) -- else: -- self._endpoint = endpoint_cls(self, consumer, max_workers=MAX_WORKERS) -- -- self._dispatchers = [] -- self._shutdown = False -- -- def start(self) -> None: -- """Entry point for the server.""" -- self._jsonrpc_stream_reader.listen(self._endpoint.consume) -- -- def consume(self, message) -> None: -- """Entry point for consumer based server. Alternative to stream listeners.""" -- # assuming message will be JSON -- self._endpoint.consume(message) -- -- def __getitem__(self, item): -- """Override getitem to fallback through multiple dispatchers.""" -- if self._shutdown and item != "exit": -- # exit is the only allowed method during shutdown -- log.debug("Ignoring non-exit method during shutdown: %s", item) -- item = "invalid_request_after_shutdown" -- -- try: -- return super().__getitem__(item) -- except KeyError: -- # Fallback through extra dispatchers -- for dispatcher in self._dispatchers: -- try: -- return dispatcher[item] -- except KeyError: -- continue -- -- raise KeyError() -- -- def m_shutdown(self, **_kwargs) -> None: -- for workspace in self.workspaces.values(): -- workspace.close() -- self._shutdown = True -- -- def m_invalid_request_after_shutdown(self, **_kwargs): -- return { -- "error": { -- "code": lsp.ErrorCodes.InvalidRequest, -- "message": "Requests after shutdown are not valid", -- } -- } -- -- def m_exit(self, **_kwargs) -> None: -- self._endpoint.shutdown() -- if self._jsonrpc_stream_reader is not None: -- self._jsonrpc_stream_reader.close() -- if self._jsonrpc_stream_writer is not None: -- self._jsonrpc_stream_writer.close() -- -- def _match_uri_to_workspace(self, uri): -- workspace_uri = _utils.match_uri_to_workspace(uri, self.workspaces) -- return self.workspaces.get(workspace_uri, self.workspace) -- -- def _hook(self, hook_name, doc_uri=None, **kwargs): -- """Calls hook_name and returns a list of results from all registered handlers""" -- workspace = self._match_uri_to_workspace(doc_uri) -- doc = workspace.get_document(doc_uri) if doc_uri else None -- hook_handlers = self.config.plugin_manager.subset_hook_caller( -- hook_name, self.config.disabled_plugins -- ) -- return hook_handlers( -- config=self.config, workspace=workspace, document=doc, **kwargs -- ) -- -- def capabilities(self): -- server_capabilities = { -- "codeActionProvider": True, -- "codeLensProvider": { -- "resolveProvider": False, # We may need to make this configurable -- }, -- "completionProvider": { -- "resolveProvider": True, # We could know everything ahead of time, but this takes time to transfer -- "triggerCharacters": ["."], -- }, -- "documentFormattingProvider": True, -- "documentHighlightProvider": True, -- "documentRangeFormattingProvider": True, -- "documentSymbolProvider": True, -- "definitionProvider": True, -- "executeCommandProvider": { -- "commands": flatten(self._hook("pylsp_commands")) -- }, -- "hoverProvider": True, -- "referencesProvider": True, -- "renameProvider": True, -- "foldingRangeProvider": True, -- "signatureHelpProvider": {"triggerCharacters": ["(", ",", "="]}, -- "textDocumentSync": { -- "change": lsp.TextDocumentSyncKind.INCREMENTAL, -- "save": { -- "includeText": True, -- }, -- "openClose": True, -- }, -- "notebookDocumentSync": { -- "notebookSelector": [{"cells": [{"language": "python"}]}] -- }, -- "workspace": { -- "workspaceFolders": {"supported": True, "changeNotifications": True} -- }, -- "experimental": merge(self._hook("pylsp_experimental_capabilities")), -- } -- log.info("Server capabilities: %s", server_capabilities) -- return server_capabilities -- -- def m_initialize( -- self, -- processId=None, -- rootUri=None, -- rootPath=None, -- initializationOptions=None, -- workspaceFolders=None, -- **_kwargs, -- ): -- log.debug( -- "Language server initialized with %s %s %s %s", -- processId, -- rootUri, -- rootPath, -- initializationOptions, -- ) -- if rootUri is None: -- rootUri = uris.from_fs_path(rootPath) if rootPath is not None else "" -- -- self.workspaces.pop(self.root_uri, None) -- self.root_uri = rootUri -- self.config = config.Config( -- rootUri, -- initializationOptions or {}, -- processId, -- _kwargs.get("capabilities", {}), -- ) -- self.workspace = Workspace(rootUri, self._endpoint, self.config) -- self.workspaces[rootUri] = self.workspace -- if workspaceFolders: -- for folder in workspaceFolders: -- uri = folder["uri"] -- if uri == rootUri: -- # Already created -- continue -- workspace_config = config.Config( -- uri, -- self.config._init_opts, -- self.config._process_id, -- self.config._capabilities, -- ) -- workspace_config.update(self.config._settings) -- self.workspaces[uri] = Workspace(uri, self._endpoint, workspace_config) -- -- self._dispatchers = self._hook("pylsp_dispatchers") -- self._hook("pylsp_initialize") -- -- if ( -- self._check_parent_process -- and processId is not None -- and self.watching_thread is None -- ): -- -- def watch_parent_process(pid): -- # exit when the given pid is not alive -- if not _utils.is_process_alive(pid): -- log.info("parent process %s is not alive, exiting!", pid) -- self.m_exit() -- else: -- threading.Timer( -- PARENT_PROCESS_WATCH_INTERVAL, watch_parent_process, args=[pid] -- ).start() -- -- self.watching_thread = threading.Thread( -- target=watch_parent_process, args=(processId,) -- ) -- self.watching_thread.daemon = True -- self.watching_thread.start() -- # Get our capabilities -- return { -- "capabilities": self.capabilities(), -- "serverInfo": { -- "name": "pylsp", -- "version": __version__, -- }, -- } -- -- def m_initialized(self, **_kwargs) -> None: -- self._hook("pylsp_initialized") -- -- def code_actions(self, doc_uri: str, range: Dict, context: Dict): -- return flatten( -- self._hook("pylsp_code_actions", doc_uri, range=range, context=context) -- ) -- -- def code_lens(self, doc_uri): -- return flatten(self._hook("pylsp_code_lens", doc_uri)) -- -- def completions(self, doc_uri, position): -- workspace = self._match_uri_to_workspace(doc_uri) -- document = workspace.get_document(doc_uri) -- ignored_names = None -- if isinstance(document, Cell): -- # We need to get the ignored names from the whole notebook document -- notebook_document = workspace.get_maybe_document(document.notebook_uri) -- ignored_names = notebook_document.jedi_names(doc_uri) -- completions = self._hook( -- "pylsp_completions", doc_uri, position=position, ignored_names=ignored_names -- ) -- return {"isIncomplete": False, "items": flatten(completions)} -- -- def completion_item_resolve(self, completion_item): -- doc_uri = completion_item.get("data", {}).get("doc_uri", None) -- return self._hook( -- "pylsp_completion_item_resolve", doc_uri, completion_item=completion_item -- ) -- -- def definitions(self, doc_uri, position): -- return flatten(self._hook("pylsp_definitions", doc_uri, position=position)) -- -- def document_symbols(self, doc_uri): -- return flatten(self._hook("pylsp_document_symbols", doc_uri)) -- -- def document_did_save(self, doc_uri): -- return self._hook("pylsp_document_did_save", doc_uri) -- -- def execute_command(self, command, arguments): -- return self._hook("pylsp_execute_command", command=command, arguments=arguments) -- -- def format_document(self, doc_uri, options): -- return lambda: self._hook("pylsp_format_document", doc_uri, options=options) -- -- def format_range(self, doc_uri, range, options): -- return self._hook("pylsp_format_range", doc_uri, range=range, options=options) -- -- def highlight(self, doc_uri, position): -- return ( -- flatten(self._hook("pylsp_document_highlight", doc_uri, position=position)) -- or None -- ) -- -- def hover(self, doc_uri, position): -- return self._hook("pylsp_hover", doc_uri, position=position) or {"contents": ""} -- -- @_utils.debounce(LINT_DEBOUNCE_S, keyed_by="doc_uri") -- def lint(self, doc_uri, is_saved) -> None: -- # Since we're debounced, the document may no longer be open -- workspace = self._match_uri_to_workspace(doc_uri) -- document_object = workspace.documents.get(doc_uri, None) -- if isinstance(document_object, Document): -- self._lint_text_document( -- doc_uri, workspace, is_saved, document_object.version -- ) -- elif isinstance(document_object, Notebook): -- self._lint_notebook_document(document_object, workspace) -- -- def _lint_text_document( -- self, doc_uri, workspace, is_saved, doc_version=None -- ) -> None: -- workspace.publish_diagnostics( -- doc_uri, -- flatten(self._hook("pylsp_lint", doc_uri, is_saved=is_saved)), -- doc_version, -- ) -- -- def _lint_notebook_document(self, notebook_document, workspace) -> None: -- """ -- Lint a notebook document. -- -- This is a bit more complicated than linting a text document, because we need to -- send the entire notebook document to the pylsp_lint hook, but we need to send -- the diagnostics back to the client on a per-cell basis. -- """ -- -- # First, we create a temp TextDocument that represents the whole notebook -- # contents. We'll use this to send to the pylsp_lint hook. -- random_uri = str(uuid.uuid4()) -- -- # cell_list helps us map the diagnostics back to the correct cell later. -- cell_list: List[Dict[str, Any]] = [] -- -- offset = 0 -- total_source = "" -- for cell in notebook_document.cells: -- cell_uri = cell["document"] -- cell_document = workspace.get_cell_document(cell_uri) -- -- num_lines = cell_document.line_count -- -- data = { -- "uri": cell_uri, -- "line_start": offset, -- "line_end": offset + num_lines - 1, -- "source": cell_document.source, -- } -- -- cell_list.append(data) -- if offset == 0: -- total_source = cell_document.source -- else: -- total_source += "\n" + cell_document.source -- -- offset += num_lines -- -- workspace.put_document(random_uri, total_source) -- -- try: -- document_diagnostics = flatten( -- self._hook("pylsp_lint", random_uri, is_saved=True) -- ) -- -- # Now we need to map the diagnostics back to the correct cell and publish them. -- # Note: this is O(n*m) in the number of cells and diagnostics, respectively. -- for cell in cell_list: -- cell_diagnostics = [] -- for diagnostic in document_diagnostics: -- start_line = diagnostic["range"]["start"]["line"] -- end_line = diagnostic["range"]["end"]["line"] -- -- if start_line > cell["line_end"] or end_line < cell["line_start"]: -- continue -- diagnostic["range"]["start"]["line"] = ( -- start_line - cell["line_start"] -- ) -- diagnostic["range"]["end"]["line"] = end_line - cell["line_start"] -- cell_diagnostics.append(diagnostic) -- -- workspace.publish_diagnostics(cell["uri"], cell_diagnostics) -- finally: -- workspace.rm_document(random_uri) -- -- def references(self, doc_uri, position, exclude_declaration): -- return flatten( -- self._hook( -- "pylsp_references", -- doc_uri, -- position=position, -- exclude_declaration=exclude_declaration, -- ) -- ) -- -- def rename(self, doc_uri, position, new_name): -- return self._hook("pylsp_rename", doc_uri, position=position, new_name=new_name) -- -- def signature_help(self, doc_uri, position): -- return self._hook("pylsp_signature_help", doc_uri, position=position) -- -- def folding(self, doc_uri): -- return flatten(self._hook("pylsp_folding_range", doc_uri)) -- -- def m_completion_item__resolve(self, **completionItem): -- return self.completion_item_resolve(completionItem) -- -- def m_notebook_document__did_open( -- self, notebookDocument=None, cellTextDocuments=None, **_kwargs -- ) -> None: -- workspace = self._match_uri_to_workspace(notebookDocument["uri"]) -- workspace.put_notebook_document( -- notebookDocument["uri"], -- notebookDocument["notebookType"], -- cells=notebookDocument["cells"], -- version=notebookDocument.get("version"), -- metadata=notebookDocument.get("metadata"), -- ) -- for cell in cellTextDocuments or []: -- workspace.put_cell_document( -- cell["uri"], -- notebookDocument["uri"], -- cell["languageId"], -- cell["text"], -- version=cell.get("version"), -- ) -- self.lint(notebookDocument["uri"], is_saved=True) -- -- def m_notebook_document__did_close( -- self, notebookDocument=None, cellTextDocuments=None, **_kwargs -- ) -> None: -- workspace = self._match_uri_to_workspace(notebookDocument["uri"]) -- for cell in cellTextDocuments or []: -- workspace.publish_diagnostics(cell["uri"], []) -- workspace.rm_document(cell["uri"]) -- workspace.rm_document(notebookDocument["uri"]) -- -- def m_notebook_document__did_change( -- self, notebookDocument=None, change=None, **_kwargs -- ) -> None: -- """ -- Changes to the notebook document. -- -- This could be one of the following: -- 1. Notebook metadata changed -- 2. Cell(s) added -- 3. Cell(s) deleted -- 4. Cell(s) data changed -- 4.1 Cell metadata changed -- 4.2 Cell source changed -- """ -- workspace = self._match_uri_to_workspace(notebookDocument["uri"]) -- -- if change.get("metadata"): -- # Case 1 -- workspace.update_notebook_metadata( -- notebookDocument["uri"], change.get("metadata") -- ) -- -- cells = change.get("cells") -- if cells: -- # Change to cells -- structure = cells.get("structure") -- if structure: -- # Case 2 or 3 -- notebook_cell_array_change = structure["array"] -- start = notebook_cell_array_change["start"] -- cell_delete_count = notebook_cell_array_change["deleteCount"] -- if cell_delete_count == 0: -- # Case 2 -- # Cell documents -- for cell_document in structure["didOpen"]: -- workspace.put_cell_document( -- cell_document["uri"], -- notebookDocument["uri"], -- cell_document["languageId"], -- cell_document["text"], -- cell_document.get("version"), -- ) -- # Cell metadata which is added to Notebook -- workspace.add_notebook_cells( -- notebookDocument["uri"], -- notebook_cell_array_change["cells"], -- start, -- ) -- else: -- # Case 3 -- # Cell documents -- for cell_document in structure["didClose"]: -- workspace.rm_document(cell_document["uri"]) -- workspace.publish_diagnostics(cell_document["uri"], []) -- # Cell metadata which is removed from Notebook -- workspace.remove_notebook_cells( -- notebookDocument["uri"], start, cell_delete_count -- ) -- -- data = cells.get("data") -- if data: -- # Case 4.1 -- for cell in data: -- # update NotebookDocument.cells properties -- pass -- -- text_content = cells.get("textContent") -- if text_content: -- # Case 4.2 -- for cell in text_content: -- cell_uri = cell["document"]["uri"] -- # Even though the protocol says that `changes` is an array, we assume that it's always a single -- # element array that contains the last change to the cell source. -- workspace.update_document(cell_uri, cell["changes"][0]) -- self.lint(notebookDocument["uri"], is_saved=True) -- -- def m_text_document__did_close(self, textDocument=None, **_kwargs) -> None: -- workspace = self._match_uri_to_workspace(textDocument["uri"]) -- workspace.publish_diagnostics(textDocument["uri"], []) -- workspace.rm_document(textDocument["uri"]) -- -- def m_text_document__did_open(self, textDocument=None, **_kwargs) -> None: -- workspace = self._match_uri_to_workspace(textDocument["uri"]) -- workspace.put_document( -- textDocument["uri"], -- textDocument["text"], -- version=textDocument.get("version"), -- ) -- self._hook("pylsp_document_did_open", textDocument["uri"]) -- self.lint(textDocument["uri"], is_saved=True) -- -- def m_text_document__did_change( -- self, contentChanges=None, textDocument=None, **_kwargs -- ) -> None: -- workspace = self._match_uri_to_workspace(textDocument["uri"]) -- for change in contentChanges: -- workspace.update_document( -- textDocument["uri"], change, version=textDocument.get("version") -- ) -- self.lint(textDocument["uri"], is_saved=False) -- -- def m_text_document__did_save(self, textDocument=None, **_kwargs) -> None: -- self.lint(textDocument["uri"], is_saved=True) -- self.document_did_save(textDocument["uri"]) -- -- def m_text_document__code_action( -- self, textDocument=None, range=None, context=None, **_kwargs -- ): -- return self.code_actions(textDocument["uri"], range, context) -- -- def m_text_document__code_lens(self, textDocument=None, **_kwargs): -- return self.code_lens(textDocument["uri"]) -- -- def _cell_document__completion(self, cellDocument, position=None, **_kwargs): -- workspace = self._match_uri_to_workspace(cellDocument.notebook_uri) -- notebookDocument = workspace.get_maybe_document(cellDocument.notebook_uri) -- if notebookDocument is None: -- raise ValueError("Invalid notebook document") -- -- cell_data = notebookDocument.cell_data() -- -- # Concatenate all cells to be a single temporary document -- total_source = "\n".join(data["source"] for data in cell_data.values()) -- with workspace.temp_document(total_source) as temp_uri: -- # update position to be the position in the temp document -- if position is not None: -- position["line"] += cell_data[cellDocument.uri]["line_start"] -- -- completions = self.completions(temp_uri, position) -- -- # Translate temp_uri locations to cell document locations -- for item in completions.get("items", []): -- if item.get("data", {}).get("doc_uri") == temp_uri: -- item["data"]["doc_uri"] = cellDocument.uri -- -- return completions -- -- def m_text_document__completion(self, textDocument=None, position=None, **_kwargs): -- # textDocument here is just a dict with a uri -- workspace = self._match_uri_to_workspace(textDocument["uri"]) -- document = workspace.get_document(textDocument["uri"]) -- if isinstance(document, Cell): -- return self._cell_document__completion(document, position, **_kwargs) -- return self.completions(textDocument["uri"], position) -- -- def _cell_document__definition(self, cellDocument, position=None, **_kwargs): -- workspace = self._match_uri_to_workspace(cellDocument.notebook_uri) -- notebookDocument = workspace.get_maybe_document(cellDocument.notebook_uri) -- if notebookDocument is None: -- raise ValueError("Invalid notebook document") -- -- cell_data = notebookDocument.cell_data() -- -- # Concatenate all cells to be a single temporary document -- total_source = "\n".join(data["source"] for data in cell_data.values()) -- with workspace.temp_document(total_source) as temp_uri: -- # update position to be the position in the temp document -- if position is not None: -- position["line"] += cell_data[cellDocument.uri]["line_start"] -- -- definitions = self.definitions(temp_uri, position) -- -- # Translate temp_uri locations to cell document locations -- for definition in definitions: -- if definition["uri"] == temp_uri: -- # Find the cell the start line is in and adjust the uri and line numbers -- for cell_uri, data in cell_data.items(): -- if ( -- data["line_start"] -- <= definition["range"]["start"]["line"] -- <= data["line_end"] -- ): -- definition["uri"] = cell_uri -- definition["range"]["start"]["line"] -= data["line_start"] -- definition["range"]["end"]["line"] -= data["line_start"] -- break -- -- return definitions -- -- def m_text_document__definition(self, textDocument=None, position=None, **_kwargs): -- # textDocument here is just a dict with a uri -- workspace = self._match_uri_to_workspace(textDocument["uri"]) -- document = workspace.get_document(textDocument["uri"]) -- if isinstance(document, Cell): -- return self._cell_document__definition(document, position, **_kwargs) -- return self.definitions(textDocument["uri"], position) -- -- def m_text_document__document_highlight( -- self, textDocument=None, position=None, **_kwargs -- ): -- return self.highlight(textDocument["uri"], position) -- -- def m_text_document__hover(self, textDocument=None, position=None, **_kwargs): -- return self.hover(textDocument["uri"], position) -- -- def m_text_document__document_symbol(self, textDocument=None, **_kwargs): -- return self.document_symbols(textDocument["uri"]) -- -- def m_text_document__formatting(self, textDocument=None, options=None, **_kwargs): -- return self.format_document(textDocument["uri"], options) -- -- def m_text_document__rename( -- self, textDocument=None, position=None, newName=None, **_kwargs -- ): -- return self.rename(textDocument["uri"], position, newName) -- -- def m_text_document__folding_range(self, textDocument=None, **_kwargs): -- return self.folding(textDocument["uri"]) -- -- def m_text_document__range_formatting( -- self, textDocument=None, range=None, options=None, **_kwargs -- ): -- return self.format_range(textDocument["uri"], range, options) -- -- def m_text_document__references( -- self, textDocument=None, position=None, context=None, **_kwargs -- ): -- exclude_declaration = not context["includeDeclaration"] -- return self.references(textDocument["uri"], position, exclude_declaration) -- -- def m_text_document__signature_help( -- self, textDocument=None, position=None, **_kwargs -- ): -- return self.signature_help(textDocument["uri"], position) -- -- def m_workspace__did_change_configuration(self, settings=None) -> None: -- if self.config is not None: -- self.config.update((settings or {}).get("pylsp", {})) -- for workspace in self.workspaces.values(): -- workspace.update_config(settings) -- self._hook("pylsp_workspace_configuration_changed") -- for doc_uri in workspace.documents: -- self.lint(doc_uri, is_saved=False) -- -- def m_workspace__did_change_workspace_folders(self, event=None, **_kwargs): -- if event is None: -- return -- added = event.get("added", []) -- removed = event.get("removed", []) -- -- for removed_info in removed: -- if "uri" in removed_info: -- removed_uri = removed_info["uri"] -- self.workspaces.pop(removed_uri, None) -- -- for added_info in added: -- if "uri" in added_info: -- added_uri = added_info["uri"] -- workspace_config = config.Config( -- added_uri, -- self.config._init_opts, -- self.config._process_id, -- self.config._capabilities, -- ) -- workspace_config.update(self.config._settings) -- self.workspaces[added_uri] = Workspace( -- added_uri, self._endpoint, workspace_config -- ) -- -- root_workspace_removed = any( -- removed_info["uri"] == self.root_uri for removed_info in removed -- ) -- workspace_added = len(added) > 0 and "uri" in added[0] -- if root_workspace_removed and workspace_added: -- added_uri = added[0]["uri"] -- self.root_uri = added_uri -- new_root_workspace = self.workspaces[added_uri] -- self.config = new_root_workspace._config -- self.workspace = new_root_workspace -- elif root_workspace_removed: -- # NOTE: Removing the root workspace can only happen when the server -- # is closed, thus the else condition of this if can never happen. -- if self.workspaces: -- log.debug("Root workspace deleted!") -- available_workspaces = sorted(self.workspaces) -- first_workspace = available_workspaces[0] -- new_root_workspace = self.workspaces[first_workspace] -- self.root_uri = first_workspace -- self.config = new_root_workspace._config -- self.workspace = new_root_workspace -- -- # Migrate documents that are on the root workspace and have a better -- # match now -- doc_uris = list(self.workspace._docs.keys()) -- for uri in doc_uris: -- doc = self.workspace._docs.pop(uri) -- new_workspace = self._match_uri_to_workspace(uri) -- new_workspace._docs[uri] = doc -- -- def m_workspace__did_change_watched_files(self, changes=None, **_kwargs): -- changed_py_files = set() -- config_changed = False -- for d in changes or []: -- if d["uri"].endswith(PYTHON_FILE_EXTENSIONS): -- changed_py_files.add(d["uri"]) -- elif d["uri"].endswith(CONFIG_FILEs): -- config_changed = True -- -- if config_changed: -- self.config.settings.cache_clear() -- elif not changed_py_files: -- # Only externally changed python files and lint configs may result in changed diagnostics. -- return -- -- for workspace in self.workspaces.values(): -- for doc_uri in workspace.documents: -- # Changes in doc_uri are already handled by m_text_document__did_save -- if doc_uri not in changed_py_files: -- self.lint(doc_uri, is_saved=False) -- -- def m_workspace__execute_command(self, command=None, arguments=None): -- return self.execute_command(command, arguments) -- -- --def flatten(list_of_lists): -- return [item for lst in list_of_lists for item in lst] -- -- --def merge(list_of_dicts): -- return {k: v for dictionary in list_of_dicts for k, v in dictionary.items()} -diff --git a/test/fixtures.py b/test/fixtures.py -index dd10140..094d550 100644 ---- a/test/fixtures.py -+++ b/test/fixtures.py -@@ -1,6 +1,10 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import PythonLSPServer -+from pylsp.pylsp_shared import FakeEditorMethodsMixin -+from pylsp.pylsp_shared import FakePythonLSPServer -+from pylsp.pylsp_shared import FakeEndpoint - import os - from io import StringIO - from unittest.mock import MagicMock -@@ -12,7 +16,6 @@ from pylsp_jsonrpc.exceptions import JsonRpcException - - from pylsp import uris - from pylsp.config.config import Config --from pylsp.python_lsp import PythonLSPServer - from pylsp.workspace import Document, Workspace - from test.test_utils import CALL_TIMEOUT_IN_SECONDS, ClientServerPair - -@@ -24,46 +27,6 @@ def main(): - """ - - --class FakeEditorMethodsMixin: -- """ -- Represents the methods to be added to a dispatcher class when faking an editor. -- """ -- -- def m_window__work_done_progress__create(self, *_args, **_kwargs): -- """ -- Fake editor method `window/workDoneProgress/create`. -- -- related spec: -- https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#window_workDoneProgress_create -- """ -- return None -- -- --class FakePythonLSPServer(FakeEditorMethodsMixin, PythonLSPServer): -- pass -- -- --class FakeEndpoint(Endpoint): -- """ -- Fake Endpoint representing the editor / LSP client. -- -- The `dispatcher` dict will be used to synchronously calculate the responses -- for calls to `.request` and resolve the futures with the value or errors. -- -- Fake methods in the `dispatcher` should raise `JsonRpcException` for any -- error. -- """ -- -- def request(self, method, params=None): -- request_future = super().request(method, params) -- try: -- request_future.set_result(self._dispatcher[method](params)) -- except JsonRpcException as e: -- request_future.set_exception(e) -- -- return request_future -- -- - @pytest.fixture - def pylsp(tmpdir): - """Return an initialized python LS""" -@@ -76,28 +39,6 @@ def pylsp(tmpdir): - return ls - - --@pytest.fixture --def pylsp_w_workspace_folders(tmpdir): -- """Return an initialized python LS""" -- ls = FakePythonLSPServer(StringIO, StringIO, endpoint_cls=FakeEndpoint) -- -- folder1 = tmpdir.mkdir("folder1") -- folder2 = tmpdir.mkdir("folder2") -- -- ls.m_initialize( -- processId=1, -- rootUri=uris.from_fs_path(str(folder1)), -- initializationOptions={}, -- workspaceFolders=[ -- {"uri": uris.from_fs_path(str(folder1)), "name": "folder1"}, -- {"uri": uris.from_fs_path(str(folder2)), "name": "folder2"}, -- ], -- ) -- -- workspace_folders = [folder1, folder2] -- return (ls, workspace_folders) -- -- - @pytest.fixture() - def consumer(): - return MagicMock() -diff --git a/test/plugins/test_autoimport.py b/test/plugins/test_autoimport.py -index dbad8d0..b27a0d5 100644 ---- a/test/plugins/test_autoimport.py -+++ b/test/plugins/test_autoimport.py -@@ -1,5 +1,11 @@ - # Copyright 2022- Python Language Server Contributors. - -+from pylsp.pylsp_shared import _should_insert -+from pylsp.pylsp_shared import _get_score -+from pylsp.pylsp_shared import get_names -+from pylsp.pylsp_shared import cache -+from pylsp.pylsp_shared import pylsp_completions -+from pylsp.pylsp_shared import get_name_or_module - from typing import Any, Dict, List - from unittest.mock import Mock, patch - -@@ -9,16 +15,6 @@ import pytest - - from pylsp import IS_WIN, lsp, uris - from pylsp.config.config import Config --from pylsp.plugins.rope_autoimport import ( -- _get_score, -- _should_insert, -- cache, -- get_name_or_module, -- get_names, --) --from pylsp.plugins.rope_autoimport import ( -- pylsp_completions as pylsp_autoimport_completions, --) - from pylsp.workspace import Workspace - from test.test_notebook_document import wait_for_condition - from test.test_utils import send_initialize_request, send_notebook_did_open -diff --git a/test/plugins/test_autopep8_format.py b/test/plugins/test_autopep8_format.py -index 4966b89..e689bbc 100644 ---- a/test/plugins/test_autopep8_format.py -+++ b/test/plugins/test_autopep8_format.py -@@ -1,10 +1,11 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_format_document -+from pylsp.pylsp_shared import pylsp_format_range - import pytest - - from pylsp import uris --from pylsp.plugins.autopep8_format import pylsp_format_document, pylsp_format_range - from pylsp.workspace import Document - - DOC_URI = uris.from_fs_path(__file__) -diff --git a/test/plugins/test_completion.py b/test/plugins/test_completion.py -index d1ca5ef..48c6af7 100644 ---- a/test/plugins/test_completion.py -+++ b/test/plugins/test_completion.py -@@ -1,6 +1,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_completions -+from pylsp.pylsp_shared import pylsp_completion_item_resolve - import math - import os - import sys -@@ -11,11 +13,6 @@ import pytest - - from pylsp import lsp, uris - from pylsp._utils import JEDI_VERSION --from pylsp.plugins.jedi_completion import ( -- pylsp_completion_item_resolve as pylsp_jedi_completion_item_resolve, --) --from pylsp.plugins.jedi_completion import pylsp_completions as pylsp_jedi_completions --from pylsp.plugins.rope_completion import pylsp_completions as pylsp_rope_completions - from pylsp.workspace import Document - - PY2 = sys.version[0] == "2" -diff --git a/test/plugins/test_definitions.py b/test/plugins/test_definitions.py -index 7923524..efdf5e1 100644 ---- a/test/plugins/test_definitions.py -+++ b/test/plugins/test_definitions.py -@@ -1,10 +1,10 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_definitions - import os - - from pylsp import uris --from pylsp.plugins.definition import pylsp_definitions - from pylsp.workspace import Document - - DOC_URI = uris.from_fs_path(__file__) -diff --git a/test/plugins/test_flake8_lint.py b/test/plugins/test_flake8_lint.py -index e7b6b00..046a7b5 100644 ---- a/test/plugins/test_flake8_lint.py -+++ b/test/plugins/test_flake8_lint.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_lint - import os - import tempfile - from textwrap import dedent -@@ -32,7 +33,7 @@ def temp_document(doc_text, workspace): - - def test_flake8_unsaved(workspace) -> None: - doc = Document("", workspace, DOC) -- diags = flake8_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - msg = "F841 local variable 'a' is assigned to but never used" - unused_var = [d for d in diags if d["message"] == msg][0] - -@@ -47,7 +48,7 @@ def test_flake8_unsaved(workspace) -> None: - def test_flake8_lint(workspace) -> None: - name, doc = temp_document(DOC, workspace) - try: -- diags = flake8_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - msg = "F841 local variable 'a' is assigned to but never used" - unused_var = [d for d in diags if d["message"] == msg][0] - -@@ -91,7 +92,7 @@ def test_flake8_respecting_configuration(workspace) -> None: - workspace.put_document(made[rel]["uri"], contents) - made[rel]["document"] = workspace._docs[made[rel]["uri"]] - -- diags = flake8_lint.pylsp_lint(workspace, made["src/a.py"]["document"]) -+ diags = pylsp_lint(workspace, made["src/a.py"]["document"]) - assert diags == [ - { - "source": "flake8", -@@ -106,7 +107,7 @@ def test_flake8_respecting_configuration(workspace) -> None: - }, - ] - -- diags = flake8_lint.pylsp_lint(workspace, made["src/b.py"]["document"]) -+ diags = pylsp_lint(workspace, made["src/b.py"]["document"]) - assert diags == [ - { - "source": "flake8", -@@ -129,7 +130,7 @@ def test_flake8_config_param(workspace) -> None: - flake8_conf = "/tmp/some.cfg" - workspace._config.update({"plugins": {"flake8": {"config": flake8_conf}}}) - _name, doc = temp_document(DOC, workspace) -- flake8_lint.pylsp_lint(workspace, doc) -+ pylsp_lint(workspace, doc) - (call_args,) = popen_mock.call_args[0] - assert "flake8" in call_args - assert "--config={}".format(flake8_conf) in call_args -@@ -146,7 +147,7 @@ def test_flake8_executable_param(workspace) -> None: - ) - - _name, doc = temp_document(DOC, workspace) -- flake8_lint.pylsp_lint(workspace, doc) -+ pylsp_lint(workspace, doc) - - (call_args,) = popen_mock.call_args[0] - assert flake8_executable in call_args -@@ -190,7 +191,7 @@ exclude = - mock_instance.communicate.return_value = [bytes(), bytes()] - - doc = workspace.get_document(doc_uri) -- flake8_lint.pylsp_lint(workspace, doc) -+ pylsp_lint(workspace, doc) - - call_args = popen_mock.call_args[0][0] - -@@ -230,7 +231,7 @@ exclude = - assert len(flake8_settings["exclude"]) == 2 - - doc = workspace.get_document(doc_uri) -- res = flake8_lint.pylsp_lint(workspace, doc) -+ res = pylsp_lint(workspace, doc) - assert not res - - os.unlink(os.path.join(workspace.root_path, "setup.cfg")) -@@ -252,7 +253,7 @@ per-file-ignores = **/__init__.py:F401,E402 - assert len(flake8_settings["perFileIgnores"]) == 2 - - doc = workspace.get_document(doc_uri) -- res = flake8_lint.pylsp_lint(workspace, doc) -+ res = pylsp_lint(workspace, doc) - assert not res - - os.unlink(os.path.join(workspace.root_path, "setup.cfg")) -diff --git a/test/plugins/test_folding.py b/test/plugins/test_folding.py -index 1f0d34c..8adcd12 100644 ---- a/test/plugins/test_folding.py -+++ b/test/plugins/test_folding.py -@@ -1,11 +1,11 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_folding_range - import sys - from textwrap import dedent - - from pylsp import uris --from pylsp.plugins.folding import pylsp_folding_range - from pylsp.workspace import Document - - DOC_URI = uris.from_fs_path(__file__) -diff --git a/test/plugins/test_highlight.py b/test/plugins/test_highlight.py -index eb5485b..0846a55 100644 ---- a/test/plugins/test_highlight.py -+++ b/test/plugins/test_highlight.py -@@ -1,8 +1,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_document_highlight - from pylsp import lsp, uris --from pylsp.plugins.highlight import pylsp_document_highlight - from pylsp.workspace import Document - - DOC_URI = uris.from_fs_path(__file__) -diff --git a/test/plugins/test_hover.py b/test/plugins/test_hover.py -index 9674b87..c3dc92b 100644 ---- a/test/plugins/test_hover.py -+++ b/test/plugins/test_hover.py -@@ -1,10 +1,10 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_hover - import os - - from pylsp import uris --from pylsp.plugins.hover import pylsp_hover - from pylsp.workspace import Document - - DOC_URI = uris.from_fs_path(__file__) -diff --git a/test/plugins/test_jedi_rename.py b/test/plugins/test_jedi_rename.py -index 349274b..1ef8969 100644 ---- a/test/plugins/test_jedi_rename.py -+++ b/test/plugins/test_jedi_rename.py -@@ -1,12 +1,12 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_rename - import os - - import pytest - - from pylsp import uris --from pylsp.plugins.jedi_rename import pylsp_rename - from pylsp.workspace import Document - - DOC_NAME = "test1.py" -diff --git a/test/plugins/test_mccabe_lint.py b/test/plugins/test_mccabe_lint.py -index f4df0c2..64df075 100644 ---- a/test/plugins/test_mccabe_lint.py -+++ b/test/plugins/test_mccabe_lint.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_lint - from pylsp import lsp, uris - from pylsp.plugins import mccabe_lint - from pylsp.workspace import Document -@@ -19,7 +20,7 @@ def test_mccabe(config, workspace) -> None: - try: - config.update({"plugins": {"mccabe": {"threshold": 1}}}) - doc = Document(DOC_URI, workspace, DOC) -- diags = mccabe_lint.pylsp_lint(config, workspace, doc) -+ diags = pylsp_lint(config, workspace, doc) - - assert all(d["source"] == "mccabe" for d in diags) - -@@ -36,4 +37,4 @@ def test_mccabe(config, workspace) -> None: - - def test_mccabe_syntax_error(config, workspace) -> None: - doc = Document(DOC_URI, workspace, DOC_SYNTAX_ERR) -- assert mccabe_lint.pylsp_lint(config, workspace, doc) is None -+ assert pylsp_lint(config, workspace, doc) is None -diff --git a/test/plugins/test_pycodestyle_lint.py b/test/plugins/test_pycodestyle_lint.py -index eea0b7d..9e32e7f 100644 ---- a/test/plugins/test_pycodestyle_lint.py -+++ b/test/plugins/test_pycodestyle_lint.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_lint - import os - - import pytest -@@ -26,7 +27,7 @@ import json - - def test_pycodestyle(workspace) -> None: - doc = Document(DOC_URI, workspace, DOC) -- diags = pycodestyle_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - - assert all(d["source"] == "pycodestyle" for d in diags) - -@@ -84,7 +85,7 @@ def test_pycodestyle_config(workspace) -> None: - doc = workspace.get_document(doc_uri) - - # Make sure we get a warning for 'indentation contains tabs' -- diags = pycodestyle_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - assert [d for d in diags if d["code"] == "W191"] - - content = { -@@ -101,7 +102,7 @@ def test_pycodestyle_config(workspace) -> None: - workspace._config.settings.cache_clear() - - # And make sure we don't get any warnings -- diags = pycodestyle_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - assert len([d for d in diags if d["code"] == "W191"]) == (0 if working else 1) - assert len([d for d in diags if d["code"] == "E201"]) == (0 if working else 1) - assert [d for d in diags if d["code"] == "W391"] -@@ -111,7 +112,7 @@ def test_pycodestyle_config(workspace) -> None: - # Make sure we can ignore via the PYLS config as well - workspace._config.update({"plugins": {"pycodestyle": {"ignore": ["W191", "E201"]}}}) - # And make sure we only get one warning -- diags = pycodestyle_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - assert not [d for d in diags if d["code"] == "W191"] - assert not [d for d in diags if d["code"] == "E201"] - assert [d for d in diags if d["code"] == "W391"] -@@ -130,7 +131,7 @@ def test_line_endings(workspace, newline) -> None: - doc = Document(DOC_URI, workspace, source) - - # Get diagnostics -- diags = pycodestyle_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - - # Assert no diagnostics were given - assert len(diags) == 0 -diff --git a/test/plugins/test_pydocstyle_lint.py b/test/plugins/test_pydocstyle_lint.py -index 383aaf1..3efcf68 100644 ---- a/test/plugins/test_pydocstyle_lint.py -+++ b/test/plugins/test_pydocstyle_lint.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_lint - import os - - from pylsp import lsp, uris -@@ -21,7 +22,7 @@ import json - - def test_pydocstyle(config, workspace) -> None: - doc = Document(DOC_URI, workspace, DOC) -- diags = pydocstyle_lint.pylsp_lint(config, workspace, doc) -+ diags = pylsp_lint(config, workspace, doc) - - assert all(d["source"] == "pydocstyle" for d in diags) - -@@ -41,19 +42,19 @@ def test_pydocstyle(config, workspace) -> None: - def test_pydocstyle_test_document(config, workspace) -> None: - # The default --match argument excludes test_* documents. - doc = Document(TEST_DOC_URI, workspace, "") -- diags = pydocstyle_lint.pylsp_lint(config, workspace, doc) -+ diags = pylsp_lint(config, workspace, doc) - assert not diags - - - def test_pydocstyle_empty_source(config, workspace) -> None: - doc = Document(DOC_URI, workspace, "") -- diags = pydocstyle_lint.pylsp_lint(config, workspace, doc) -+ diags = pylsp_lint(config, workspace, doc) - assert diags[0]["message"] == "D100: Missing docstring in public module" - assert len(diags) == 1 - - - def test_pydocstyle_invalid_source(config, workspace) -> None: - doc = Document(DOC_URI, workspace, "bad syntax") -- diags = pydocstyle_lint.pylsp_lint(config, workspace, doc) -+ diags = pylsp_lint(config, workspace, doc) - # We're unable to parse the file, so can't get any pydocstyle diagnostics - assert not diags -diff --git a/test/plugins/test_pyflakes_lint.py b/test/plugins/test_pyflakes_lint.py -index 8ab3632..4034a74 100644 ---- a/test/plugins/test_pyflakes_lint.py -+++ b/test/plugins/test_pyflakes_lint.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_lint - import sys - - from pylsp import lsp, uris -@@ -30,7 +31,7 @@ import sys - - def test_pyflakes(workspace) -> None: - doc = Document(DOC_URI, workspace, DOC) -- diags = pyflakes_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - - # One we're expecting is: - msg = "'sys' imported but unused" -@@ -42,7 +43,7 @@ def test_pyflakes(workspace) -> None: - - def test_syntax_error_pyflakes(workspace) -> None: - doc = Document(DOC_URI, workspace, DOC_SYNTAX_ERR) -- diag = pyflakes_lint.pylsp_lint(workspace, doc)[0] -+ diag = pylsp_lint(workspace, doc)[0] - - if sys.version_info[:2] >= (3, 10): - assert diag["message"] == "expected ':'" -@@ -54,7 +55,7 @@ def test_syntax_error_pyflakes(workspace) -> None: - - def test_undefined_name_pyflakes(workspace) -> None: - doc = Document(DOC_URI, workspace, DOC_UNDEFINED_NAME_ERR) -- diag = pyflakes_lint.pylsp_lint(workspace, doc)[0] -+ diag = pylsp_lint(workspace, doc)[0] - - assert diag["message"] == "undefined name 'b'" - assert diag["range"]["start"] == {"line": 0, "character": 4} -@@ -63,7 +64,7 @@ def test_undefined_name_pyflakes(workspace) -> None: - - def test_unicode_encoding(workspace) -> None: - doc = Document(DOC_URI, workspace, DOC_ENCODING) -- diags = pyflakes_lint.pylsp_lint(workspace, doc) -+ diags = pylsp_lint(workspace, doc) - - assert len(diags) == 1 - assert diags[0]["message"] == "'sys' imported but unused" -diff --git a/test/plugins/test_pylint_lint.py b/test/plugins/test_pylint_lint.py -index b4d511d..c38794c 100644 ---- a/test/plugins/test_pylint_lint.py -+++ b/test/plugins/test_pylint_lint.py -@@ -2,6 +2,8 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import PylintLinter -+from pylsp.pylsp_shared import pylsp_lint - import contextlib - import os - import tempfile -@@ -43,7 +45,7 @@ def write_temp_doc(document, contents) -> None: - - def test_pylint(config, workspace) -> None: - with temp_document(DOC, workspace) as doc: -- diags = pylint_lint.pylsp_lint(config, workspace, doc, True) -+ diags = pylsp_lint(config, workspace, doc, True) - - msg = "[unused-import] Unused import sys" - unused_import = [d for d in diags if d["message"] == msg][0] -@@ -54,7 +56,7 @@ def test_pylint(config, workspace) -> None: - - # test running pylint in stdin - config.plugin_settings("pylint")["executable"] = "pylint" -- diags = pylint_lint.pylsp_lint(config, workspace, doc, True) -+ diags = pylsp_lint(config, workspace, doc, True) - - msg = "Unused import sys (unused-import)" - unused_import = [d for d in diags if d["message"] == msg][0] -@@ -68,7 +70,7 @@ def test_pylint(config, workspace) -> None: - - def test_syntax_error_pylint(config, workspace) -> None: - with temp_document(DOC_SYNTAX_ERR, workspace) as doc: -- diag = pylint_lint.pylsp_lint(config, workspace, doc, True)[0] -+ diag = pylsp_lint(config, workspace, doc, True)[0] - - assert diag["message"].startswith("[syntax-error]") - assert diag["message"].count("expected ':'") or diag["message"].count( -@@ -81,7 +83,7 @@ def test_syntax_error_pylint(config, workspace) -> None: - - # test running pylint in stdin - config.plugin_settings("pylint")["executable"] = "pylint" -- diag = pylint_lint.pylsp_lint(config, workspace, doc, True)[0] -+ diag = pylsp_lint(config, workspace, doc, True)[0] - - assert diag["message"].count("expected ':'") or diag["message"].count( - "invalid syntax" -@@ -96,7 +98,7 @@ def test_lint_free_pylint(config, workspace) -> None: - # match pylint's naming requirements. We should be keeping this file clean - # though, so it works for a test of an empty lint. - ws = Workspace(str(Path(__file__).absolute().parents[2]), workspace._endpoint) -- assert not pylint_lint.pylsp_lint( -+ assert not pylsp_lint( - config, ws, Document(uris.from_fs_path(__file__), ws), True - ) - -@@ -114,26 +116,26 @@ def test_lint_caching(workspace) -> None: - flags = "--disable=invalid-name" - with temp_document(DOC, workspace) as doc: - # Start with a file with errors. -- diags = pylint_lint.PylintLinter.lint(doc, True, flags) -+ diags = PylintLinter.lint(doc, True, flags) - assert diags - - # Fix lint errors and write the changes to disk. Run the linter in the - # in-memory mode to check the cached diagnostic behavior. - write_temp_doc(doc, "") -- assert pylint_lint.PylintLinter.lint(doc, False, flags) == diags -+ assert PylintLinter.lint(doc, False, flags) == diags - - # Now check the on-disk behavior. -- assert not pylint_lint.PylintLinter.lint(doc, True, flags) -+ assert not PylintLinter.lint(doc, True, flags) - - # Make sure the cache was properly cleared. -- assert not pylint_lint.PylintLinter.lint(doc, False, flags) -+ assert not PylintLinter.lint(doc, False, flags) - - - def test_per_file_caching(config, workspace) -> None: - # Ensure that diagnostics are cached per-file. - with temp_document(DOC, workspace) as doc: -- assert pylint_lint.pylsp_lint(config, workspace, doc, True) -+ assert pylsp_lint(config, workspace, doc, True) - -- assert not pylint_lint.pylsp_lint( -+ assert not pylsp_lint( - config, workspace, Document(uris.from_fs_path(__file__), workspace), False - ) -diff --git a/test/plugins/test_references.py b/test/plugins/test_references.py -index f512169..3283612 100644 ---- a/test/plugins/test_references.py -+++ b/test/plugins/test_references.py -@@ -1,12 +1,12 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_references - import os - - import pytest - - from pylsp import uris --from pylsp.plugins.references import pylsp_references - from pylsp.workspace import Document - - DOC1_NAME = "test1.py" -diff --git a/test/plugins/test_signature.py b/test/plugins/test_signature.py -index 4a0a84e..74ead16 100644 ---- a/test/plugins/test_signature.py -+++ b/test/plugins/test_signature.py -@@ -1,6 +1,10 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import SPHINX -+from pylsp.pylsp_shared import EPYDOC -+from pylsp.pylsp_shared import GOOGLE -+from pylsp.pylsp_shared import pylsp_signature_help - import pytest - - from pylsp import uris -@@ -47,7 +51,7 @@ def test_no_signature(workspace) -> None: - sig_position = {"line": 9, "character": 0} - doc = Document(DOC_URI, workspace, DOC) - -- sigs = signature.pylsp_signature_help(doc._config, doc, sig_position)["signatures"] -+ sigs = pylsp_signature_help(doc._config, doc, sig_position)["signatures"] - assert not sigs - - -@@ -56,7 +60,7 @@ def test_signature(workspace) -> None: - sig_position = {"line": 10, "character": 5} - doc = Document(DOC_URI, workspace, DOC) - -- sig_info = signature.pylsp_signature_help(doc._config, doc, sig_position) -+ sig_info = pylsp_signature_help(doc._config, doc, sig_position) - - sigs = sig_info["signatures"] - assert len(sigs) == 1 -@@ -75,7 +79,7 @@ def test_multi_line_signature(workspace) -> None: - sig_position = {"line": 17, "character": 5} - doc = Document(DOC_URI, workspace, MULTI_LINE_DOC) - -- sig_info = signature.pylsp_signature_help(doc._config, doc, sig_position) -+ sig_info = pylsp_signature_help(doc._config, doc, sig_position) - - sigs = sig_info["signatures"] - assert len(sigs) == 1 -@@ -95,9 +99,9 @@ def test_multi_line_signature(workspace) -> None: - @pytest.mark.parametrize( - "regex,doc", - [ -- (signature.SPHINX, " :param test: parameter docstring"), -- (signature.EPYDOC, " @param test: parameter docstring"), -- (signature.GOOGLE, " test (str): parameter docstring"), -+ (SPHINX, " :param test: parameter docstring"), -+ (EPYDOC, " @param test: parameter docstring"), -+ (GOOGLE, " test (str): parameter docstring"), - ], - ) - def test_docstring_params(regex, doc) -> None: -diff --git a/test/plugins/test_symbols.py b/test/plugins/test_symbols.py -index c00ab93..0ab1cc0 100644 ---- a/test/plugins/test_symbols.py -+++ b/test/plugins/test_symbols.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_document_symbols - import os - import sys - -@@ -8,7 +9,6 @@ import pytest - - from pylsp import uris - from pylsp.lsp import SymbolKind --from pylsp.plugins.symbols import pylsp_document_symbols - from pylsp.workspace import Document - - PY2 = sys.version[0] == "2" -diff --git a/test/plugins/test_yapf_format.py b/test/plugins/test_yapf_format.py -index f69541a..1886d18 100644 ---- a/test/plugins/test_yapf_format.py -+++ b/test/plugins/test_yapf_format.py -@@ -1,10 +1,11 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import pylsp_format_document -+from pylsp.pylsp_shared import pylsp_format_range - import pytest - - from pylsp import uris --from pylsp.plugins.yapf_format import pylsp_format_document, pylsp_format_range - from pylsp.text_edit import apply_text_edits - from pylsp.workspace import Document - -diff --git a/test/test_utils.py b/test/test_utils.py -index 07d04e3..f4f1f64 100644 ---- a/test/test_utils.py -+++ b/test/test_utils.py -@@ -1,6 +1,7 @@ - # Copyright 2017-2020 Palantir Technologies, Inc. - # Copyright 2021- Python Language Server Contributors. - -+from pylsp.pylsp_shared import PythonLSPServer - import multiprocessing - import os - import sys -@@ -14,7 +15,7 @@ from flaky import flaky - - from pylsp import _utils - from pylsp.lsp import NotebookCellKind --from pylsp.python_lsp import PythonLSPServer, start_io_lang_server -+from pylsp.python_lsp import start_io_lang_server - - CALL_TIMEOUT_IN_SECONDS = 30 - diff --git a/tests/integration/codemod/canonical/openapi_add_response_none/test_indra/expected_diff.patch.skip b/tests/integration/codemod/canonical/openapi_add_response_none/test_indra/expected_diff.patch.skip deleted file mode 100644 index 320b9d326..000000000 --- a/tests/integration/codemod/canonical/openapi_add_response_none/test_indra/expected_diff.patch.skip +++ /dev/null @@ -1,212 +0,0 @@ -diff --git a/rest_api/api.py b/rest_api/api.py -index e853c84eb..67cf2f176 100644 ---- a/rest_api/api.py -+++ b/rest_api/api.py -@@ -127,6 +127,7 @@ class RunPipeline(Resource): - def options(self): - return {} - -+ @preassembly_ns.response(200) - def post(self): - """Run an assembly pipeline for a list of Statements. - -@@ -292,6 +293,7 @@ for func_name, func in pipeline_functions.items(): - class NewFunction(PreassembleStatements): - func_name = func_name - -+ @preassembly_ns.response(200) - def post(self): - return super().post() - -@@ -320,6 +322,7 @@ class ReachProcessText(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process text with REACH and return INDRA Statements. - -@@ -375,6 +378,7 @@ class ReachProcessJson(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process REACH json and return INDRA Statements. - -@@ -401,6 +405,7 @@ class ReachProcessPmc(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process PubMedCentral article and return INDRA Statements. - -@@ -463,6 +468,7 @@ class TripsProcessText(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process text with TRIPS and return INDRA Statements. - -@@ -489,6 +495,7 @@ class TripsProcessText(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process TRIPS EKB XML and return INDRA Statements. - -@@ -523,6 +530,7 @@ class EidosProcessText(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process text with EIDOS and return biology INDRA Statements. - -@@ -557,6 +565,7 @@ class EidosProcessJsonld(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process an EIDOS JSON-LD and return biology INDRA Statements. - -@@ -588,6 +597,7 @@ class BelProcessNeighborhood(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process BEL Large Corpus neighborhood and return INDRA Statements. - -@@ -617,6 +627,7 @@ class BelProcessBelRdf(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process BEL RDF and return INDRA Statements. - -@@ -651,6 +662,7 @@ class BiopaxPathsBetween(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """ - Process PathwayCommons paths between genes, return INDRA Statements. -@@ -679,6 +691,7 @@ class BiopaxPathsFromTo(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """ - Process PathwayCommons paths from-to genes, return INDRA Statements. -@@ -714,6 +727,7 @@ class BiopaxNeighborhood(Resource): - def options(self): - return {} - -+ @sources_ns.response(200) - def post(self): - """Process PathwayCommons neighborhood, return INDRA Statements. - -@@ -747,6 +761,7 @@ class AssemblePysb(Resource): - def options(self): - return {} - -+ @assemblers_ns.response(200) - def post(self): - """Assemble INDRA Statements and return PySB model string. - -@@ -809,6 +824,7 @@ class AssembleCx(Resource): - def options(self): - return {} - -+ @assemblers_ns.response(200) - def post(self): - """Assemble INDRA Statements and return CX network json. - -@@ -838,6 +854,7 @@ class AssembleGraph(Resource): - def options(self): - return {} - -+ @assemblers_ns.response(200) - def post(self): - """Assemble INDRA Statements and return Graphviz graph dot string. - -@@ -867,6 +884,7 @@ class AssembleCyjs(Resource): - def options(self): - return {} - -+ @assemblers_ns.response(200) - def post(self): - """Assemble INDRA Statements and return Cytoscape JS network. - -@@ -896,6 +914,7 @@ class AssembleEnglish(Resource): - def options(self): - return {} - -+ @assemblers_ns.response(200) - def post(self): - """Assemble each statement into English sentence. - -@@ -929,6 +948,7 @@ class AssembleLoopy(Resource): - def options(self): - return {} - -+ @assemblers_ns.response(200) - def post(self): - """Assemble INDRA Statements into a Loopy model using SIF Assembler. - -@@ -963,6 +983,7 @@ class ShareModelNdex(Resource): - def options(self): - return {} - -+ @ndex_ns.response(200) - def post(self): - """Upload the model to NDEX. - -@@ -994,6 +1015,7 @@ class FetchModelNdex(Resource): - def options(self): - return {} - -+ @ndex_ns.response(200) - def post(self): - """Download model and associated pieces from NDEX. - -@@ -1033,6 +1055,7 @@ class GetEvidence(Resource): - def options(self): - return {} - -+ @indra_db_rest_ns.response(200) - def post(self): - """Get all evidence for a given INDRA statement. - -@@ -1107,6 +1130,7 @@ class CbioMrna(Resource): - def options(self): - return {} - -+ @databases_ns.response(200) - def post(self): - """Get CCLE mRNA amounts using cBioClient - -@@ -1139,6 +1163,7 @@ class CbioCna(Resource): - def options(self): - return {} - -+ @databases_ns.response(200) - def post(self): - """Get CCLE CNA - -@@ -1177,6 +1202,7 @@ class CbioMutations(Resource): - def options(self): - return {} - -+ @databases_ns.response(200) - def post(self): - """Get CCLE mutations - diff --git a/tests/integration/codemod/canonical/split_large_files/test_vite/expected_diff.patch.skip b/tests/integration/codemod/canonical/split_large_files/test_vite/expected_diff.patch.skip deleted file mode 100644 index 4e109d68b..000000000 --- a/tests/integration/codemod/canonical/split_large_files/test_vite/expected_diff.patch.skip +++ /dev/null @@ -1,25123 +0,0 @@ -diff --git a/packages/vite/src/node/__tests__/build.spec.ts b/packages/vite/src/node/__tests__/build.spec.ts -index 2dad85578..ae345af1c 100644 ---- a/packages/vite/src/node/__tests__/build.spec.ts -+++ b/packages/vite/src/node/__tests__/build.spec.ts -@@ -1,10 +1,17 @@ -+import { LibraryFormats } from 'packages/vite/src/node/build/BuildOptions'; -+import { LibraryOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { LibraryFormats } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { LibraryOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { build } from 'packages/vite/src/node/build/build'; -+import { LibraryFormats } from 'packages/vite/src/node/build/resolveBuildOutputs'; -+import { LibraryOptions } from 'packages/vite/src/node/build/resolveBuildOutputs'; -+import { resolveBuildOutputs } from 'packages/vite/src/node/build/resolveBuildOutputs'; - import { basename, resolve } from 'node:path' - import { fileURLToPath } from 'node:url' - import colors from 'picocolors' - import { describe, expect, test, vi } from 'vitest' - import type { OutputChunk, OutputOptions, RollupOutput } from 'rollup' --import type { LibraryFormats, LibraryOptions } from '../build' --import { build, resolveBuildOutputs, resolveLibFilename } from '../build' -+import { resolveLibFilename } from '../build' - import type { Logger } from '../logger' - import { createLogger } from '../logger' - -diff --git a/packages/vite/src/node/__tests__/config.spec.ts b/packages/vite/src/node/__tests__/config.spec.ts -index 9fbbdd61f..d37fe9703 100644 ---- a/packages/vite/src/node/__tests__/config.spec.ts -+++ b/packages/vite/src/node/__tests__/config.spec.ts -@@ -1,8 +1,14 @@ -+import { PluginOption } from 'packages/vite/src/node/config/UserConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/UserConfig'; -+import { PluginOption } from 'packages/vite/src/node/config/resolveConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { resolveConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { PluginOption } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { UserConfig } from 'packages/vite/src/node/config/loadConfigFromFile'; - import http from 'node:http' - import { describe, expect, test } from 'vitest' - import type { InlineConfig } from '..' --import type { PluginOption, UserConfig, UserConfigExport } from '../config' --import { defineConfig, resolveConfig } from '../config' -+import { defineConfig } from '../config' - import { resolveEnvPrefix } from '../env' - import { createLogger, mergeConfig } from '../publicUtils' - -diff --git a/packages/vite/src/node/__tests__/plugins/css.spec.ts b/packages/vite/src/node/__tests__/plugins/css.spec.ts -index e1c435211..b6629ec1b 100644 ---- a/packages/vite/src/node/__tests__/plugins/css.spec.ts -+++ b/packages/vite/src/node/__tests__/plugins/css.spec.ts -@@ -1,12 +1,14 @@ -+import { InlineConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { resolveConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { cssPlugin } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { cssUrlRE } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { cssUrlRE } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { cssUrlRE } from 'packages/vite/src/node/plugins/css/lessProcessor'; - import fs from 'node:fs' - import path from 'node:path' - import { describe, expect, test, vi } from 'vitest' --import { resolveConfig } from '../../config' --import type { InlineConfig } from '../../config' - import { - convertTargets, -- cssPlugin, -- cssUrlRE, - getEmptyChunkReplacer, - hoistAtRules, - preprocessCSS, -diff --git a/packages/vite/src/node/__tests__/plugins/define.spec.ts b/packages/vite/src/node/__tests__/plugins/define.spec.ts -index 2165461c7..83c22cd05 100644 ---- a/packages/vite/src/node/__tests__/plugins/define.spec.ts -+++ b/packages/vite/src/node/__tests__/plugins/define.spec.ts -@@ -1,6 +1,6 @@ -+import { resolveConfig } from 'packages/vite/src/node/config/resolveConfig'; - import { describe, expect, test } from 'vitest' - import { definePlugin } from '../../plugins/define' --import { resolveConfig } from '../../config' - - async function createDefinePluginTransform( - define: Record = {}, -diff --git a/packages/vite/src/node/__tests__/plugins/esbuild.spec.ts b/packages/vite/src/node/__tests__/plugins/esbuild.spec.ts -index 936415f9c..bcbd1358c 100644 ---- a/packages/vite/src/node/__tests__/plugins/esbuild.spec.ts -+++ b/packages/vite/src/node/__tests__/plugins/esbuild.spec.ts -@@ -1,5 +1,8 @@ -+import { UserConfig } from 'packages/vite/src/node/config/UserConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/loadConfigFromFile'; - import { describe, expect, test } from 'vitest' --import type { ResolvedConfig, UserConfig } from '../../config' - import { - resolveEsbuildTranspileOptions, - transformWithEsbuild, -diff --git a/packages/vite/src/node/__tests__/plugins/modulePreloadPolyfill/modulePreloadPolyfill.spec.ts b/packages/vite/src/node/__tests__/plugins/modulePreloadPolyfill/modulePreloadPolyfill.spec.ts -index 3b24fbd52..695b7cf6a 100644 ---- a/packages/vite/src/node/__tests__/plugins/modulePreloadPolyfill/modulePreloadPolyfill.spec.ts -+++ b/packages/vite/src/node/__tests__/plugins/modulePreloadPolyfill/modulePreloadPolyfill.spec.ts -@@ -1,6 +1,6 @@ -+import { build } from 'packages/vite/src/node/build/build'; - import { describe, it } from 'vitest' - import type { ModuleFormat, RollupOutput } from 'rollup' --import { build } from '../../../build' - import { modulePreloadPolyfillId } from '../../../plugins/modulePreloadPolyfill' - - const buildProject = ({ format = 'es' as ModuleFormat } = {}) => -diff --git a/packages/vite/src/node/build.ts b/packages/vite/src/node/build.ts -index d86393d36..daf2573d3 100644 ---- a/packages/vite/src/node/build.ts -+++ b/packages/vite/src/node/build.ts -@@ -1,3 +1,25 @@ -+import { LibraryOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { ResolveModulePreloadDependenciesFn } from 'packages/vite/src/node/build/BuildOptions'; -+import { ModulePreloadOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build'; -+import { LibraryOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ResolveModulePreloadDependenciesFn } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ResolvedBuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { resolveBuildOptions } from 'packages/vite/src/node/build'; -+import { build } from 'packages/vite/src/node/build'; -+import { LibraryOptions } from 'packages/vite/src/node/build/resolveBuildOutputs'; -+import { resolveBuildOutputs } from 'packages/vite/src/node/build/resolveBuildOutputs'; -+import { resolveBuildOutputs } from 'packages/vite/src/node/build'; -+import { warningIgnoreList } from 'packages/vite/src/node/build/onRollupWarning'; -+import { dynamicImportWarningIgnoreList } from 'packages/vite/src/node/build/onRollupWarning'; -+import { clearLine } from 'packages/vite/src/node/build/onRollupWarning'; -+import { onRollupWarning } from 'packages/vite/src/node/build/onRollupWarning'; -+import { onRollupWarning } from 'packages/vite/src/node/build'; -+import { InlineConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { resolveConfig } from 'packages/vite/src/node/config/resolveConfig'; - import fs from 'node:fs' - import path from 'node:path' - import colors from 'picocolors' -@@ -27,8 +49,6 @@ import { - ESBUILD_MODULES_TARGET, - VERSION, - } from './constants' --import type { InlineConfig, ResolvedConfig } from './config' --import { resolveConfig } from './config' - import { buildReporterPlugin } from './plugins/reporter' - import { buildEsbuildPlugin } from './plugins/esbuild' - import { type TerserOptions, terserPlugin } from './plugins/terser' -@@ -61,369 +81,6 @@ import { mergeConfig } from './publicUtils' - import { webWorkerPostPlugin } from './plugins/worker' - import { getHookHandler } from './plugins' - --export interface BuildOptions { -- /** -- * Compatibility transform target. The transform is performed with esbuild -- * and the lowest supported target is es2015/es6. Note this only handles -- * syntax transformation and does not cover polyfills (except for dynamic -- * import) -- * -- * Default: 'modules' - Similar to `@babel/preset-env`'s targets.esmodules, -- * transpile targeting browsers that natively support dynamic es module imports. -- * https://caniuse.com/es6-module-dynamic-import -- * -- * Another special value is 'esnext' - which only performs minimal transpiling -- * (for minification compat) and assumes native dynamic imports support. -- * -- * For custom targets, see https://esbuild.github.io/api/#target and -- * https://esbuild.github.io/content-types/#javascript for more details. -- * @default 'modules' -- */ -- target?: 'modules' | TransformOptions['target'] | false -- /** -- * whether to inject module preload polyfill. -- * Note: does not apply to library mode. -- * @default true -- * @deprecated use `modulePreload.polyfill` instead -- */ -- polyfillModulePreload?: boolean -- /** -- * Configure module preload -- * Note: does not apply to library mode. -- * @default true -- */ -- modulePreload?: boolean | ModulePreloadOptions -- /** -- * Directory relative from `root` where build output will be placed. If the -- * directory exists, it will be removed before the build. -- * @default 'dist' -- */ -- outDir?: string -- /** -- * Directory relative from `outDir` where the built js/css/image assets will -- * be placed. -- * @default 'assets' -- */ -- assetsDir?: string -- /** -- * Static asset files smaller than this number (in bytes) will be inlined as -- * base64 strings. Default limit is `4096` (4 KiB). Set to `0` to disable. -- * @default 4096 -- */ -- assetsInlineLimit?: -- | number -- | ((filePath: string, content: Buffer) => boolean | undefined) -- /** -- * Whether to code-split CSS. When enabled, CSS in async chunks will be -- * inlined as strings in the chunk and inserted via dynamically created -- * style tags when the chunk is loaded. -- * @default true -- */ -- cssCodeSplit?: boolean -- /** -- * An optional separate target for CSS minification. -- * As esbuild only supports configuring targets to mainstream -- * browsers, users may need this option when they are targeting -- * a niche browser that comes with most modern JavaScript features -- * but has poor CSS support, e.g. Android WeChat WebView, which -- * doesn't support the #RGBA syntax. -- * @default target -- */ -- cssTarget?: TransformOptions['target'] | false -- /** -- * Override CSS minification specifically instead of defaulting to `build.minify`, -- * so you can configure minification for JS and CSS separately. -- * @default 'esbuild' -- */ -- cssMinify?: boolean | 'esbuild' | 'lightningcss' -- /** -- * If `true`, a separate sourcemap file will be created. If 'inline', the -- * sourcemap will be appended to the resulting output file as data URI. -- * 'hidden' works like `true` except that the corresponding sourcemap -- * comments in the bundled files are suppressed. -- * @default false -- */ -- sourcemap?: boolean | 'inline' | 'hidden' -- /** -- * Set to `false` to disable minification, or specify the minifier to use. -- * Available options are 'terser' or 'esbuild'. -- * @default 'esbuild' -- */ -- minify?: boolean | 'terser' | 'esbuild' -- /** -- * Options for terser -- * https://terser.org/docs/api-reference#minify-options -- * -- * In addition, you can also pass a `maxWorkers: number` option to specify the -- * max number of workers to spawn. Defaults to the number of CPUs minus 1. -- */ -- terserOptions?: TerserOptions -- /** -- * Will be merged with internal rollup options. -- * https://rollupjs.org/configuration-options/ -- */ -- rollupOptions?: RollupOptions -- /** -- * Options to pass on to `@rollup/plugin-commonjs` -- */ -- commonjsOptions?: RollupCommonJSOptions -- /** -- * Options to pass on to `@rollup/plugin-dynamic-import-vars` -- */ -- dynamicImportVarsOptions?: RollupDynamicImportVarsOptions -- /** -- * Whether to write bundle to disk -- * @default true -- */ -- write?: boolean -- /** -- * Empty outDir on write. -- * @default true when outDir is a sub directory of project root -- */ -- emptyOutDir?: boolean | null -- /** -- * Copy the public directory to outDir on write. -- * @default true -- */ -- copyPublicDir?: boolean -- /** -- * Whether to emit a .vite/manifest.json under assets dir to map hash-less filenames -- * to their hashed versions. Useful when you want to generate your own HTML -- * instead of using the one generated by Vite. -- * -- * Example: -- * -- * ```json -- * { -- * "main.js": { -- * "file": "main.68fe3fad.js", -- * "css": "main.e6b63442.css", -- * "imports": [...], -- * "dynamicImports": [...] -- * } -- * } -- * ``` -- * @default false -- */ -- manifest?: boolean | string -- /** -- * Build in library mode. The value should be the global name of the lib in -- * UMD mode. This will produce esm + cjs + umd bundle formats with default -- * configurations that are suitable for distributing libraries. -- * @default false -- */ -- lib?: LibraryOptions | false -- /** -- * Produce SSR oriented build. Note this requires specifying SSR entry via -- * `rollupOptions.input`. -- * @default false -- */ -- ssr?: boolean | string -- /** -- * Generate SSR manifest for determining style links and asset preload -- * directives in production. -- * @default false -- */ -- ssrManifest?: boolean | string -- /** -- * Emit assets during SSR. -- * @default false -- */ -- ssrEmitAssets?: boolean -- /** -- * Set to false to disable reporting compressed chunk sizes. -- * Can slightly improve build speed. -- * @default true -- */ -- reportCompressedSize?: boolean -- /** -- * Adjust chunk size warning limit (in kB). -- * @default 500 -- */ -- chunkSizeWarningLimit?: number -- /** -- * Rollup watch options -- * https://rollupjs.org/configuration-options/#watch -- * @default null -- */ -- watch?: WatcherOptions | null --} -- --export interface LibraryOptions { -- /** -- * Path of library entry -- */ -- entry: InputOption -- /** -- * The name of the exposed global variable. Required when the `formats` option includes -- * `umd` or `iife` -- */ -- name?: string -- /** -- * Output bundle formats -- * @default ['es', 'umd'] -- */ -- formats?: LibraryFormats[] -- /** -- * The name of the package file output. The default file name is the name option -- * of the project package.json. It can also be defined as a function taking the -- * format as an argument. -- */ -- fileName?: string | ((format: ModuleFormat, entryName: string) => string) --} -- --export type LibraryFormats = 'es' | 'cjs' | 'umd' | 'iife' | 'system' -- --export interface ModulePreloadOptions { -- /** -- * Whether to inject a module preload polyfill. -- * Note: does not apply to library mode. -- * @default true -- */ -- polyfill?: boolean -- /** -- * Resolve the list of dependencies to preload for a given dynamic import -- * @experimental -- */ -- resolveDependencies?: ResolveModulePreloadDependenciesFn --} --export interface ResolvedModulePreloadOptions { -- polyfill: boolean -- resolveDependencies?: ResolveModulePreloadDependenciesFn --} -- --export type ResolveModulePreloadDependenciesFn = ( -- filename: string, -- deps: string[], -- context: { -- hostId: string -- hostType: 'html' | 'js' -- }, --) => string[] -- --export interface ResolvedBuildOptions -- extends Required> { -- modulePreload: false | ResolvedModulePreloadOptions --} -- --export function resolveBuildOptions( -- raw: BuildOptions | undefined, -- logger: Logger, -- root: string, --): ResolvedBuildOptions { -- const deprecatedPolyfillModulePreload = raw?.polyfillModulePreload -- if (raw) { -- const { polyfillModulePreload, ...rest } = raw -- raw = rest -- if (deprecatedPolyfillModulePreload !== undefined) { -- logger.warn( -- 'polyfillModulePreload is deprecated. Use modulePreload.polyfill instead.', -- ) -- } -- if ( -- deprecatedPolyfillModulePreload === false && -- raw.modulePreload === undefined -- ) { -- raw.modulePreload = { polyfill: false } -- } -- } -- -- const modulePreload = raw?.modulePreload -- const defaultModulePreload = { -- polyfill: true, -- } -- -- const defaultBuildOptions: BuildOptions = { -- outDir: 'dist', -- assetsDir: 'assets', -- assetsInlineLimit: DEFAULT_ASSETS_INLINE_LIMIT, -- cssCodeSplit: !raw?.lib, -- sourcemap: false, -- rollupOptions: {}, -- minify: raw?.ssr ? false : 'esbuild', -- terserOptions: {}, -- write: true, -- emptyOutDir: null, -- copyPublicDir: true, -- manifest: false, -- lib: false, -- ssr: false, -- ssrManifest: false, -- ssrEmitAssets: false, -- reportCompressedSize: true, -- chunkSizeWarningLimit: 500, -- watch: null, -- } -- -- const userBuildOptions = raw -- ? mergeConfig(defaultBuildOptions, raw) -- : defaultBuildOptions -- -- // @ts-expect-error Fallback options instead of merging -- const resolved: ResolvedBuildOptions = { -- target: 'modules', -- cssTarget: false, -- ...userBuildOptions, -- commonjsOptions: { -- include: [/node_modules/], -- extensions: ['.js', '.cjs'], -- ...userBuildOptions.commonjsOptions, -- }, -- dynamicImportVarsOptions: { -- warnOnError: true, -- exclude: [/node_modules/], -- ...userBuildOptions.dynamicImportVarsOptions, -- }, -- // Resolve to false | object -- modulePreload: -- modulePreload === false -- ? false -- : typeof modulePreload === 'object' -- ? { -- ...defaultModulePreload, -- ...modulePreload, -- } -- : defaultModulePreload, -- } -- -- // handle special build targets -- if (resolved.target === 'modules') { -- resolved.target = ESBUILD_MODULES_TARGET -- } else if (resolved.target === 'esnext' && resolved.minify === 'terser') { -- try { -- const terserPackageJsonPath = requireResolveFromRootWithFallback( -- root, -- 'terser/package.json', -- ) -- const terserPackageJson = JSON.parse( -- fs.readFileSync(terserPackageJsonPath, 'utf-8'), -- ) -- const v = terserPackageJson.version.split('.') -- if (v[0] === '5' && v[1] < 16) { -- // esnext + terser 5.16<: limit to es2021 so it can be minified by terser -- resolved.target = 'es2021' -- } -- } catch {} -- } -- -- if (!resolved.cssTarget) { -- resolved.cssTarget = resolved.target -- } -- -- // normalize false string into actual false -- if ((resolved.minify as string) === 'false') { -- resolved.minify = false -- } else if (resolved.minify === true) { -- resolved.minify = 'esbuild' -- } -- -- if (resolved.cssMinify == null) { -- resolved.cssMinify = !!resolved.minify -- } -- -- return resolved --} -- - export async function resolveBuildPlugins(config: ResolvedConfig): Promise<{ - pre: Plugin[] - post: Plugin[] -@@ -460,329 +117,6 @@ export async function resolveBuildPlugins(config: ResolvedConfig): Promise<{ - } - } - --/** -- * Bundles the app for production. -- * Returns a Promise containing the build result. -- */ --export async function build( -- inlineConfig: InlineConfig = {}, --): Promise { -- const config = await resolveConfig( -- inlineConfig, -- 'build', -- 'production', -- 'production', -- ) -- const options = config.build -- const { logger } = config -- const ssr = !!options.ssr -- const libOptions = options.lib -- -- logger.info( -- colors.cyan( -- `vite v${VERSION} ${colors.green( -- `building ${ssr ? `SSR bundle ` : ``}for ${config.mode}...`, -- )}`, -- ), -- ) -- -- const resolve = (p: string) => path.resolve(config.root, p) -- const input = libOptions -- ? options.rollupOptions?.input || -- (typeof libOptions.entry === 'string' -- ? resolve(libOptions.entry) -- : Array.isArray(libOptions.entry) -- ? libOptions.entry.map(resolve) -- : Object.fromEntries( -- Object.entries(libOptions.entry).map(([alias, file]) => [ -- alias, -- resolve(file), -- ]), -- )) -- : typeof options.ssr === 'string' -- ? resolve(options.ssr) -- : options.rollupOptions?.input || resolve('index.html') -- -- if (ssr && typeof input === 'string' && input.endsWith('.html')) { -- throw new Error( -- `rollupOptions.input should not be an html file when building for SSR. ` + -- `Please specify a dedicated SSR entry.`, -- ) -- } -- if (config.build.cssCodeSplit === false) { -- const inputs = -- typeof input === 'string' -- ? [input] -- : Array.isArray(input) -- ? input -- : Object.values(input) -- if (inputs.some((input) => input.endsWith('.css'))) { -- throw new Error( -- `When "build.cssCodeSplit: false" is set, "rollupOptions.input" should not include CSS files.`, -- ) -- } -- } -- -- const outDir = resolve(options.outDir) -- -- // inject ssr arg to plugin load/transform hooks -- const plugins = ( -- ssr ? config.plugins.map((p) => injectSsrFlagToHooks(p)) : config.plugins -- ) as Plugin[] -- -- const rollupOptions: RollupOptions = { -- preserveEntrySignatures: ssr -- ? 'allow-extension' -- : libOptions -- ? 'strict' -- : false, -- cache: config.build.watch ? undefined : false, -- ...options.rollupOptions, -- input, -- plugins, -- external: options.rollupOptions?.external, -- onwarn(warning, warn) { -- onRollupWarning(warning, warn, config) -- }, -- } -- -- /** -- * The stack string usually contains a copy of the message at the start of the stack. -- * If the stack starts with the message, we remove it and just return the stack trace -- * portion. Otherwise the original stack trace is used. -- */ -- function extractStack(e: RollupError) { -- const { stack, name = 'Error', message } = e -- -- // If we don't have a stack, not much we can do. -- if (!stack) { -- return stack -- } -- -- const expectedPrefix = `${name}: ${message}\n` -- if (stack.startsWith(expectedPrefix)) { -- return stack.slice(expectedPrefix.length) -- } -- -- return stack -- } -- -- /** -- * Esbuild code frames have newlines at the start and end of the frame, rollup doesn't -- * This function normalizes the frame to match the esbuild format which has more pleasing padding -- */ -- const normalizeCodeFrame = (frame: string) => { -- const trimmedPadding = frame.replace(/^\n|\n$/g, '') -- return `\n${trimmedPadding}\n` -- } -- -- const enhanceRollupError = (e: RollupError) => { -- const stackOnly = extractStack(e) -- -- let msg = colors.red((e.plugin ? `[${e.plugin}] ` : '') + e.message) -- if (e.id) { -- msg += `\nfile: ${colors.cyan( -- e.id + (e.loc ? `:${e.loc.line}:${e.loc.column}` : ''), -- )}` -- } -- if (e.frame) { -- msg += `\n` + colors.yellow(normalizeCodeFrame(e.frame)) -- } -- -- e.message = msg -- -- // We are rebuilding the stack trace to include the more detailed message at the top. -- // Previously this code was relying on mutating e.message changing the generated stack -- // when it was accessed, but we don't have any guarantees that the error we are working -- // with hasn't already had its stack accessed before we get here. -- if (stackOnly !== undefined) { -- e.stack = `${e.message}\n${stackOnly}` -- } -- } -- -- const outputBuildError = (e: RollupError) => { -- enhanceRollupError(e) -- clearLine() -- logger.error(e.message, { error: e }) -- } -- -- let bundle: RollupBuild | undefined -- let startTime: number | undefined -- try { -- const buildOutputOptions = (output: OutputOptions = {}): OutputOptions => { -- // @ts-expect-error See https://github.com/vitejs/vite/issues/5812#issuecomment-984345618 -- if (output.output) { -- logger.warn( -- `You've set "rollupOptions.output.output" in your config. ` + -- `This is deprecated and will override all Vite.js default output options. ` + -- `Please use "rollupOptions.output" instead.`, -- ) -- } -- if (output.file) { -- throw new Error( -- `Vite does not support "rollupOptions.output.file". ` + -- `Please use "rollupOptions.output.dir" and "rollupOptions.output.entryFileNames" instead.`, -- ) -- } -- if (output.sourcemap) { -- logger.warnOnce( -- colors.yellow( -- `Vite does not support "rollupOptions.output.sourcemap". ` + -- `Please use "build.sourcemap" instead.`, -- ), -- ) -- } -- -- const ssrNodeBuild = ssr && config.ssr.target === 'node' -- const ssrWorkerBuild = ssr && config.ssr.target === 'webworker' -- -- const format = output.format || 'es' -- const jsExt = -- ssrNodeBuild || libOptions -- ? resolveOutputJsExtension( -- format, -- findNearestPackageData(config.root, config.packageCache)?.data -- .type, -- ) -- : 'js' -- return { -- dir: outDir, -- // Default format is 'es' for regular and for SSR builds -- format, -- exports: 'auto', -- sourcemap: options.sourcemap, -- name: libOptions ? libOptions.name : undefined, -- hoistTransitiveImports: libOptions ? false : undefined, -- // es2015 enables `generatedCode.symbols` -- // - #764 add `Symbol.toStringTag` when build es module into cjs chunk -- // - #1048 add `Symbol.toStringTag` for module default export -- generatedCode: 'es2015', -- entryFileNames: ssr -- ? `[name].${jsExt}` -- : libOptions -- ? ({ name }) => -- resolveLibFilename( -- libOptions, -- format, -- name, -- config.root, -- jsExt, -- config.packageCache, -- ) -- : path.posix.join(options.assetsDir, `[name]-[hash].${jsExt}`), -- chunkFileNames: libOptions -- ? `[name]-[hash].${jsExt}` -- : path.posix.join(options.assetsDir, `[name]-[hash].${jsExt}`), -- assetFileNames: libOptions -- ? `[name].[ext]` -- : path.posix.join(options.assetsDir, `[name]-[hash].[ext]`), -- inlineDynamicImports: -- output.format === 'umd' || -- output.format === 'iife' || -- (ssrWorkerBuild && -- (typeof input === 'string' || Object.keys(input).length === 1)), -- ...output, -- } -- } -- -- // resolve lib mode outputs -- const outputs = resolveBuildOutputs( -- options.rollupOptions?.output, -- libOptions, -- logger, -- ) -- const normalizedOutputs: OutputOptions[] = [] -- -- if (Array.isArray(outputs)) { -- for (const resolvedOutput of outputs) { -- normalizedOutputs.push(buildOutputOptions(resolvedOutput)) -- } -- } else { -- normalizedOutputs.push(buildOutputOptions(outputs)) -- } -- -- const resolvedOutDirs = getResolvedOutDirs( -- config.root, -- options.outDir, -- options.rollupOptions?.output, -- ) -- const emptyOutDir = resolveEmptyOutDir( -- options.emptyOutDir, -- config.root, -- resolvedOutDirs, -- logger, -- ) -- -- // watch file changes with rollup -- if (config.build.watch) { -- logger.info(colors.cyan(`\nwatching for file changes...`)) -- -- const resolvedChokidarOptions = resolveChokidarOptions( -- config, -- config.build.watch.chokidar, -- resolvedOutDirs, -- emptyOutDir, -- ) -- -- const { watch } = await import('rollup') -- const watcher = watch({ -- ...rollupOptions, -- output: normalizedOutputs, -- watch: { -- ...config.build.watch, -- chokidar: resolvedChokidarOptions, -- }, -- }) -- -- watcher.on('event', (event) => { -- if (event.code === 'BUNDLE_START') { -- logger.info(colors.cyan(`\nbuild started...`)) -- if (options.write) { -- prepareOutDir(resolvedOutDirs, emptyOutDir, config) -- } -- } else if (event.code === 'BUNDLE_END') { -- event.result.close() -- logger.info(colors.cyan(`built in ${event.duration}ms.`)) -- } else if (event.code === 'ERROR') { -- outputBuildError(event.error) -- } -- }) -- -- return watcher -- } -- -- // write or generate files with rollup -- const { rollup } = await import('rollup') -- startTime = Date.now() -- bundle = await rollup(rollupOptions) -- -- if (options.write) { -- prepareOutDir(resolvedOutDirs, emptyOutDir, config) -- } -- -- const res: RollupOutput[] = [] -- for (const output of normalizedOutputs) { -- res.push(await bundle[options.write ? 'write' : 'generate'](output)) -- } -- logger.info( -- `${colors.green(`✓ built in ${displayTime(Date.now() - startTime)}`)}`, -- ) -- return Array.isArray(outputs) ? res : res[0] -- } catch (e) { -- enhanceRollupError(e) -- clearLine() -- if (startTime) { -- logger.error( -- `${colors.red('x')} Build failed in ${displayTime(Date.now() - startTime)}`, -- ) -- startTime = undefined -- } -- throw e -- } finally { -- if (bundle) await bundle.close() -- } --} -- - function prepareOutDir( - outDirs: Set, - emptyOutDir: boolean | null, -@@ -880,139 +214,6 @@ export function resolveLibFilename( - return `${name}.${format}.${extension}` - } - --export function resolveBuildOutputs( -- outputs: OutputOptions | OutputOptions[] | undefined, -- libOptions: LibraryOptions | false, -- logger: Logger, --): OutputOptions | OutputOptions[] | undefined { -- if (libOptions) { -- const libHasMultipleEntries = -- typeof libOptions.entry !== 'string' && -- Object.values(libOptions.entry).length > 1 -- const libFormats = -- libOptions.formats || -- (libHasMultipleEntries ? ['es', 'cjs'] : ['es', 'umd']) -- -- if (!Array.isArray(outputs)) { -- if (libFormats.includes('umd') || libFormats.includes('iife')) { -- if (libHasMultipleEntries) { -- throw new Error( -- 'Multiple entry points are not supported when output formats include "umd" or "iife".', -- ) -- } -- -- if (!libOptions.name) { -- throw new Error( -- 'Option "build.lib.name" is required when output formats include "umd" or "iife".', -- ) -- } -- } -- -- return libFormats.map((format) => ({ ...outputs, format })) -- } -- -- // By this point, we know "outputs" is an Array. -- if (libOptions.formats) { -- logger.warn( -- colors.yellow( -- '"build.lib.formats" will be ignored because "build.rollupOptions.output" is already an array format.', -- ), -- ) -- } -- -- outputs.forEach((output) => { -- if ( -- (output.format === 'umd' || output.format === 'iife') && -- !output.name -- ) { -- throw new Error( -- 'Entries in "build.rollupOptions.output" must specify "name" when the format is "umd" or "iife".', -- ) -- } -- }) -- } -- -- return outputs --} -- --const warningIgnoreList = [`CIRCULAR_DEPENDENCY`, `THIS_IS_UNDEFINED`] --const dynamicImportWarningIgnoreList = [ -- `Unsupported expression`, -- `statically analyzed`, --] -- --function clearLine() { -- const tty = process.stdout.isTTY && !process.env.CI -- if (tty) { -- process.stdout.clearLine(0) -- process.stdout.cursorTo(0) -- } --} -- --export function onRollupWarning( -- warning: RollupLog, -- warn: LoggingFunction, -- config: ResolvedConfig, --): void { -- const viteWarn: LoggingFunction = (warnLog) => { -- let warning: string | RollupLog -- -- if (typeof warnLog === 'function') { -- warning = warnLog() -- } else { -- warning = warnLog -- } -- -- if (typeof warning === 'object') { -- if (warning.code === 'UNRESOLVED_IMPORT') { -- const id = warning.id -- const exporter = warning.exporter -- // throw unless it's commonjs external... -- if (!id || !id.endsWith('?commonjs-external')) { -- throw new Error( -- `[vite]: Rollup failed to resolve import "${exporter}" from "${id}".\n` + -- `This is most likely unintended because it can break your application at runtime.\n` + -- `If you do want to externalize this module explicitly add it to\n` + -- `\`build.rollupOptions.external\``, -- ) -- } -- } -- -- if ( -- warning.plugin === 'rollup-plugin-dynamic-import-variables' && -- dynamicImportWarningIgnoreList.some((msg) => -- warning.message.includes(msg), -- ) -- ) { -- return -- } -- -- if (warningIgnoreList.includes(warning.code!)) { -- return -- } -- -- if (warning.code === 'PLUGIN_WARNING') { -- config.logger.warn( -- `${colors.bold( -- colors.yellow(`[plugin:${warning.plugin}]`), -- )} ${colors.yellow(warning.message)}`, -- ) -- return -- } -- } -- -- warn(warnLog) -- } -- -- clearLine() -- const userOnWarn = config.build.rollupOptions?.onwarn -- if (userOnWarn) { -- userOnWarn(warning, viteWarn) -- } else { -- viteWarn(warning) -- } --} -- - export function resolveUserExternal( - user: ExternalOption, - id: string, -diff --git a/packages/vite/src/node/build/BuildOptions.ts b/packages/vite/src/node/build/BuildOptions.ts -new file mode 100644 -index 000000000..7a1707a02 ---- /dev/null -+++ b/packages/vite/src/node/build/BuildOptions.ts -@@ -0,0 +1,257 @@ -+import type { -+ ExternalOption, -+ InputOption, -+ InternalModuleFormat, -+ LoggingFunction, -+ ModuleFormat, -+ OutputOptions, -+ Plugin, -+ RollupBuild, -+ RollupError, -+ RollupLog, -+ RollupOptions, -+ RollupOutput, -+ RollupWatcher, -+ WatcherOptions, -+} from 'rollup' -+import type { RollupCommonJSOptions } from 'dep-types/commonjs' -+import type { RollupDynamicImportVarsOptions } from 'dep-types/dynamicImportVars' -+import type { TransformOptions } from 'esbuild' -+import { TerserOptions } from 'packages/vite/src/node/plugins/terser'; -+ -+ -+export type LibraryFormats = 'es' | 'cjs' | 'umd' | 'iife' | 'system' -+ -+export interface LibraryOptions { -+ /** -+ * Path of library entry -+ */ -+ entry: InputOption -+ /** -+ * The name of the exposed global variable. Required when the `formats` option includes -+ * `umd` or `iife` -+ */ -+ name?: string -+ /** -+ * Output bundle formats -+ * @default ['es', 'umd'] -+ */ -+ formats?: LibraryFormats[] -+ /** -+ * The name of the package file output. The default file name is the name option -+ * of the project package.json. It can also be defined as a function taking the -+ * format as an argument. -+ */ -+ fileName?: string | ((format: ModuleFormat, entryName: string) => string) -+} -+ -+export type ResolveModulePreloadDependenciesFn = ( -+ filename: string, -+ deps: string[], -+ context: { -+ hostId: string -+ hostType: 'html' | 'js' -+ }, -+) => string[] -+ -+export interface ModulePreloadOptions { -+ /** -+ * Whether to inject a module preload polyfill. -+ * Note: does not apply to library mode. -+ * @default true -+ */ -+ polyfill?: boolean -+ /** -+ * Resolve the list of dependencies to preload for a given dynamic import -+ * @experimental -+ */ -+ resolveDependencies?: ResolveModulePreloadDependenciesFn -+} -+ -+export interface BuildOptions { -+ /** -+ * Compatibility transform target. The transform is performed with esbuild -+ * and the lowest supported target is es2015/es6. Note this only handles -+ * syntax transformation and does not cover polyfills (except for dynamic -+ * import) -+ * -+ * Default: 'modules' - Similar to `@babel/preset-env`'s targets.esmodules, -+ * transpile targeting browsers that natively support dynamic es module imports. -+ * https://caniuse.com/es6-module-dynamic-import -+ * -+ * Another special value is 'esnext' - which only performs minimal transpiling -+ * (for minification compat) and assumes native dynamic imports support. -+ * -+ * For custom targets, see https://esbuild.github.io/api/#target and -+ * https://esbuild.github.io/content-types/#javascript for more details. -+ * @default 'modules' -+ */ -+ target?: 'modules' | TransformOptions['target'] | false -+ /** -+ * whether to inject module preload polyfill. -+ * Note: does not apply to library mode. -+ * @default true -+ * @deprecated use `modulePreload.polyfill` instead -+ */ -+ polyfillModulePreload?: boolean -+ /** -+ * Configure module preload -+ * Note: does not apply to library mode. -+ * @default true -+ */ -+ modulePreload?: boolean | ModulePreloadOptions -+ /** -+ * Directory relative from `root` where build output will be placed. If the -+ * directory exists, it will be removed before the build. -+ * @default 'dist' -+ */ -+ outDir?: string -+ /** -+ * Directory relative from `outDir` where the built js/css/image assets will -+ * be placed. -+ * @default 'assets' -+ */ -+ assetsDir?: string -+ /** -+ * Static asset files smaller than this number (in bytes) will be inlined as -+ * base64 strings. Default limit is `4096` (4 KiB). Set to `0` to disable. -+ * @default 4096 -+ */ -+ assetsInlineLimit?: -+ | number -+ | ((filePath: string, content: Buffer) => boolean | undefined) -+ /** -+ * Whether to code-split CSS. When enabled, CSS in async chunks will be -+ * inlined as strings in the chunk and inserted via dynamically created -+ * style tags when the chunk is loaded. -+ * @default true -+ */ -+ cssCodeSplit?: boolean -+ /** -+ * An optional separate target for CSS minification. -+ * As esbuild only supports configuring targets to mainstream -+ * browsers, users may need this option when they are targeting -+ * a niche browser that comes with most modern JavaScript features -+ * but has poor CSS support, e.g. Android WeChat WebView, which -+ * doesn't support the #RGBA syntax. -+ * @default target -+ */ -+ cssTarget?: TransformOptions['target'] | false -+ /** -+ * Override CSS minification specifically instead of defaulting to `build.minify`, -+ * so you can configure minification for JS and CSS separately. -+ * @default 'esbuild' -+ */ -+ cssMinify?: boolean | 'esbuild' | 'lightningcss' -+ /** -+ * If `true`, a separate sourcemap file will be created. If 'inline', the -+ * sourcemap will be appended to the resulting output file as data URI. -+ * 'hidden' works like `true` except that the corresponding sourcemap -+ * comments in the bundled files are suppressed. -+ * @default false -+ */ -+ sourcemap?: boolean | 'inline' | 'hidden' -+ /** -+ * Set to `false` to disable minification, or specify the minifier to use. -+ * Available options are 'terser' or 'esbuild'. -+ * @default 'esbuild' -+ */ -+ minify?: boolean | 'terser' | 'esbuild' -+ /** -+ * Options for terser -+ * https://terser.org/docs/api-reference#minify-options -+ * -+ * In addition, you can also pass a `maxWorkers: number` option to specify the -+ * max number of workers to spawn. Defaults to the number of CPUs minus 1. -+ */ -+ terserOptions?: TerserOptions -+ /** -+ * Will be merged with internal rollup options. -+ * https://rollupjs.org/configuration-options/ -+ */ -+ rollupOptions?: RollupOptions -+ /** -+ * Options to pass on to `@rollup/plugin-commonjs` -+ */ -+ commonjsOptions?: RollupCommonJSOptions -+ /** -+ * Options to pass on to `@rollup/plugin-dynamic-import-vars` -+ */ -+ dynamicImportVarsOptions?: RollupDynamicImportVarsOptions -+ /** -+ * Whether to write bundle to disk -+ * @default true -+ */ -+ write?: boolean -+ /** -+ * Empty outDir on write. -+ * @default true when outDir is a sub directory of project root -+ */ -+ emptyOutDir?: boolean | null -+ /** -+ * Copy the public directory to outDir on write. -+ * @default true -+ */ -+ copyPublicDir?: boolean -+ /** -+ * Whether to emit a .vite/manifest.json under assets dir to map hash-less filenames -+ * to their hashed versions. Useful when you want to generate your own HTML -+ * instead of using the one generated by Vite. -+ * -+ * Example: -+ * -+ * ```json -+ * { -+ * "main.js": { -+ * "file": "main.68fe3fad.js", -+ * "css": "main.e6b63442.css", -+ * "imports": [...], -+ * "dynamicImports": [...] -+ * } -+ * } -+ * ``` -+ * @default false -+ */ -+ manifest?: boolean | string -+ /** -+ * Build in library mode. The value should be the global name of the lib in -+ * UMD mode. This will produce esm + cjs + umd bundle formats with default -+ * configurations that are suitable for distributing libraries. -+ * @default false -+ */ -+ lib?: LibraryOptions | false -+ /** -+ * Produce SSR oriented build. Note this requires specifying SSR entry via -+ * `rollupOptions.input`. -+ * @default false -+ */ -+ ssr?: boolean | string -+ /** -+ * Generate SSR manifest for determining style links and asset preload -+ * directives in production. -+ * @default false -+ */ -+ ssrManifest?: boolean | string -+ /** -+ * Emit assets during SSR. -+ * @default false -+ */ -+ ssrEmitAssets?: boolean -+ /** -+ * Set to false to disable reporting compressed chunk sizes. -+ * Can slightly improve build speed. -+ * @default true -+ */ -+ reportCompressedSize?: boolean -+ /** -+ * Adjust chunk size warning limit (in kB). -+ * @default 500 -+ */ -+ chunkSizeWarningLimit?: number -+ /** -+ * Rollup watch options -+ * https://rollupjs.org/configuration-options/#watch -+ * @default null -+ */ -+ watch?: WatcherOptions | null -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/build/build.ts b/packages/vite/src/node/build/build.ts -new file mode 100644 -index 000000000..c61e47ba3 ---- /dev/null -+++ b/packages/vite/src/node/build/build.ts -@@ -0,0 +1,350 @@ -+import path from 'node:path' -+import colors from 'picocolors' -+import type { -+ ExternalOption, -+ InputOption, -+ InternalModuleFormat, -+ LoggingFunction, -+ ModuleFormat, -+ OutputOptions, -+ Plugin, -+ RollupBuild, -+ RollupError, -+ RollupLog, -+ RollupOptions, -+ RollupOutput, -+ RollupWatcher, -+ WatcherOptions, -+} from 'rollup' -+import { VERSION } from 'packages/vite/src/node/constants'; -+import { InlineConfig } from 'packages/vite/src/node/config'; -+import { resolveConfig } from 'packages/vite/src/node/config'; -+import { displayTime } from 'packages/vite/src/node/utils'; -+import { findNearestPackageData } from 'packages/vite/src/node/packages'; -+import { getResolvedOutDirs } from 'packages/vite/src/node/watch'; -+import { resolveEmptyOutDir } from 'packages/vite/src/node/watch'; -+import { resolveChokidarOptions } from 'packages/vite/src/node/watch'; -+ -+ -+/** -+ * Bundles the app for production. -+ * Returns a Promise containing the build result. -+ */ -+export async function build( -+ inlineConfig: InlineConfig = {}, -+): Promise { -+ const config = await resolveConfig( -+ inlineConfig, -+ 'build', -+ 'production', -+ 'production', -+ ) -+ const options = config.build -+ const { logger } = config -+ const ssr = !!options.ssr -+ const libOptions = options.lib -+ -+ logger.info( -+ colors.cyan( -+ `vite v${VERSION} ${colors.green( -+ `building ${ssr ? `SSR bundle ` : ``}for ${config.mode}...`, -+ )}`, -+ ), -+ ) -+ -+ const resolve = (p: string) => path.resolve(config.root, p) -+ const input = libOptions -+ ? options.rollupOptions?.input || -+ (typeof libOptions.entry === 'string' -+ ? resolve(libOptions.entry) -+ : Array.isArray(libOptions.entry) -+ ? libOptions.entry.map(resolve) -+ : Object.fromEntries( -+ Object.entries(libOptions.entry).map(([alias, file]) => [ -+ alias, -+ resolve(file), -+ ]), -+ )) -+ : typeof options.ssr === 'string' -+ ? resolve(options.ssr) -+ : options.rollupOptions?.input || resolve('index.html') -+ -+ if (ssr && typeof input === 'string' && input.endsWith('.html')) { -+ throw new Error( -+ `rollupOptions.input should not be an html file when building for SSR. ` + -+ `Please specify a dedicated SSR entry.`, -+ ) -+ } -+ if (config.build.cssCodeSplit === false) { -+ const inputs = -+ typeof input === 'string' -+ ? [input] -+ : Array.isArray(input) -+ ? input -+ : Object.values(input) -+ if (inputs.some((input) => input.endsWith('.css'))) { -+ throw new Error( -+ `When "build.cssCodeSplit: false" is set, "rollupOptions.input" should not include CSS files.`, -+ ) -+ } -+ } -+ -+ const outDir = resolve(options.outDir) -+ -+ // inject ssr arg to plugin load/transform hooks -+ const plugins = ( -+ ssr ? config.plugins.map((p) => injectSsrFlagToHooks(p)) : config.plugins -+ ) as Plugin[] -+ -+ const rollupOptions: RollupOptions = { -+ preserveEntrySignatures: ssr -+ ? 'allow-extension' -+ : libOptions -+ ? 'strict' -+ : false, -+ cache: config.build.watch ? undefined : false, -+ ...options.rollupOptions, -+ input, -+ plugins, -+ external: options.rollupOptions?.external, -+ onwarn(warning, warn) { -+ onRollupWarning(warning, warn, config) -+ }, -+ } -+ -+ /** -+ * The stack string usually contains a copy of the message at the start of the stack. -+ * If the stack starts with the message, we remove it and just return the stack trace -+ * portion. Otherwise the original stack trace is used. -+ */ -+ function extractStack(e: RollupError) { -+ const { stack, name = 'Error', message } = e -+ -+ // If we don't have a stack, not much we can do. -+ if (!stack) { -+ return stack -+ } -+ -+ const expectedPrefix = `${name}: ${message}\n` -+ if (stack.startsWith(expectedPrefix)) { -+ return stack.slice(expectedPrefix.length) -+ } -+ -+ return stack -+ } -+ -+ /** -+ * Esbuild code frames have newlines at the start and end of the frame, rollup doesn't -+ * This function normalizes the frame to match the esbuild format which has more pleasing padding -+ */ -+ const normalizeCodeFrame = (frame: string) => { -+ const trimmedPadding = frame.replace(/^\n|\n$/g, '') -+ return `\n${trimmedPadding}\n` -+ } -+ -+ const enhanceRollupError = (e: RollupError) => { -+ const stackOnly = extractStack(e) -+ -+ let msg = colors.red((e.plugin ? `[${e.plugin}] ` : '') + e.message) -+ if (e.id) { -+ msg += `\nfile: ${colors.cyan( -+ e.id + (e.loc ? `:${e.loc.line}:${e.loc.column}` : ''), -+ )}` -+ } -+ if (e.frame) { -+ msg += `\n` + colors.yellow(normalizeCodeFrame(e.frame)) -+ } -+ -+ e.message = msg -+ -+ // We are rebuilding the stack trace to include the more detailed message at the top. -+ // Previously this code was relying on mutating e.message changing the generated stack -+ // when it was accessed, but we don't have any guarantees that the error we are working -+ // with hasn't already had its stack accessed before we get here. -+ if (stackOnly !== undefined) { -+ e.stack = `${e.message}\n${stackOnly}` -+ } -+ } -+ -+ const outputBuildError = (e: RollupError) => { -+ enhanceRollupError(e) -+ clearLine() -+ logger.error(e.message, { error: e }) -+ } -+ -+ let bundle: RollupBuild | undefined -+ let startTime: number | undefined -+ try { -+ const buildOutputOptions = (output: OutputOptions = {}): OutputOptions => { -+ // @ts-expect-error See https://github.com/vitejs/vite/issues/5812#issuecomment-984345618 -+ if (output.output) { -+ logger.warn( -+ `You've set "rollupOptions.output.output" in your config. ` + -+ `This is deprecated and will override all Vite.js default output options. ` + -+ `Please use "rollupOptions.output" instead.`, -+ ) -+ } -+ if (output.file) { -+ throw new Error( -+ `Vite does not support "rollupOptions.output.file". ` + -+ `Please use "rollupOptions.output.dir" and "rollupOptions.output.entryFileNames" instead.`, -+ ) -+ } -+ if (output.sourcemap) { -+ logger.warnOnce( -+ colors.yellow( -+ `Vite does not support "rollupOptions.output.sourcemap". ` + -+ `Please use "build.sourcemap" instead.`, -+ ), -+ ) -+ } -+ -+ const ssrNodeBuild = ssr && config.ssr.target === 'node' -+ const ssrWorkerBuild = ssr && config.ssr.target === 'webworker' -+ -+ const format = output.format || 'es' -+ const jsExt = -+ ssrNodeBuild || libOptions -+ ? resolveOutputJsExtension( -+ format, -+ findNearestPackageData(config.root, config.packageCache)?.data -+ .type, -+ ) -+ : 'js' -+ return { -+ dir: outDir, -+ // Default format is 'es' for regular and for SSR builds -+ format, -+ exports: 'auto', -+ sourcemap: options.sourcemap, -+ name: libOptions ? libOptions.name : undefined, -+ hoistTransitiveImports: libOptions ? false : undefined, -+ // es2015 enables `generatedCode.symbols` -+ // - #764 add `Symbol.toStringTag` when build es module into cjs chunk -+ // - #1048 add `Symbol.toStringTag` for module default export -+ generatedCode: 'es2015', -+ entryFileNames: ssr -+ ? `[name].${jsExt}` -+ : libOptions -+ ? ({ name }) => -+ resolveLibFilename( -+ libOptions, -+ format, -+ name, -+ config.root, -+ jsExt, -+ config.packageCache, -+ ) -+ : path.posix.join(options.assetsDir, `[name]-[hash].${jsExt}`), -+ chunkFileNames: libOptions -+ ? `[name]-[hash].${jsExt}` -+ : path.posix.join(options.assetsDir, `[name]-[hash].${jsExt}`), -+ assetFileNames: libOptions -+ ? `[name].[ext]` -+ : path.posix.join(options.assetsDir, `[name]-[hash].[ext]`), -+ inlineDynamicImports: -+ output.format === 'umd' || -+ output.format === 'iife' || -+ (ssrWorkerBuild && -+ (typeof input === 'string' || Object.keys(input).length === 1)), -+ ...output, -+ } -+ } -+ -+ // resolve lib mode outputs -+ const outputs = resolveBuildOutputs( -+ options.rollupOptions?.output, -+ libOptions, -+ logger, -+ ) -+ const normalizedOutputs: OutputOptions[] = [] -+ -+ if (Array.isArray(outputs)) { -+ for (const resolvedOutput of outputs) { -+ normalizedOutputs.push(buildOutputOptions(resolvedOutput)) -+ } -+ } else { -+ normalizedOutputs.push(buildOutputOptions(outputs)) -+ } -+ -+ const resolvedOutDirs = getResolvedOutDirs( -+ config.root, -+ options.outDir, -+ options.rollupOptions?.output, -+ ) -+ const emptyOutDir = resolveEmptyOutDir( -+ options.emptyOutDir, -+ config.root, -+ resolvedOutDirs, -+ logger, -+ ) -+ -+ // watch file changes with rollup -+ if (config.build.watch) { -+ logger.info(colors.cyan(`\nwatching for file changes...`)) -+ -+ const resolvedChokidarOptions = resolveChokidarOptions( -+ config, -+ config.build.watch.chokidar, -+ resolvedOutDirs, -+ emptyOutDir, -+ ) -+ -+ const { watch } = await import('rollup') -+ const watcher = watch({ -+ ...rollupOptions, -+ output: normalizedOutputs, -+ watch: { -+ ...config.build.watch, -+ chokidar: resolvedChokidarOptions, -+ }, -+ }) -+ -+ watcher.on('event', (event) => { -+ if (event.code === 'BUNDLE_START') { -+ logger.info(colors.cyan(`\nbuild started...`)) -+ if (options.write) { -+ prepareOutDir(resolvedOutDirs, emptyOutDir, config) -+ } -+ } else if (event.code === 'BUNDLE_END') { -+ event.result.close() -+ logger.info(colors.cyan(`built in ${event.duration}ms.`)) -+ } else if (event.code === 'ERROR') { -+ outputBuildError(event.error) -+ } -+ }) -+ -+ return watcher -+ } -+ -+ // write or generate files with rollup -+ const { rollup } = await import('rollup') -+ startTime = Date.now() -+ bundle = await rollup(rollupOptions) -+ -+ if (options.write) { -+ prepareOutDir(resolvedOutDirs, emptyOutDir, config) -+ } -+ -+ const res: RollupOutput[] = [] -+ for (const output of normalizedOutputs) { -+ res.push(await bundle[options.write ? 'write' : 'generate'](output)) -+ } -+ logger.info( -+ `${colors.green(`✓ built in ${displayTime(Date.now() - startTime)}`)}`, -+ ) -+ return Array.isArray(outputs) ? res : res[0] -+ } catch (e) { -+ enhanceRollupError(e) -+ clearLine() -+ if (startTime) { -+ logger.error( -+ `${colors.red('x')} Build failed in ${displayTime(Date.now() - startTime)}`, -+ ) -+ startTime = undefined -+ } -+ throw e -+ } finally { -+ if (bundle) await bundle.close() -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/build/onRollupWarning.ts b/packages/vite/src/node/build/onRollupWarning.ts -new file mode 100644 -index 000000000..b9697626b ---- /dev/null -+++ b/packages/vite/src/node/build/onRollupWarning.ts -@@ -0,0 +1,98 @@ -+import colors from 'picocolors' -+import type { -+ ExternalOption, -+ InputOption, -+ InternalModuleFormat, -+ LoggingFunction, -+ ModuleFormat, -+ OutputOptions, -+ Plugin, -+ RollupBuild, -+ RollupError, -+ RollupLog, -+ RollupOptions, -+ RollupOutput, -+ RollupWatcher, -+ WatcherOptions, -+} from 'rollup' -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+ -+ -+const export warningIgnoreList = [`CIRCULAR_DEPENDENCY`, `THIS_IS_UNDEFINED`] -+ -+const export dynamicImportWarningIgnoreList = [ -+ `Unsupported expression`, -+ `statically analyzed`, -+] -+ -+export function clearLine() { -+ const tty = process.stdout.isTTY && !process.env.CI -+ if (tty) { -+ process.stdout.clearLine(0) -+ process.stdout.cursorTo(0) -+ } -+} -+ -+export function onRollupWarning( -+ warning: RollupLog, -+ warn: LoggingFunction, -+ config: ResolvedConfig, -+): void { -+ const viteWarn: LoggingFunction = (warnLog) => { -+ let warning: string | RollupLog -+ -+ if (typeof warnLog === 'function') { -+ warning = warnLog() -+ } else { -+ warning = warnLog -+ } -+ -+ if (typeof warning === 'object') { -+ if (warning.code === 'UNRESOLVED_IMPORT') { -+ const id = warning.id -+ const exporter = warning.exporter -+ // throw unless it's commonjs external... -+ if (!id || !id.endsWith('?commonjs-external')) { -+ throw new Error( -+ `[vite]: Rollup failed to resolve import "${exporter}" from "${id}".\n` + -+ `This is most likely unintended because it can break your application at runtime.\n` + -+ `If you do want to externalize this module explicitly add it to\n` + -+ `\`build.rollupOptions.external\``, -+ ) -+ } -+ } -+ -+ if ( -+ warning.plugin === 'rollup-plugin-dynamic-import-variables' && -+ dynamicImportWarningIgnoreList.some((msg) => -+ warning.message.includes(msg), -+ ) -+ ) { -+ return -+ } -+ -+ if (warningIgnoreList.includes(warning.code!)) { -+ return -+ } -+ -+ if (warning.code === 'PLUGIN_WARNING') { -+ config.logger.warn( -+ `${colors.bold( -+ colors.yellow(`[plugin:${warning.plugin}]`), -+ )} ${colors.yellow(warning.message)}`, -+ ) -+ return -+ } -+ } -+ -+ warn(warnLog) -+ } -+ -+ clearLine() -+ const userOnWarn = config.build.rollupOptions?.onwarn -+ if (userOnWarn) { -+ userOnWarn(warning, viteWarn) -+ } else { -+ viteWarn(warning) -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/build/resolveBuildOptions.ts b/packages/vite/src/node/build/resolveBuildOptions.ts -new file mode 100644 -index 000000000..b93e728f4 ---- /dev/null -+++ b/packages/vite/src/node/build/resolveBuildOptions.ts -@@ -0,0 +1,391 @@ -+import fs from 'node:fs' -+import { DEFAULT_ASSETS_INLINE_LIMIT } from 'packages/vite/src/node/constants'; -+import { ESBUILD_MODULES_TARGET } from 'packages/vite/src/node/constants'; -+import { requireResolveFromRootWithFallback } from 'packages/vite/src/node/utils'; -+import { Logger } from 'packages/vite/src/node/logger'; -+import { mergeConfig } from 'packages/vite/src/node/publicUtils'; -+import type { -+ ExternalOption, -+ InputOption, -+ InternalModuleFormat, -+ LoggingFunction, -+ ModuleFormat, -+ OutputOptions, -+ Plugin, -+ RollupBuild, -+ RollupError, -+ RollupLog, -+ RollupOptions, -+ RollupOutput, -+ RollupWatcher, -+ WatcherOptions, -+} from 'rollup' -+import type { RollupCommonJSOptions } from 'dep-types/commonjs' -+import type { RollupDynamicImportVarsOptions } from 'dep-types/dynamicImportVars' -+import type { TransformOptions } from 'esbuild' -+import { TerserOptions } from 'packages/vite/src/node/plugins/terser'; -+ -+ -+export type LibraryFormats = 'es' | 'cjs' | 'umd' | 'iife' | 'system' -+ -+export interface LibraryOptions { -+ /** -+ * Path of library entry -+ */ -+ entry: InputOption -+ /** -+ * The name of the exposed global variable. Required when the `formats` option includes -+ * `umd` or `iife` -+ */ -+ name?: string -+ /** -+ * Output bundle formats -+ * @default ['es', 'umd'] -+ */ -+ formats?: LibraryFormats[] -+ /** -+ * The name of the package file output. The default file name is the name option -+ * of the project package.json. It can also be defined as a function taking the -+ * format as an argument. -+ */ -+ fileName?: string | ((format: ModuleFormat, entryName: string) => string) -+} -+ -+export type ResolveModulePreloadDependenciesFn = ( -+ filename: string, -+ deps: string[], -+ context: { -+ hostId: string -+ hostType: 'html' | 'js' -+ }, -+) => string[] -+ -+export interface ModulePreloadOptions { -+ /** -+ * Whether to inject a module preload polyfill. -+ * Note: does not apply to library mode. -+ * @default true -+ */ -+ polyfill?: boolean -+ /** -+ * Resolve the list of dependencies to preload for a given dynamic import -+ * @experimental -+ */ -+ resolveDependencies?: ResolveModulePreloadDependenciesFn -+} -+ -+export interface BuildOptions { -+ /** -+ * Compatibility transform target. The transform is performed with esbuild -+ * and the lowest supported target is es2015/es6. Note this only handles -+ * syntax transformation and does not cover polyfills (except for dynamic -+ * import) -+ * -+ * Default: 'modules' - Similar to `@babel/preset-env`'s targets.esmodules, -+ * transpile targeting browsers that natively support dynamic es module imports. -+ * https://caniuse.com/es6-module-dynamic-import -+ * -+ * Another special value is 'esnext' - which only performs minimal transpiling -+ * (for minification compat) and assumes native dynamic imports support. -+ * -+ * For custom targets, see https://esbuild.github.io/api/#target and -+ * https://esbuild.github.io/content-types/#javascript for more details. -+ * @default 'modules' -+ */ -+ target?: 'modules' | TransformOptions['target'] | false -+ /** -+ * whether to inject module preload polyfill. -+ * Note: does not apply to library mode. -+ * @default true -+ * @deprecated use `modulePreload.polyfill` instead -+ */ -+ polyfillModulePreload?: boolean -+ /** -+ * Configure module preload -+ * Note: does not apply to library mode. -+ * @default true -+ */ -+ modulePreload?: boolean | ModulePreloadOptions -+ /** -+ * Directory relative from `root` where build output will be placed. If the -+ * directory exists, it will be removed before the build. -+ * @default 'dist' -+ */ -+ outDir?: string -+ /** -+ * Directory relative from `outDir` where the built js/css/image assets will -+ * be placed. -+ * @default 'assets' -+ */ -+ assetsDir?: string -+ /** -+ * Static asset files smaller than this number (in bytes) will be inlined as -+ * base64 strings. Default limit is `4096` (4 KiB). Set to `0` to disable. -+ * @default 4096 -+ */ -+ assetsInlineLimit?: -+ | number -+ | ((filePath: string, content: Buffer) => boolean | undefined) -+ /** -+ * Whether to code-split CSS. When enabled, CSS in async chunks will be -+ * inlined as strings in the chunk and inserted via dynamically created -+ * style tags when the chunk is loaded. -+ * @default true -+ */ -+ cssCodeSplit?: boolean -+ /** -+ * An optional separate target for CSS minification. -+ * As esbuild only supports configuring targets to mainstream -+ * browsers, users may need this option when they are targeting -+ * a niche browser that comes with most modern JavaScript features -+ * but has poor CSS support, e.g. Android WeChat WebView, which -+ * doesn't support the #RGBA syntax. -+ * @default target -+ */ -+ cssTarget?: TransformOptions['target'] | false -+ /** -+ * Override CSS minification specifically instead of defaulting to `build.minify`, -+ * so you can configure minification for JS and CSS separately. -+ * @default 'esbuild' -+ */ -+ cssMinify?: boolean | 'esbuild' | 'lightningcss' -+ /** -+ * If `true`, a separate sourcemap file will be created. If 'inline', the -+ * sourcemap will be appended to the resulting output file as data URI. -+ * 'hidden' works like `true` except that the corresponding sourcemap -+ * comments in the bundled files are suppressed. -+ * @default false -+ */ -+ sourcemap?: boolean | 'inline' | 'hidden' -+ /** -+ * Set to `false` to disable minification, or specify the minifier to use. -+ * Available options are 'terser' or 'esbuild'. -+ * @default 'esbuild' -+ */ -+ minify?: boolean | 'terser' | 'esbuild' -+ /** -+ * Options for terser -+ * https://terser.org/docs/api-reference#minify-options -+ * -+ * In addition, you can also pass a `maxWorkers: number` option to specify the -+ * max number of workers to spawn. Defaults to the number of CPUs minus 1. -+ */ -+ terserOptions?: TerserOptions -+ /** -+ * Will be merged with internal rollup options. -+ * https://rollupjs.org/configuration-options/ -+ */ -+ rollupOptions?: RollupOptions -+ /** -+ * Options to pass on to `@rollup/plugin-commonjs` -+ */ -+ commonjsOptions?: RollupCommonJSOptions -+ /** -+ * Options to pass on to `@rollup/plugin-dynamic-import-vars` -+ */ -+ dynamicImportVarsOptions?: RollupDynamicImportVarsOptions -+ /** -+ * Whether to write bundle to disk -+ * @default true -+ */ -+ write?: boolean -+ /** -+ * Empty outDir on write. -+ * @default true when outDir is a sub directory of project root -+ */ -+ emptyOutDir?: boolean | null -+ /** -+ * Copy the public directory to outDir on write. -+ * @default true -+ */ -+ copyPublicDir?: boolean -+ /** -+ * Whether to emit a .vite/manifest.json under assets dir to map hash-less filenames -+ * to their hashed versions. Useful when you want to generate your own HTML -+ * instead of using the one generated by Vite. -+ * -+ * Example: -+ * -+ * ```json -+ * { -+ * "main.js": { -+ * "file": "main.68fe3fad.js", -+ * "css": "main.e6b63442.css", -+ * "imports": [...], -+ * "dynamicImports": [...] -+ * } -+ * } -+ * ``` -+ * @default false -+ */ -+ manifest?: boolean | string -+ /** -+ * Build in library mode. The value should be the global name of the lib in -+ * UMD mode. This will produce esm + cjs + umd bundle formats with default -+ * configurations that are suitable for distributing libraries. -+ * @default false -+ */ -+ lib?: LibraryOptions | false -+ /** -+ * Produce SSR oriented build. Note this requires specifying SSR entry via -+ * `rollupOptions.input`. -+ * @default false -+ */ -+ ssr?: boolean | string -+ /** -+ * Generate SSR manifest for determining style links and asset preload -+ * directives in production. -+ * @default false -+ */ -+ ssrManifest?: boolean | string -+ /** -+ * Emit assets during SSR. -+ * @default false -+ */ -+ ssrEmitAssets?: boolean -+ /** -+ * Set to false to disable reporting compressed chunk sizes. -+ * Can slightly improve build speed. -+ * @default true -+ */ -+ reportCompressedSize?: boolean -+ /** -+ * Adjust chunk size warning limit (in kB). -+ * @default 500 -+ */ -+ chunkSizeWarningLimit?: number -+ /** -+ * Rollup watch options -+ * https://rollupjs.org/configuration-options/#watch -+ * @default null -+ */ -+ watch?: WatcherOptions | null -+} -+ -+export interface ResolvedModulePreloadOptions { -+ polyfill: boolean -+ resolveDependencies?: ResolveModulePreloadDependenciesFn -+} -+ -+export interface ResolvedBuildOptions -+ extends Required> { -+ modulePreload: false | ResolvedModulePreloadOptions -+} -+ -+export function resolveBuildOptions( -+ raw: BuildOptions | undefined, -+ logger: Logger, -+ root: string, -+): ResolvedBuildOptions { -+ const deprecatedPolyfillModulePreload = raw?.polyfillModulePreload -+ if (raw) { -+ const { polyfillModulePreload, ...rest } = raw -+ raw = rest -+ if (deprecatedPolyfillModulePreload !== undefined) { -+ logger.warn( -+ 'polyfillModulePreload is deprecated. Use modulePreload.polyfill instead.', -+ ) -+ } -+ if ( -+ deprecatedPolyfillModulePreload === false && -+ raw.modulePreload === undefined -+ ) { -+ raw.modulePreload = { polyfill: false } -+ } -+ } -+ -+ const modulePreload = raw?.modulePreload -+ const defaultModulePreload = { -+ polyfill: true, -+ } -+ -+ const defaultBuildOptions: BuildOptions = { -+ outDir: 'dist', -+ assetsDir: 'assets', -+ assetsInlineLimit: DEFAULT_ASSETS_INLINE_LIMIT, -+ cssCodeSplit: !raw?.lib, -+ sourcemap: false, -+ rollupOptions: {}, -+ minify: raw?.ssr ? false : 'esbuild', -+ terserOptions: {}, -+ write: true, -+ emptyOutDir: null, -+ copyPublicDir: true, -+ manifest: false, -+ lib: false, -+ ssr: false, -+ ssrManifest: false, -+ ssrEmitAssets: false, -+ reportCompressedSize: true, -+ chunkSizeWarningLimit: 500, -+ watch: null, -+ } -+ -+ const userBuildOptions = raw -+ ? mergeConfig(defaultBuildOptions, raw) -+ : defaultBuildOptions -+ -+ // @ts-expect-error Fallback options instead of merging -+ const resolved: ResolvedBuildOptions = { -+ target: 'modules', -+ cssTarget: false, -+ ...userBuildOptions, -+ commonjsOptions: { -+ include: [/node_modules/], -+ extensions: ['.js', '.cjs'], -+ ...userBuildOptions.commonjsOptions, -+ }, -+ dynamicImportVarsOptions: { -+ warnOnError: true, -+ exclude: [/node_modules/], -+ ...userBuildOptions.dynamicImportVarsOptions, -+ }, -+ // Resolve to false | object -+ modulePreload: -+ modulePreload === false -+ ? false -+ : typeof modulePreload === 'object' -+ ? { -+ ...defaultModulePreload, -+ ...modulePreload, -+ } -+ : defaultModulePreload, -+ } -+ -+ // handle special build targets -+ if (resolved.target === 'modules') { -+ resolved.target = ESBUILD_MODULES_TARGET -+ } else if (resolved.target === 'esnext' && resolved.minify === 'terser') { -+ try { -+ const terserPackageJsonPath = requireResolveFromRootWithFallback( -+ root, -+ 'terser/package.json', -+ ) -+ const terserPackageJson = JSON.parse( -+ fs.readFileSync(terserPackageJsonPath, 'utf-8'), -+ ) -+ const v = terserPackageJson.version.split('.') -+ if (v[0] === '5' && v[1] < 16) { -+ // esnext + terser 5.16<: limit to es2021 so it can be minified by terser -+ resolved.target = 'es2021' -+ } -+ } catch {} -+ } -+ -+ if (!resolved.cssTarget) { -+ resolved.cssTarget = resolved.target -+ } -+ -+ // normalize false string into actual false -+ if ((resolved.minify as string) === 'false') { -+ resolved.minify = false -+ } else if (resolved.minify === true) { -+ resolved.minify = 'esbuild' -+ } -+ -+ if (resolved.cssMinify == null) { -+ resolved.cssMinify = !!resolved.minify -+ } -+ -+ return resolved -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/build/resolveBuildOutputs.ts b/packages/vite/src/node/build/resolveBuildOutputs.ts -new file mode 100644 -index 000000000..ac24200f3 ---- /dev/null -+++ b/packages/vite/src/node/build/resolveBuildOutputs.ts -@@ -0,0 +1,99 @@ -+import colors from 'picocolors' -+import type { -+ ExternalOption, -+ InputOption, -+ InternalModuleFormat, -+ LoggingFunction, -+ ModuleFormat, -+ OutputOptions, -+ Plugin, -+ RollupBuild, -+ RollupError, -+ RollupLog, -+ RollupOptions, -+ RollupOutput, -+ RollupWatcher, -+ WatcherOptions, -+} from 'rollup' -+import { Logger } from 'packages/vite/src/node/logger'; -+ -+ -+export type LibraryFormats = 'es' | 'cjs' | 'umd' | 'iife' | 'system' -+ -+export interface LibraryOptions { -+ /** -+ * Path of library entry -+ */ -+ entry: InputOption -+ /** -+ * The name of the exposed global variable. Required when the `formats` option includes -+ * `umd` or `iife` -+ */ -+ name?: string -+ /** -+ * Output bundle formats -+ * @default ['es', 'umd'] -+ */ -+ formats?: LibraryFormats[] -+ /** -+ * The name of the package file output. The default file name is the name option -+ * of the project package.json. It can also be defined as a function taking the -+ * format as an argument. -+ */ -+ fileName?: string | ((format: ModuleFormat, entryName: string) => string) -+} -+ -+export function resolveBuildOutputs( -+ outputs: OutputOptions | OutputOptions[] | undefined, -+ libOptions: LibraryOptions | false, -+ logger: Logger, -+): OutputOptions | OutputOptions[] | undefined { -+ if (libOptions) { -+ const libHasMultipleEntries = -+ typeof libOptions.entry !== 'string' && -+ Object.values(libOptions.entry).length > 1 -+ const libFormats = -+ libOptions.formats || -+ (libHasMultipleEntries ? ['es', 'cjs'] : ['es', 'umd']) -+ -+ if (!Array.isArray(outputs)) { -+ if (libFormats.includes('umd') || libFormats.includes('iife')) { -+ if (libHasMultipleEntries) { -+ throw new Error( -+ 'Multiple entry points are not supported when output formats include "umd" or "iife".', -+ ) -+ } -+ -+ if (!libOptions.name) { -+ throw new Error( -+ 'Option "build.lib.name" is required when output formats include "umd" or "iife".', -+ ) -+ } -+ } -+ -+ return libFormats.map((format) => ({ ...outputs, format })) -+ } -+ -+ // By this point, we know "outputs" is an Array. -+ if (libOptions.formats) { -+ logger.warn( -+ colors.yellow( -+ '"build.lib.formats" will be ignored because "build.rollupOptions.output" is already an array format.', -+ ), -+ ) -+ } -+ -+ outputs.forEach((output) => { -+ if ( -+ (output.format === 'umd' || output.format === 'iife') && -+ !output.name -+ ) { -+ throw new Error( -+ 'Entries in "build.rollupOptions.output" must specify "name" when the format is "umd" or "iife".', -+ ) -+ } -+ }) -+ } -+ -+ return outputs -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/cli.ts b/packages/vite/src/node/cli.ts -index f0fa20921..3be8d25d6 100644 ---- a/packages/vite/src/node/cli.ts -+++ b/packages/vite/src/node/cli.ts -@@ -1,15 +1,18 @@ -+import { BuildOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { build } from 'packages/vite/src/node/build/build'; -+import { resolveConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ServerOptions } from 'packages/vite/src/node/server/index/ServerOptions'; -+import { ServerOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; - import path from 'node:path' - import fs from 'node:fs' - import { performance } from 'node:perf_hooks' - import { cac } from 'cac' - import colors from 'picocolors' - import { VERSION } from './constants' --import type { BuildOptions } from './build' --import type { ServerOptions } from './server' - import type { CLIShortcut } from './shortcuts' - import type { LogLevel } from './logger' - import { createLogger } from './logger' --import { resolveConfig } from './config' - - const cli = cac('vite') - -@@ -265,7 +268,6 @@ cli - .option('-w, --watch', `[boolean] rebuilds when modules have changed on disk`) - .action(async (root: string, options: BuildOptions & GlobalCLIOptions) => { - filterDuplicateOptions(options) -- const { build } = await import('./build') - const buildOptions: BuildOptions = cleanOptions(options) - - try { -diff --git a/packages/vite/src/node/config.ts b/packages/vite/src/node/config.ts -index e38e5b595..325205b0f 100644 ---- a/packages/vite/src/node/config.ts -+++ b/packages/vite/src/node/config.ts -@@ -1,3 +1,48 @@ -+import { BuildOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ResolvedBuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { resolveBuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { AppType } from 'packages/vite/src/node/config/UserConfig'; -+import { PluginOption } from 'packages/vite/src/node/config/UserConfig'; -+import { HTMLOptions } from 'packages/vite/src/node/config/UserConfig'; -+import { ExperimentalOptions } from 'packages/vite/src/node/config/UserConfig'; -+import { LegacyOptions } from 'packages/vite/src/node/config/UserConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/UserConfig'; -+import { UserConfig } from 'packages/vite/src/node/config'; -+import { ConfigEnv } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedWorkerOptions } from 'packages/vite/src/node/config/resolveConfig'; -+import { AppType } from 'packages/vite/src/node/config/resolveConfig'; -+import { ExperimentalOptions } from 'packages/vite/src/node/config/resolveConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { InlineConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { checkBadCharactersInPath } from 'packages/vite/src/node/config/resolveConfig'; -+import { resolveConfig } from 'packages/vite/src/node/config'; -+import { ConfigEnv } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { AppType } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { ExperimentalOptions } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { UserConfig } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { loadConfigFromFile } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { loadConfigFromFile } from 'packages/vite/src/node/config'; -+import { bundleConfigFile } from 'packages/vite/src/node/config/bundleConfigFile'; -+import { bundleConfigFile } from 'packages/vite/src/node/config'; -+import { DepOptimizationConfig } from 'packages/vite/src/node/optimizer/index/DepOptimizationConfig'; -+import { ServerOptions } from 'packages/vite/src/node/server/index/ServerOptions'; -+import { ServerOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { ResolvedServerOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { resolveServerOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { resolvePlugin } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryCleanFsResolve'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryCleanFsResolve'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { tryNodeResolve } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePackageEntry'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePackageEntry'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; - import fs from 'node:fs' - import fsp from 'node:fs/promises' - import path from 'node:path' -@@ -21,14 +66,6 @@ import { - FS_PREFIX, - } from './constants' - import type { HookHandler, Plugin, PluginWithRequiredHook } from './plugin' --import type { -- BuildOptions, -- RenderBuiltAssetUrl, -- ResolvedBuildOptions, --} from './build' --import { resolveBuildOptions } from './build' --import type { ResolvedServerOptions, ServerOptions } from './server' --import { resolveServerOptions } from './server' - import type { PreviewOptions, ResolvedPreviewOptions } from './preview' - import { resolvePreviewOptions } from './preview' - import { -@@ -59,11 +96,9 @@ import { - resolvePlugins, - } from './plugins' - import type { ESBuildOptions } from './plugins/esbuild' --import type { InternalResolveOptions, ResolveOptions } from './plugins/resolve' --import { resolvePlugin, tryNodeResolve } from './plugins/resolve' - import type { LogLevel, Logger } from './logger' - import { createLogger } from './logger' --import type { DepOptimizationConfig, DepOptimizationOptions } from './optimizer' -+import type { DepOptimizationOptions } from './optimizer' - import type { JsonOptions } from './plugins/json' - import type { PluginContainer } from './server/pluginContainer' - import { createPluginContainer } from './server/pluginContainer' -@@ -76,26 +111,6 @@ import { resolveSSROptions } from './ssr' - const debug = createDebugger('vite:config') - const promisifiedRealpath = promisify(fs.realpath) - --export interface ConfigEnv { -- /** -- * 'serve': during dev (`vite` command) -- * 'build': when building for production (`vite build` command) -- */ -- command: 'build' | 'serve' -- mode: string -- isSsrBuild?: boolean -- isPreview?: boolean --} -- --/** -- * spa: include SPA fallback middleware and configure sirv with `single: true` in preview -- * -- * mpa: only include non-SPA HTML middlewares -- * -- * custom: don't include HTML middlewares -- */ --export type AppType = 'spa' | 'mpa' | 'custom' -- - export type UserConfigFnObject = (env: ConfigEnv) => UserConfig - export type UserConfigFnPromise = (env: ConfigEnv) => Promise - export type UserConfigFn = (env: ConfigEnv) => UserConfig | Promise -@@ -120,823 +135,6 @@ export function defineConfig(config: UserConfigExport): UserConfigExport { - return config - } - --export type PluginOption = -- | Plugin -- | false -- | null -- | undefined -- | PluginOption[] -- | Promise -- --export interface UserConfig { -- /** -- * Project root directory. Can be an absolute path, or a path relative from -- * the location of the config file itself. -- * @default process.cwd() -- */ -- root?: string -- /** -- * Base public path when served in development or production. -- * @default '/' -- */ -- base?: string -- /** -- * Directory to serve as plain static assets. Files in this directory are -- * served and copied to build dist dir as-is without transform. The value -- * can be either an absolute file system path or a path relative to project root. -- * -- * Set to `false` or an empty string to disable copied static assets to build dist dir. -- * @default 'public' -- */ -- publicDir?: string | false -- /** -- * Directory to save cache files. Files in this directory are pre-bundled -- * deps or some other cache files that generated by vite, which can improve -- * the performance. You can use `--force` flag or manually delete the directory -- * to regenerate the cache files. The value can be either an absolute file -- * system path or a path relative to project root. -- * Default to `.vite` when no `package.json` is detected. -- * @default 'node_modules/.vite' -- */ -- cacheDir?: string -- /** -- * Explicitly set a mode to run in. This will override the default mode for -- * each command, and can be overridden by the command line --mode option. -- */ -- mode?: string -- /** -- * Define global variable replacements. -- * Entries will be defined on `window` during dev and replaced during build. -- */ -- define?: Record -- /** -- * Array of vite plugins to use. -- */ -- plugins?: PluginOption[] -- /** -- * Configure resolver -- */ -- resolve?: ResolveOptions & { alias?: AliasOptions } -- /** -- * HTML related options -- */ -- html?: HTMLOptions -- /** -- * CSS related options (preprocessors and CSS modules) -- */ -- css?: CSSOptions -- /** -- * JSON loading options -- */ -- json?: JsonOptions -- /** -- * Transform options to pass to esbuild. -- * Or set to `false` to disable esbuild. -- */ -- esbuild?: ESBuildOptions | false -- /** -- * Specify additional picomatch patterns to be treated as static assets. -- */ -- assetsInclude?: string | RegExp | (string | RegExp)[] -- /** -- * Server specific options, e.g. host, port, https... -- */ -- server?: ServerOptions -- /** -- * Build specific options -- */ -- build?: BuildOptions -- /** -- * Preview specific options, e.g. host, port, https... -- */ -- preview?: PreviewOptions -- /** -- * Dep optimization options -- */ -- optimizeDeps?: DepOptimizationOptions -- /** -- * SSR specific options -- */ -- ssr?: SSROptions -- /** -- * Experimental features -- * -- * Features under this field could change in the future and might NOT follow semver. -- * Please be careful and always pin Vite's version when using them. -- * @experimental -- */ -- experimental?: ExperimentalOptions -- /** -- * Legacy options -- * -- * Features under this field only follow semver for patches, they could be removed in a -- * future minor version. Please always pin Vite's version to a minor when using them. -- */ -- legacy?: LegacyOptions -- /** -- * Log level. -- * @default 'info' -- */ -- logLevel?: LogLevel -- /** -- * Custom logger. -- */ -- customLogger?: Logger -- /** -- * @default true -- */ -- clearScreen?: boolean -- /** -- * Environment files directory. Can be an absolute path, or a path relative from -- * root. -- * @default root -- */ -- envDir?: string -- /** -- * Env variables starts with `envPrefix` will be exposed to your client source code via import.meta.env. -- * @default 'VITE_' -- */ -- envPrefix?: string | string[] -- /** -- * Worker bundle options -- */ -- worker?: { -- /** -- * Output format for worker bundle -- * @default 'iife' -- */ -- format?: 'es' | 'iife' -- /** -- * Vite plugins that apply to worker bundle. The plugins returned by this function -- * should be new instances every time it is called, because they are used for each -- * rollup worker bundling process. -- */ -- plugins?: () => PluginOption[] -- /** -- * Rollup options to build worker bundle -- */ -- rollupOptions?: Omit< -- RollupOptions, -- 'plugins' | 'input' | 'onwarn' | 'preserveEntrySignatures' -- > -- } -- /** -- * Whether your application is a Single Page Application (SPA), -- * a Multi-Page Application (MPA), or Custom Application (SSR -- * and frameworks with custom HTML handling) -- * @default 'spa' -- */ -- appType?: AppType --} -- --export interface HTMLOptions { -- /** -- * A nonce value placeholder that will be used when generating script/style tags. -- * -- * Make sure that this placeholder will be replaced with a unique value for each request by the server. -- */ -- cspNonce?: string --} -- --export interface ExperimentalOptions { -- /** -- * Append fake `&lang.(ext)` when queries are specified, to preserve the file extension for following plugins to process. -- * -- * @experimental -- * @default false -- */ -- importGlobRestoreExtension?: boolean -- /** -- * Allow finegrain control over assets and public files paths -- * -- * @experimental -- */ -- renderBuiltUrl?: RenderBuiltAssetUrl -- /** -- * Enables support of HMR partial accept via `import.meta.hot.acceptExports`. -- * -- * @experimental -- * @default false -- */ -- hmrPartialAccept?: boolean -- /** -- * Skips SSR transform to make it easier to use Vite with Node ESM loaders. -- * @warning Enabling this will break normal operation of Vite's SSR in development mode. -- * -- * @experimental -- * @default false -- */ -- skipSsrTransform?: boolean --} -- --export interface LegacyOptions { -- /** -- * In Vite 4, SSR-externalized modules (modules not bundled and loaded by Node.js at runtime) -- * are implicitly proxied in dev to automatically handle `default` and `__esModule` access. -- * However, this does not correctly reflect how it works in the Node.js runtime, causing -- * inconsistencies between dev and prod. -- * -- * In Vite 5, the proxy is removed so dev and prod are consistent, but if you still require -- * the old behaviour, you can enable this option. If so, please leave your feedback at -- * https://github.com/vitejs/vite/discussions/14697. -- */ -- proxySsrExternalModules?: boolean --} -- --export interface ResolvedWorkerOptions { -- format: 'es' | 'iife' -- plugins: (bundleChain: string[]) => Promise -- rollupOptions: RollupOptions --} -- --export interface InlineConfig extends UserConfig { -- configFile?: string | false -- envFile?: false --} -- --export type ResolvedConfig = Readonly< -- Omit< -- UserConfig, -- 'plugins' | 'css' | 'assetsInclude' | 'optimizeDeps' | 'worker' | 'build' -- > & { -- configFile: string | undefined -- configFileDependencies: string[] -- inlineConfig: InlineConfig -- root: string -- base: string -- /** @internal */ -- decodedBase: string -- /** @internal */ -- rawBase: string -- publicDir: string -- cacheDir: string -- command: 'build' | 'serve' -- mode: string -- isWorker: boolean -- // in nested worker bundle to find the main config -- /** @internal */ -- mainConfig: ResolvedConfig | null -- /** @internal list of bundle entry id. used to detect recursive worker bundle. */ -- bundleChain: string[] -- isProduction: boolean -- envDir: string -- env: Record -- resolve: Required & { -- alias: Alias[] -- } -- plugins: readonly Plugin[] -- css: ResolvedCSSOptions -- esbuild: ESBuildOptions | false -- server: ResolvedServerOptions -- build: ResolvedBuildOptions -- preview: ResolvedPreviewOptions -- ssr: ResolvedSSROptions -- assetsInclude: (file: string) => boolean -- logger: Logger -- createResolver: (options?: Partial) => ResolveFn -- optimizeDeps: DepOptimizationOptions -- /** @internal */ -- packageCache: PackageCache -- worker: ResolvedWorkerOptions -- appType: AppType -- experimental: ExperimentalOptions -- } & PluginHookUtils --> -- --export interface PluginHookUtils { -- getSortedPlugins: ( -- hookName: K, -- ) => PluginWithRequiredHook[] -- getSortedPluginHooks: ( -- hookName: K, -- ) => NonNullable>[] --} -- --export type ResolveFn = ( -- id: string, -- importer?: string, -- aliasOnly?: boolean, -- ssr?: boolean, --) => Promise -- --/** -- * Check and warn if `path` includes characters that don't work well in Vite, -- * such as `#` and `?`. -- */ --function checkBadCharactersInPath(path: string, logger: Logger): void { -- const badChars = [] -- -- if (path.includes('#')) { -- badChars.push('#') -- } -- if (path.includes('?')) { -- badChars.push('?') -- } -- -- if (badChars.length > 0) { -- const charString = badChars.map((c) => `"${c}"`).join(' and ') -- const inflectedChars = badChars.length > 1 ? 'characters' : 'character' -- -- logger.warn( -- colors.yellow( -- `The project root contains the ${charString} ${inflectedChars} (${colors.cyan( -- path, -- )}), which may not work when running Vite. Consider renaming the directory to remove the characters.`, -- ), -- ) -- } --} -- --export async function resolveConfig( -- inlineConfig: InlineConfig, -- command: 'build' | 'serve', -- defaultMode = 'development', -- defaultNodeEnv = 'development', -- isPreview = false, --): Promise { -- let config = inlineConfig -- let configFileDependencies: string[] = [] -- let mode = inlineConfig.mode || defaultMode -- const isNodeEnvSet = !!process.env.NODE_ENV -- const packageCache: PackageCache = new Map() -- -- // some dependencies e.g. @vue/compiler-* relies on NODE_ENV for getting -- // production-specific behavior, so set it early on -- if (!isNodeEnvSet) { -- process.env.NODE_ENV = defaultNodeEnv -- } -- -- const configEnv: ConfigEnv = { -- mode, -- command, -- isSsrBuild: command === 'build' && !!config.build?.ssr, -- isPreview, -- } -- -- let { configFile } = config -- if (configFile !== false) { -- const loadResult = await loadConfigFromFile( -- configEnv, -- configFile, -- config.root, -- config.logLevel, -- config.customLogger, -- ) -- if (loadResult) { -- config = mergeConfig(loadResult.config, config) -- configFile = loadResult.path -- configFileDependencies = loadResult.dependencies -- } -- } -- -- // user config may provide an alternative mode. But --mode has a higher priority -- mode = inlineConfig.mode || config.mode || mode -- configEnv.mode = mode -- -- const filterPlugin = (p: Plugin) => { -- if (!p) { -- return false -- } else if (!p.apply) { -- return true -- } else if (typeof p.apply === 'function') { -- return p.apply({ ...config, mode }, configEnv) -- } else { -- return p.apply === command -- } -- } -- -- // resolve plugins -- const rawUserPlugins = ( -- (await asyncFlatten(config.plugins || [])) as Plugin[] -- ).filter(filterPlugin) -- -- const [prePlugins, normalPlugins, postPlugins] = -- sortUserPlugins(rawUserPlugins) -- -- // run config hooks -- const userPlugins = [...prePlugins, ...normalPlugins, ...postPlugins] -- config = await runConfigHook(config, userPlugins, configEnv) -- -- // Define logger -- const logger = createLogger(config.logLevel, { -- allowClearScreen: config.clearScreen, -- customLogger: config.customLogger, -- }) -- -- // resolve root -- const resolvedRoot = normalizePath( -- config.root ? path.resolve(config.root) : process.cwd(), -- ) -- -- checkBadCharactersInPath(resolvedRoot, logger) -- -- const clientAlias = [ -- { -- find: /^\/?@vite\/env/, -- replacement: path.posix.join(FS_PREFIX, normalizePath(ENV_ENTRY)), -- }, -- { -- find: /^\/?@vite\/client/, -- replacement: path.posix.join(FS_PREFIX, normalizePath(CLIENT_ENTRY)), -- }, -- ] -- -- // resolve alias with internal client alias -- const resolvedAlias = normalizeAlias( -- mergeAlias(clientAlias, config.resolve?.alias || []), -- ) -- -- const resolveOptions: ResolvedConfig['resolve'] = { -- mainFields: config.resolve?.mainFields ?? DEFAULT_MAIN_FIELDS, -- conditions: config.resolve?.conditions ?? [], -- extensions: config.resolve?.extensions ?? DEFAULT_EXTENSIONS, -- dedupe: config.resolve?.dedupe ?? [], -- preserveSymlinks: config.resolve?.preserveSymlinks ?? false, -- alias: resolvedAlias, -- } -- -- if ( -- // @ts-expect-error removed field -- config.resolve?.browserField === false && -- resolveOptions.mainFields.includes('browser') -- ) { -- logger.warn( -- colors.yellow( -- `\`resolve.browserField\` is set to false, but the option is removed in favour of ` + -- `the 'browser' string in \`resolve.mainFields\`. You may want to update \`resolve.mainFields\` ` + -- `to remove the 'browser' string and preserve the previous browser behaviour.`, -- ), -- ) -- } -- -- // load .env files -- const envDir = config.envDir -- ? normalizePath(path.resolve(resolvedRoot, config.envDir)) -- : resolvedRoot -- const userEnv = -- inlineConfig.envFile !== false && -- loadEnv(mode, envDir, resolveEnvPrefix(config)) -- -- // Note it is possible for user to have a custom mode, e.g. `staging` where -- // development-like behavior is expected. This is indicated by NODE_ENV=development -- // loaded from `.staging.env` and set by us as VITE_USER_NODE_ENV -- const userNodeEnv = process.env.VITE_USER_NODE_ENV -- if (!isNodeEnvSet && userNodeEnv) { -- if (userNodeEnv === 'development') { -- process.env.NODE_ENV = 'development' -- } else { -- // NODE_ENV=production is not supported as it could break HMR in dev for frameworks like Vue -- logger.warn( -- `NODE_ENV=${userNodeEnv} is not supported in the .env file. ` + -- `Only NODE_ENV=development is supported to create a development build of your project. ` + -- `If you need to set process.env.NODE_ENV, you can set it in the Vite config instead.`, -- ) -- } -- } -- -- const isProduction = process.env.NODE_ENV === 'production' -- -- // resolve public base url -- const isBuild = command === 'build' -- const relativeBaseShortcut = config.base === '' || config.base === './' -- -- // During dev, we ignore relative base and fallback to '/' -- // For the SSR build, relative base isn't possible by means -- // of import.meta.url. -- const resolvedBase = relativeBaseShortcut -- ? !isBuild || config.build?.ssr -- ? '/' -- : './' -- : (resolveBaseUrl(config.base, isBuild, logger) ?? '/') -- -- const resolvedBuildOptions = resolveBuildOptions( -- config.build, -- logger, -- resolvedRoot, -- ) -- -- // resolve cache directory -- const pkgDir = findNearestPackageData(resolvedRoot, packageCache)?.dir -- const cacheDir = normalizePath( -- config.cacheDir -- ? path.resolve(resolvedRoot, config.cacheDir) -- : pkgDir -- ? path.join(pkgDir, `node_modules/.vite`) -- : path.join(resolvedRoot, `.vite`), -- ) -- -- const assetsFilter = -- config.assetsInclude && -- (!Array.isArray(config.assetsInclude) || config.assetsInclude.length) -- ? createFilter(config.assetsInclude) -- : () => false -- -- // create an internal resolver to be used in special scenarios, e.g. -- // optimizer & handling css @imports -- const createResolver: ResolvedConfig['createResolver'] = (options) => { -- let aliasContainer: PluginContainer | undefined -- let resolverContainer: PluginContainer | undefined -- return async (id, importer, aliasOnly, ssr) => { -- let container: PluginContainer -- if (aliasOnly) { -- container = -- aliasContainer || -- (aliasContainer = await createPluginContainer({ -- ...resolved, -- plugins: [aliasPlugin({ entries: resolved.resolve.alias })], -- })) -- } else { -- container = -- resolverContainer || -- (resolverContainer = await createPluginContainer({ -- ...resolved, -- plugins: [ -- aliasPlugin({ entries: resolved.resolve.alias }), -- resolvePlugin({ -- ...resolved.resolve, -- root: resolvedRoot, -- isProduction, -- isBuild: command === 'build', -- ssrConfig: resolved.ssr, -- asSrc: true, -- preferRelative: false, -- tryIndex: true, -- ...options, -- idOnly: true, -- fsUtils: getFsUtils(resolved), -- }), -- ], -- })) -- } -- return ( -- await container.resolveId(id, importer, { -- ssr, -- scan: options?.scan, -- }) -- )?.id -- } -- } -- -- const { publicDir } = config -- const resolvedPublicDir = -- publicDir !== false && publicDir !== '' -- ? normalizePath( -- path.resolve( -- resolvedRoot, -- typeof publicDir === 'string' ? publicDir : 'public', -- ), -- ) -- : '' -- -- const server = resolveServerOptions(resolvedRoot, config.server, logger) -- const ssr = resolveSSROptions(config.ssr, resolveOptions.preserveSymlinks) -- -- const optimizeDeps = config.optimizeDeps || {} -- -- const BASE_URL = resolvedBase -- -- let resolved: ResolvedConfig -- -- let createUserWorkerPlugins = config.worker?.plugins -- if (Array.isArray(createUserWorkerPlugins)) { -- // @ts-expect-error backward compatibility -- createUserWorkerPlugins = () => config.worker?.plugins -- -- logger.warn( -- colors.yellow( -- `worker.plugins is now a function that returns an array of plugins. ` + -- `Please update your Vite config accordingly.\n`, -- ), -- ) -- } -- -- const createWorkerPlugins = async function (bundleChain: string[]) { -- // Some plugins that aren't intended to work in the bundling of workers (doing post-processing at build time for example). -- // And Plugins may also have cached that could be corrupted by being used in these extra rollup calls. -- // So we need to separate the worker plugin from the plugin that vite needs to run. -- const rawWorkerUserPlugins = ( -- (await asyncFlatten(createUserWorkerPlugins?.() || [])) as Plugin[] -- ).filter(filterPlugin) -- -- // resolve worker -- let workerConfig = mergeConfig({}, config) -- const [workerPrePlugins, workerNormalPlugins, workerPostPlugins] = -- sortUserPlugins(rawWorkerUserPlugins) -- -- // run config hooks -- const workerUserPlugins = [ -- ...workerPrePlugins, -- ...workerNormalPlugins, -- ...workerPostPlugins, -- ] -- workerConfig = await runConfigHook( -- workerConfig, -- workerUserPlugins, -- configEnv, -- ) -- -- const workerResolved: ResolvedConfig = { -- ...workerConfig, -- ...resolved, -- isWorker: true, -- mainConfig: resolved, -- bundleChain, -- } -- const resolvedWorkerPlugins = await resolvePlugins( -- workerResolved, -- workerPrePlugins, -- workerNormalPlugins, -- workerPostPlugins, -- ) -- -- // run configResolved hooks -- await Promise.all( -- createPluginHookUtils(resolvedWorkerPlugins) -- .getSortedPluginHooks('configResolved') -- .map((hook) => hook(workerResolved)), -- ) -- -- return resolvedWorkerPlugins -- } -- -- const resolvedWorkerOptions: ResolvedWorkerOptions = { -- format: config.worker?.format || 'iife', -- plugins: createWorkerPlugins, -- rollupOptions: config.worker?.rollupOptions || {}, -- } -- -- const base = withTrailingSlash(resolvedBase) -- -- resolved = { -- configFile: configFile ? normalizePath(configFile) : undefined, -- configFileDependencies: configFileDependencies.map((name) => -- normalizePath(path.resolve(name)), -- ), -- inlineConfig, -- root: resolvedRoot, -- base, -- decodedBase: decodeURI(base), -- rawBase: resolvedBase, -- resolve: resolveOptions, -- publicDir: resolvedPublicDir, -- cacheDir, -- command, -- mode, -- ssr, -- isWorker: false, -- mainConfig: null, -- bundleChain: [], -- isProduction, -- plugins: userPlugins, -- css: resolveCSSOptions(config.css), -- esbuild: -- config.esbuild === false -- ? false -- : { -- jsxDev: !isProduction, -- ...config.esbuild, -- }, -- server, -- build: resolvedBuildOptions, -- preview: resolvePreviewOptions(config.preview, server), -- envDir, -- env: { -- ...userEnv, -- BASE_URL, -- MODE: mode, -- DEV: !isProduction, -- PROD: isProduction, -- }, -- assetsInclude(file: string) { -- return DEFAULT_ASSETS_RE.test(file) || assetsFilter(file) -- }, -- logger, -- packageCache, -- createResolver, -- optimizeDeps: { -- holdUntilCrawlEnd: true, -- ...optimizeDeps, -- esbuildOptions: { -- preserveSymlinks: resolveOptions.preserveSymlinks, -- ...optimizeDeps.esbuildOptions, -- }, -- }, -- worker: resolvedWorkerOptions, -- appType: config.appType ?? 'spa', -- experimental: { -- importGlobRestoreExtension: false, -- hmrPartialAccept: false, -- ...config.experimental, -- }, -- getSortedPlugins: undefined!, -- getSortedPluginHooks: undefined!, -- } -- resolved = { -- ...config, -- ...resolved, -- } -- ;(resolved.plugins as Plugin[]) = await resolvePlugins( -- resolved, -- prePlugins, -- normalPlugins, -- postPlugins, -- ) -- Object.assign(resolved, createPluginHookUtils(resolved.plugins)) -- -- // call configResolved hooks -- await Promise.all( -- resolved -- .getSortedPluginHooks('configResolved') -- .map((hook) => hook(resolved)), -- ) -- -- optimizeDepsDisabledBackwardCompatibility(resolved, resolved.optimizeDeps) -- optimizeDepsDisabledBackwardCompatibility( -- resolved, -- resolved.ssr.optimizeDeps, -- 'ssr.', -- ) -- -- debug?.(`using resolved config: %O`, { -- ...resolved, -- plugins: resolved.plugins.map((p) => p.name), -- worker: { -- ...resolved.worker, -- plugins: `() => plugins`, -- }, -- }) -- -- // validate config -- -- if ( -- config.build?.terserOptions && -- config.build.minify && -- config.build.minify !== 'terser' -- ) { -- logger.warn( -- colors.yellow( -- `build.terserOptions is specified but build.minify is not set to use Terser. ` + -- `Note Vite now defaults to use esbuild for minification. If you still ` + -- `prefer Terser, set build.minify to "terser".`, -- ), -- ) -- } -- -- // Check if all assetFileNames have the same reference. -- // If not, display a warn for user. -- const outputOption = config.build?.rollupOptions?.output ?? [] -- // Use isArray to narrow its type to array -- if (Array.isArray(outputOption)) { -- const assetFileNamesList = outputOption.map( -- (output) => output.assetFileNames, -- ) -- if (assetFileNamesList.length > 1) { -- const firstAssetFileNames = assetFileNamesList[0] -- const hasDifferentReference = assetFileNamesList.some( -- (assetFileNames) => assetFileNames !== firstAssetFileNames, -- ) -- if (hasDifferentReference) { -- resolved.logger.warn( -- colors.yellow(` --assetFileNames isn't equal for every build.rollupOptions.output. A single pattern across all outputs is supported by Vite. --`), -- ) -- } -- } -- } -- -- // Warn about removal of experimental features -- if ( -- // @ts-expect-error Option removed -- config.legacy?.buildSsrCjsExternalHeuristics || -- // @ts-expect-error Option removed -- config.ssr?.format === 'cjs' -- ) { -- resolved.logger.warn( -- colors.yellow(` --(!) Experimental legacy.buildSsrCjsExternalHeuristics and ssr.format were be removed in Vite 5. -- The only SSR Output format is ESM. Find more information at https://github.com/vitejs/vite/discussions/13816. --`), -- ) -- } -- -- const resolvedBuildOutDir = normalizePath( -- path.resolve(resolved.root, resolved.build.outDir), -- ) -- if ( -- isParentDirectory(resolvedBuildOutDir, resolved.root) || -- resolvedBuildOutDir === resolved.root -- ) { -- resolved.logger.warn( -- colors.yellow(` --(!) build.outDir must not be the same directory of root or a parent directory of root as this could cause Vite to overwriting source files with build outputs. --`), -- ) -- } -- -- return resolved --} -- - /** - * Resolve base url. Note that some users use Vite to build for non-web targets like - * electron or expects to deploy -@@ -999,227 +197,6 @@ export function sortUserPlugins( - return [prePlugins, normalPlugins, postPlugins] - } - --export async function loadConfigFromFile( -- configEnv: ConfigEnv, -- configFile?: string, -- configRoot: string = process.cwd(), -- logLevel?: LogLevel, -- customLogger?: Logger, --): Promise<{ -- path: string -- config: UserConfig -- dependencies: string[] --} | null> { -- const start = performance.now() -- const getTime = () => `${(performance.now() - start).toFixed(2)}ms` -- -- let resolvedPath: string | undefined -- -- if (configFile) { -- // explicit config path is always resolved from cwd -- resolvedPath = path.resolve(configFile) -- } else { -- // implicit config file loaded from inline root (if present) -- // otherwise from cwd -- for (const filename of DEFAULT_CONFIG_FILES) { -- const filePath = path.resolve(configRoot, filename) -- if (!fs.existsSync(filePath)) continue -- -- resolvedPath = filePath -- break -- } -- } -- -- if (!resolvedPath) { -- debug?.('no config file found.') -- return null -- } -- -- const isESM = isFilePathESM(resolvedPath) -- -- try { -- const bundled = await bundleConfigFile(resolvedPath, isESM) -- const userConfig = await loadConfigFromBundledFile( -- resolvedPath, -- bundled.code, -- isESM, -- ) -- debug?.(`bundled config file loaded in ${getTime()}`) -- -- const config = await (typeof userConfig === 'function' -- ? userConfig(configEnv) -- : userConfig) -- if (!isObject(config)) { -- throw new Error(`config must export or return an object.`) -- } -- return { -- path: normalizePath(resolvedPath), -- config, -- dependencies: bundled.dependencies, -- } -- } catch (e) { -- createLogger(logLevel, { customLogger }).error( -- colors.red(`failed to load config from ${resolvedPath}`), -- { -- error: e, -- }, -- ) -- throw e -- } --} -- --async function bundleConfigFile( -- fileName: string, -- isESM: boolean, --): Promise<{ code: string; dependencies: string[] }> { -- const dirnameVarName = '__vite_injected_original_dirname' -- const filenameVarName = '__vite_injected_original_filename' -- const importMetaUrlVarName = '__vite_injected_original_import_meta_url' -- const result = await build({ -- absWorkingDir: process.cwd(), -- entryPoints: [fileName], -- write: false, -- target: [`node${process.versions.node}`], -- platform: 'node', -- bundle: true, -- format: isESM ? 'esm' : 'cjs', -- mainFields: ['main'], -- sourcemap: 'inline', -- metafile: true, -- define: { -- __dirname: dirnameVarName, -- __filename: filenameVarName, -- 'import.meta.url': importMetaUrlVarName, -- 'import.meta.dirname': dirnameVarName, -- 'import.meta.filename': filenameVarName, -- }, -- plugins: [ -- { -- name: 'externalize-deps', -- setup(build) { -- const packageCache = new Map() -- const resolveByViteResolver = ( -- id: string, -- importer: string, -- isRequire: boolean, -- ) => { -- return tryNodeResolve( -- id, -- importer, -- { -- root: path.dirname(fileName), -- isBuild: true, -- isProduction: true, -- preferRelative: false, -- tryIndex: true, -- mainFields: [], -- conditions: [], -- overrideConditions: ['node'], -- dedupe: [], -- extensions: DEFAULT_EXTENSIONS, -- preserveSymlinks: false, -- packageCache, -- isRequire, -- }, -- false, -- )?.id -- } -- -- // externalize bare imports -- build.onResolve( -- { filter: /^[^.].*/ }, -- async ({ path: id, importer, kind }) => { -- if ( -- kind === 'entry-point' || -- path.isAbsolute(id) || -- isNodeBuiltin(id) -- ) { -- return -- } -- -- // With the `isNodeBuiltin` check above, this check captures if the builtin is a -- // non-node built-in, which esbuild doesn't know how to handle. In that case, we -- // externalize it so the non-node runtime handles it instead. -- if (isBuiltin(id)) { -- return { external: true } -- } -- -- const isImport = isESM || kind === 'dynamic-import' -- let idFsPath: string | undefined -- try { -- idFsPath = resolveByViteResolver(id, importer, !isImport) -- } catch (e) { -- if (!isImport) { -- let canResolveWithImport = false -- try { -- canResolveWithImport = !!resolveByViteResolver( -- id, -- importer, -- false, -- ) -- } catch {} -- if (canResolveWithImport) { -- throw new Error( -- `Failed to resolve ${JSON.stringify( -- id, -- )}. This package is ESM only but it was tried to load by \`require\`. See https://vitejs.dev/guide/troubleshooting.html#this-package-is-esm-only for more details.`, -- ) -- } -- } -- throw e -- } -- if (idFsPath && isImport) { -- idFsPath = pathToFileURL(idFsPath).href -- } -- if ( -- idFsPath && -- !isImport && -- isFilePathESM(idFsPath, packageCache) -- ) { -- throw new Error( -- `${JSON.stringify( -- id, -- )} resolved to an ESM file. ESM file cannot be loaded by \`require\`. See https://vitejs.dev/guide/troubleshooting.html#this-package-is-esm-only for more details.`, -- ) -- } -- return { -- path: idFsPath, -- external: true, -- } -- }, -- ) -- }, -- }, -- { -- name: 'inject-file-scope-variables', -- setup(build) { -- build.onLoad({ filter: /\.[cm]?[jt]s$/ }, async (args) => { -- const contents = await fsp.readFile(args.path, 'utf-8') -- const injectValues = -- `const ${dirnameVarName} = ${JSON.stringify( -- path.dirname(args.path), -- )};` + -- `const ${filenameVarName} = ${JSON.stringify(args.path)};` + -- `const ${importMetaUrlVarName} = ${JSON.stringify( -- pathToFileURL(args.path).href, -- )};` -- -- return { -- loader: args.path.endsWith('ts') ? 'ts' : 'js', -- contents: injectValues + contents, -- } -- }) -- }, -- }, -- ], -- }) -- const { text } = result.outputFiles[0] -- return { -- code: text, -- dependencies: result.metafile ? Object.keys(result.metafile.inputs) : [], -- } --} -- - interface NodeModuleWithCompile extends NodeModule { - _compile(code: string, filename: string): any - } -diff --git a/packages/vite/src/node/config/UserConfig.ts b/packages/vite/src/node/config/UserConfig.ts -new file mode 100644 -index 000000000..5aa2ffede ---- /dev/null -+++ b/packages/vite/src/node/config/UserConfig.ts -@@ -0,0 +1,248 @@ -+import type { Alias, AliasOptions } from 'dep-types/alias' -+import type { RollupOptions } from 'rollup' -+import { BuildOptions } from 'packages/vite/src/node/build'; -+import { ServerOptions } from 'packages/vite/src/node/server/index'; -+import { PreviewOptions } from 'packages/vite/src/node/preview'; -+import { CSSOptions } from 'packages/vite/src/node/plugins/css'; -+import { ESBuildOptions } from 'packages/vite/src/node/plugins/esbuild'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve'; -+import { LogLevel } from 'packages/vite/src/node/logger'; -+import { Logger } from 'packages/vite/src/node/logger'; -+import { DepOptimizationOptions } from 'packages/vite/src/node/optimizer/index'; -+import { JsonOptions } from 'packages/vite/src/node/plugins/json'; -+import { SSROptions } from 'packages/vite/src/node/ssr/index'; -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { RenderBuiltAssetUrl } from 'packages/vite/src/node/build'; -+ -+ -+/** -+ * spa: include SPA fallback middleware and configure sirv with `single: true` in preview -+ * -+ * mpa: only include non-SPA HTML middlewares -+ * -+ * custom: don't include HTML middlewares -+ */ -+export type AppType = 'spa' | 'mpa' | 'custom' -+ -+export type PluginOption = -+ | Plugin -+ | false -+ | null -+ | undefined -+ | PluginOption[] -+ | Promise -+ -+export interface HTMLOptions { -+ /** -+ * A nonce value placeholder that will be used when generating script/style tags. -+ * -+ * Make sure that this placeholder will be replaced with a unique value for each request by the server. -+ */ -+ cspNonce?: string -+} -+ -+export interface ExperimentalOptions { -+ /** -+ * Append fake `&lang.(ext)` when queries are specified, to preserve the file extension for following plugins to process. -+ * -+ * @experimental -+ * @default false -+ */ -+ importGlobRestoreExtension?: boolean -+ /** -+ * Allow finegrain control over assets and public files paths -+ * -+ * @experimental -+ */ -+ renderBuiltUrl?: RenderBuiltAssetUrl -+ /** -+ * Enables support of HMR partial accept via `import.meta.hot.acceptExports`. -+ * -+ * @experimental -+ * @default false -+ */ -+ hmrPartialAccept?: boolean -+ /** -+ * Skips SSR transform to make it easier to use Vite with Node ESM loaders. -+ * @warning Enabling this will break normal operation of Vite's SSR in development mode. -+ * -+ * @experimental -+ * @default false -+ */ -+ skipSsrTransform?: boolean -+} -+ -+export interface LegacyOptions { -+ /** -+ * In Vite 4, SSR-externalized modules (modules not bundled and loaded by Node.js at runtime) -+ * are implicitly proxied in dev to automatically handle `default` and `__esModule` access. -+ * However, this does not correctly reflect how it works in the Node.js runtime, causing -+ * inconsistencies between dev and prod. -+ * -+ * In Vite 5, the proxy is removed so dev and prod are consistent, but if you still require -+ * the old behaviour, you can enable this option. If so, please leave your feedback at -+ * https://github.com/vitejs/vite/discussions/14697. -+ */ -+ proxySsrExternalModules?: boolean -+} -+ -+export interface UserConfig { -+ /** -+ * Project root directory. Can be an absolute path, or a path relative from -+ * the location of the config file itself. -+ * @default process.cwd() -+ */ -+ root?: string -+ /** -+ * Base public path when served in development or production. -+ * @default '/' -+ */ -+ base?: string -+ /** -+ * Directory to serve as plain static assets. Files in this directory are -+ * served and copied to build dist dir as-is without transform. The value -+ * can be either an absolute file system path or a path relative to project root. -+ * -+ * Set to `false` or an empty string to disable copied static assets to build dist dir. -+ * @default 'public' -+ */ -+ publicDir?: string | false -+ /** -+ * Directory to save cache files. Files in this directory are pre-bundled -+ * deps or some other cache files that generated by vite, which can improve -+ * the performance. You can use `--force` flag or manually delete the directory -+ * to regenerate the cache files. The value can be either an absolute file -+ * system path or a path relative to project root. -+ * Default to `.vite` when no `package.json` is detected. -+ * @default 'node_modules/.vite' -+ */ -+ cacheDir?: string -+ /** -+ * Explicitly set a mode to run in. This will override the default mode for -+ * each command, and can be overridden by the command line --mode option. -+ */ -+ mode?: string -+ /** -+ * Define global variable replacements. -+ * Entries will be defined on `window` during dev and replaced during build. -+ */ -+ define?: Record -+ /** -+ * Array of vite plugins to use. -+ */ -+ plugins?: PluginOption[] -+ /** -+ * Configure resolver -+ */ -+ resolve?: ResolveOptions & { alias?: AliasOptions } -+ /** -+ * HTML related options -+ */ -+ html?: HTMLOptions -+ /** -+ * CSS related options (preprocessors and CSS modules) -+ */ -+ css?: CSSOptions -+ /** -+ * JSON loading options -+ */ -+ json?: JsonOptions -+ /** -+ * Transform options to pass to esbuild. -+ * Or set to `false` to disable esbuild. -+ */ -+ esbuild?: ESBuildOptions | false -+ /** -+ * Specify additional picomatch patterns to be treated as static assets. -+ */ -+ assetsInclude?: string | RegExp | (string | RegExp)[] -+ /** -+ * Server specific options, e.g. host, port, https... -+ */ -+ server?: ServerOptions -+ /** -+ * Build specific options -+ */ -+ build?: BuildOptions -+ /** -+ * Preview specific options, e.g. host, port, https... -+ */ -+ preview?: PreviewOptions -+ /** -+ * Dep optimization options -+ */ -+ optimizeDeps?: DepOptimizationOptions -+ /** -+ * SSR specific options -+ */ -+ ssr?: SSROptions -+ /** -+ * Experimental features -+ * -+ * Features under this field could change in the future and might NOT follow semver. -+ * Please be careful and always pin Vite's version when using them. -+ * @experimental -+ */ -+ experimental?: ExperimentalOptions -+ /** -+ * Legacy options -+ * -+ * Features under this field only follow semver for patches, they could be removed in a -+ * future minor version. Please always pin Vite's version to a minor when using them. -+ */ -+ legacy?: LegacyOptions -+ /** -+ * Log level. -+ * @default 'info' -+ */ -+ logLevel?: LogLevel -+ /** -+ * Custom logger. -+ */ -+ customLogger?: Logger -+ /** -+ * @default true -+ */ -+ clearScreen?: boolean -+ /** -+ * Environment files directory. Can be an absolute path, or a path relative from -+ * root. -+ * @default root -+ */ -+ envDir?: string -+ /** -+ * Env variables starts with `envPrefix` will be exposed to your client source code via import.meta.env. -+ * @default 'VITE_' -+ */ -+ envPrefix?: string | string[] -+ /** -+ * Worker bundle options -+ */ -+ worker?: { -+ /** -+ * Output format for worker bundle -+ * @default 'iife' -+ */ -+ format?: 'es' | 'iife' -+ /** -+ * Vite plugins that apply to worker bundle. The plugins returned by this function -+ * should be new instances every time it is called, because they are used for each -+ * rollup worker bundling process. -+ */ -+ plugins?: () => PluginOption[] -+ /** -+ * Rollup options to build worker bundle -+ */ -+ rollupOptions?: Omit< -+ RollupOptions, -+ 'plugins' | 'input' | 'onwarn' | 'preserveEntrySignatures' -+ > -+ } -+ /** -+ * Whether your application is a Single Page Application (SPA), -+ * a Multi-Page Application (MPA), or Custom Application (SSR -+ * and frameworks with custom HTML handling) -+ * @default 'spa' -+ */ -+ appType?: AppType -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/config/bundleConfigFile.ts b/packages/vite/src/node/config/bundleConfigFile.ts -new file mode 100644 -index 000000000..74fc22cfd ---- /dev/null -+++ b/packages/vite/src/node/config/bundleConfigFile.ts -@@ -0,0 +1,158 @@ -+import path from 'node:path' -+import { build } from 'esbuild' -+import { isNodeBuiltin } from 'packages/vite/src/node/utils'; -+import { isBuiltin } from 'packages/vite/src/node/utils'; -+import { isFilePathESM } from 'packages/vite/src/node/utils'; -+ -+ -+export async function bundleConfigFile( -+ fileName: string, -+ isESM: boolean, -+): Promise<{ code: string; dependencies: string[] }> { -+ const dirnameVarName = '__vite_injected_original_dirname' -+ const filenameVarName = '__vite_injected_original_filename' -+ const importMetaUrlVarName = '__vite_injected_original_import_meta_url' -+ const result = await build({ -+ absWorkingDir: process.cwd(), -+ entryPoints: [fileName], -+ write: false, -+ target: [`node${process.versions.node}`], -+ platform: 'node', -+ bundle: true, -+ format: isESM ? 'esm' : 'cjs', -+ mainFields: ['main'], -+ sourcemap: 'inline', -+ metafile: true, -+ define: { -+ __dirname: dirnameVarName, -+ __filename: filenameVarName, -+ 'import.meta.url': importMetaUrlVarName, -+ 'import.meta.dirname': dirnameVarName, -+ 'import.meta.filename': filenameVarName, -+ }, -+ plugins: [ -+ { -+ name: 'externalize-deps', -+ setup(build) { -+ const packageCache = new Map() -+ const resolveByViteResolver = ( -+ id: string, -+ importer: string, -+ isRequire: boolean, -+ ) => { -+ return tryNodeResolve( -+ id, -+ importer, -+ { -+ root: path.dirname(fileName), -+ isBuild: true, -+ isProduction: true, -+ preferRelative: false, -+ tryIndex: true, -+ mainFields: [], -+ conditions: [], -+ overrideConditions: ['node'], -+ dedupe: [], -+ extensions: DEFAULT_EXTENSIONS, -+ preserveSymlinks: false, -+ packageCache, -+ isRequire, -+ }, -+ false, -+ )?.id -+ } -+ -+ // externalize bare imports -+ build.onResolve( -+ { filter: /^[^.].*/ }, -+ async ({ path: id, importer, kind }) => { -+ if ( -+ kind === 'entry-point' || -+ path.isAbsolute(id) || -+ isNodeBuiltin(id) -+ ) { -+ return -+ } -+ -+ // With the `isNodeBuiltin` check above, this check captures if the builtin is a -+ // non-node built-in, which esbuild doesn't know how to handle. In that case, we -+ // externalize it so the non-node runtime handles it instead. -+ if (isBuiltin(id)) { -+ return { external: true } -+ } -+ -+ const isImport = isESM || kind === 'dynamic-import' -+ let idFsPath: string | undefined -+ try { -+ idFsPath = resolveByViteResolver(id, importer, !isImport) -+ } catch (e) { -+ if (!isImport) { -+ let canResolveWithImport = false -+ try { -+ canResolveWithImport = !!resolveByViteResolver( -+ id, -+ importer, -+ false, -+ ) -+ } catch {} -+ if (canResolveWithImport) { -+ throw new Error( -+ `Failed to resolve ${JSON.stringify( -+ id, -+ )}. This package is ESM only but it was tried to load by \`require\`. See https://vitejs.dev/guide/troubleshooting.html#this-package-is-esm-only for more details.`, -+ ) -+ } -+ } -+ throw e -+ } -+ if (idFsPath && isImport) { -+ idFsPath = pathToFileURL(idFsPath).href -+ } -+ if ( -+ idFsPath && -+ !isImport && -+ isFilePathESM(idFsPath, packageCache) -+ ) { -+ throw new Error( -+ `${JSON.stringify( -+ id, -+ )} resolved to an ESM file. ESM file cannot be loaded by \`require\`. See https://vitejs.dev/guide/troubleshooting.html#this-package-is-esm-only for more details.`, -+ ) -+ } -+ return { -+ path: idFsPath, -+ external: true, -+ } -+ }, -+ ) -+ }, -+ }, -+ { -+ name: 'inject-file-scope-variables', -+ setup(build) { -+ build.onLoad({ filter: /\.[cm]?[jt]s$/ }, async (args) => { -+ const contents = await fsp.readFile(args.path, 'utf-8') -+ const injectValues = -+ `const ${dirnameVarName} = ${JSON.stringify( -+ path.dirname(args.path), -+ )};` + -+ `const ${filenameVarName} = ${JSON.stringify(args.path)};` + -+ `const ${importMetaUrlVarName} = ${JSON.stringify( -+ pathToFileURL(args.path).href, -+ )};` -+ -+ return { -+ loader: args.path.endsWith('ts') ? 'ts' : 'js', -+ contents: injectValues + contents, -+ } -+ }) -+ }, -+ }, -+ ], -+ }) -+ const { text } = result.outputFiles[0] -+ return { -+ code: text, -+ dependencies: result.metafile ? Object.keys(result.metafile.inputs) : [], -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/config/loadConfigFromFile.ts b/packages/vite/src/node/config/loadConfigFromFile.ts -new file mode 100644 -index 000000000..a8c60b0af ---- /dev/null -+++ b/packages/vite/src/node/config/loadConfigFromFile.ts -@@ -0,0 +1,337 @@ -+import fs from 'node:fs' -+import path from 'node:path' -+import { performance } from 'node:perf_hooks' -+import colors from 'picocolors' -+import { DEFAULT_CONFIG_FILES } from 'packages/vite/src/node/constants'; -+import { isObject } from 'packages/vite/src/node/utils'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { isFilePathESM } from 'packages/vite/src/node/utils'; -+import { LogLevel } from 'packages/vite/src/node/logger'; -+import { Logger } from 'packages/vite/src/node/logger'; -+import { createLogger } from 'packages/vite/src/node/logger'; -+import type { Alias, AliasOptions } from 'dep-types/alias' -+import type { RollupOptions } from 'rollup' -+import { BuildOptions } from 'packages/vite/src/node/build'; -+import { ServerOptions } from 'packages/vite/src/node/server/index'; -+import { PreviewOptions } from 'packages/vite/src/node/preview'; -+import { CSSOptions } from 'packages/vite/src/node/plugins/css'; -+import { ESBuildOptions } from 'packages/vite/src/node/plugins/esbuild'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve'; -+import { DepOptimizationOptions } from 'packages/vite/src/node/optimizer/index'; -+import { JsonOptions } from 'packages/vite/src/node/plugins/json'; -+import { SSROptions } from 'packages/vite/src/node/ssr/index'; -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { RenderBuiltAssetUrl } from 'packages/vite/src/node/build'; -+ -+ -+export interface ConfigEnv { -+ /** -+ * 'serve': during dev (`vite` command) -+ * 'build': when building for production (`vite build` command) -+ */ -+ command: 'build' | 'serve' -+ mode: string -+ isSsrBuild?: boolean -+ isPreview?: boolean -+} -+ -+/** -+ * spa: include SPA fallback middleware and configure sirv with `single: true` in preview -+ * -+ * mpa: only include non-SPA HTML middlewares -+ * -+ * custom: don't include HTML middlewares -+ */ -+export type AppType = 'spa' | 'mpa' | 'custom' -+ -+export type PluginOption = -+ | Plugin -+ | false -+ | null -+ | undefined -+ | PluginOption[] -+ | Promise -+ -+export interface HTMLOptions { -+ /** -+ * A nonce value placeholder that will be used when generating script/style tags. -+ * -+ * Make sure that this placeholder will be replaced with a unique value for each request by the server. -+ */ -+ cspNonce?: string -+} -+ -+export interface ExperimentalOptions { -+ /** -+ * Append fake `&lang.(ext)` when queries are specified, to preserve the file extension for following plugins to process. -+ * -+ * @experimental -+ * @default false -+ */ -+ importGlobRestoreExtension?: boolean -+ /** -+ * Allow finegrain control over assets and public files paths -+ * -+ * @experimental -+ */ -+ renderBuiltUrl?: RenderBuiltAssetUrl -+ /** -+ * Enables support of HMR partial accept via `import.meta.hot.acceptExports`. -+ * -+ * @experimental -+ * @default false -+ */ -+ hmrPartialAccept?: boolean -+ /** -+ * Skips SSR transform to make it easier to use Vite with Node ESM loaders. -+ * @warning Enabling this will break normal operation of Vite's SSR in development mode. -+ * -+ * @experimental -+ * @default false -+ */ -+ skipSsrTransform?: boolean -+} -+ -+export interface LegacyOptions { -+ /** -+ * In Vite 4, SSR-externalized modules (modules not bundled and loaded by Node.js at runtime) -+ * are implicitly proxied in dev to automatically handle `default` and `__esModule` access. -+ * However, this does not correctly reflect how it works in the Node.js runtime, causing -+ * inconsistencies between dev and prod. -+ * -+ * In Vite 5, the proxy is removed so dev and prod are consistent, but if you still require -+ * the old behaviour, you can enable this option. If so, please leave your feedback at -+ * https://github.com/vitejs/vite/discussions/14697. -+ */ -+ proxySsrExternalModules?: boolean -+} -+ -+export interface UserConfig { -+ /** -+ * Project root directory. Can be an absolute path, or a path relative from -+ * the location of the config file itself. -+ * @default process.cwd() -+ */ -+ root?: string -+ /** -+ * Base public path when served in development or production. -+ * @default '/' -+ */ -+ base?: string -+ /** -+ * Directory to serve as plain static assets. Files in this directory are -+ * served and copied to build dist dir as-is without transform. The value -+ * can be either an absolute file system path or a path relative to project root. -+ * -+ * Set to `false` or an empty string to disable copied static assets to build dist dir. -+ * @default 'public' -+ */ -+ publicDir?: string | false -+ /** -+ * Directory to save cache files. Files in this directory are pre-bundled -+ * deps or some other cache files that generated by vite, which can improve -+ * the performance. You can use `--force` flag or manually delete the directory -+ * to regenerate the cache files. The value can be either an absolute file -+ * system path or a path relative to project root. -+ * Default to `.vite` when no `package.json` is detected. -+ * @default 'node_modules/.vite' -+ */ -+ cacheDir?: string -+ /** -+ * Explicitly set a mode to run in. This will override the default mode for -+ * each command, and can be overridden by the command line --mode option. -+ */ -+ mode?: string -+ /** -+ * Define global variable replacements. -+ * Entries will be defined on `window` during dev and replaced during build. -+ */ -+ define?: Record -+ /** -+ * Array of vite plugins to use. -+ */ -+ plugins?: PluginOption[] -+ /** -+ * Configure resolver -+ */ -+ resolve?: ResolveOptions & { alias?: AliasOptions } -+ /** -+ * HTML related options -+ */ -+ html?: HTMLOptions -+ /** -+ * CSS related options (preprocessors and CSS modules) -+ */ -+ css?: CSSOptions -+ /** -+ * JSON loading options -+ */ -+ json?: JsonOptions -+ /** -+ * Transform options to pass to esbuild. -+ * Or set to `false` to disable esbuild. -+ */ -+ esbuild?: ESBuildOptions | false -+ /** -+ * Specify additional picomatch patterns to be treated as static assets. -+ */ -+ assetsInclude?: string | RegExp | (string | RegExp)[] -+ /** -+ * Server specific options, e.g. host, port, https... -+ */ -+ server?: ServerOptions -+ /** -+ * Build specific options -+ */ -+ build?: BuildOptions -+ /** -+ * Preview specific options, e.g. host, port, https... -+ */ -+ preview?: PreviewOptions -+ /** -+ * Dep optimization options -+ */ -+ optimizeDeps?: DepOptimizationOptions -+ /** -+ * SSR specific options -+ */ -+ ssr?: SSROptions -+ /** -+ * Experimental features -+ * -+ * Features under this field could change in the future and might NOT follow semver. -+ * Please be careful and always pin Vite's version when using them. -+ * @experimental -+ */ -+ experimental?: ExperimentalOptions -+ /** -+ * Legacy options -+ * -+ * Features under this field only follow semver for patches, they could be removed in a -+ * future minor version. Please always pin Vite's version to a minor when using them. -+ */ -+ legacy?: LegacyOptions -+ /** -+ * Log level. -+ * @default 'info' -+ */ -+ logLevel?: LogLevel -+ /** -+ * Custom logger. -+ */ -+ customLogger?: Logger -+ /** -+ * @default true -+ */ -+ clearScreen?: boolean -+ /** -+ * Environment files directory. Can be an absolute path, or a path relative from -+ * root. -+ * @default root -+ */ -+ envDir?: string -+ /** -+ * Env variables starts with `envPrefix` will be exposed to your client source code via import.meta.env. -+ * @default 'VITE_' -+ */ -+ envPrefix?: string | string[] -+ /** -+ * Worker bundle options -+ */ -+ worker?: { -+ /** -+ * Output format for worker bundle -+ * @default 'iife' -+ */ -+ format?: 'es' | 'iife' -+ /** -+ * Vite plugins that apply to worker bundle. The plugins returned by this function -+ * should be new instances every time it is called, because they are used for each -+ * rollup worker bundling process. -+ */ -+ plugins?: () => PluginOption[] -+ /** -+ * Rollup options to build worker bundle -+ */ -+ rollupOptions?: Omit< -+ RollupOptions, -+ 'plugins' | 'input' | 'onwarn' | 'preserveEntrySignatures' -+ > -+ } -+ /** -+ * Whether your application is a Single Page Application (SPA), -+ * a Multi-Page Application (MPA), or Custom Application (SSR -+ * and frameworks with custom HTML handling) -+ * @default 'spa' -+ */ -+ appType?: AppType -+} -+ -+export async function loadConfigFromFile( -+ configEnv: ConfigEnv, -+ configFile?: string, -+ configRoot: string = process.cwd(), -+ logLevel?: LogLevel, -+ customLogger?: Logger, -+): Promise<{ -+ path: string -+ config: UserConfig -+ dependencies: string[] -+} | null> { -+ const start = performance.now() -+ const getTime = () => `${(performance.now() - start).toFixed(2)}ms` -+ -+ let resolvedPath: string | undefined -+ -+ if (configFile) { -+ // explicit config path is always resolved from cwd -+ resolvedPath = path.resolve(configFile) -+ } else { -+ // implicit config file loaded from inline root (if present) -+ // otherwise from cwd -+ for (const filename of DEFAULT_CONFIG_FILES) { -+ const filePath = path.resolve(configRoot, filename) -+ if (!fs.existsSync(filePath)) continue -+ -+ resolvedPath = filePath -+ break -+ } -+ } -+ -+ if (!resolvedPath) { -+ debug?.('no config file found.') -+ return null -+ } -+ -+ const isESM = isFilePathESM(resolvedPath) -+ -+ try { -+ const bundled = await bundleConfigFile(resolvedPath, isESM) -+ const userConfig = await loadConfigFromBundledFile( -+ resolvedPath, -+ bundled.code, -+ isESM, -+ ) -+ debug?.(`bundled config file loaded in ${getTime()}`) -+ -+ const config = await (typeof userConfig === 'function' -+ ? userConfig(configEnv) -+ : userConfig) -+ if (!isObject(config)) { -+ throw new Error(`config must export or return an object.`) -+ } -+ return { -+ path: normalizePath(resolvedPath), -+ config, -+ dependencies: bundled.dependencies, -+ } -+ } catch (e) { -+ createLogger(logLevel, { customLogger }).error( -+ colors.red(`failed to load config from ${resolvedPath}`), -+ { -+ error: e, -+ }, -+ ) -+ throw e -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/config/resolveConfig.ts b/packages/vite/src/node/config/resolveConfig.ts -new file mode 100644 -index 000000000..9e10e6f6d ---- /dev/null -+++ b/packages/vite/src/node/config/resolveConfig.ts -@@ -0,0 +1,889 @@ -+import path from 'node:path' -+import colors from 'picocolors' -+import { withTrailingSlash } from 'packages/vite/src/shared/utils'; -+import { FS_PREFIX } from 'packages/vite/src/node/constants'; -+import { ENV_ENTRY } from 'packages/vite/src/node/constants'; -+import { CLIENT_ENTRY } from 'packages/vite/src/node/constants'; -+import { DEFAULT_MAIN_FIELDS } from 'packages/vite/src/node/constants'; -+import { DEFAULT_EXTENSIONS } from 'packages/vite/src/node/constants'; -+import { DEFAULT_ASSETS_RE } from 'packages/vite/src/node/constants'; -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { resolveBuildOptions } from 'packages/vite/src/node/build'; -+import { resolveServerOptions } from 'packages/vite/src/node/server/index'; -+import { resolvePreviewOptions } from 'packages/vite/src/node/preview'; -+import { resolveCSSOptions } from 'packages/vite/src/node/plugins/css'; -+import { isParentDirectory } from 'packages/vite/src/node/utils'; -+import { mergeConfig } from 'packages/vite/src/node/utils'; -+import { asyncFlatten } from 'packages/vite/src/node/utils'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { mergeAlias } from 'packages/vite/src/node/utils'; -+import { normalizeAlias } from 'packages/vite/src/node/utils'; -+import { createPluginHookUtils } from 'packages/vite/src/node/plugins/index'; -+import { resolvePlugins } from 'packages/vite/src/node/plugins/index'; -+import { createLogger } from 'packages/vite/src/node/logger'; -+import { PluginContainer } from 'packages/vite/src/node/server/pluginContainer'; -+import { PackageCache } from 'packages/vite/src/node/packages'; -+import { findNearestPackageData } from 'packages/vite/src/node/packages'; -+import { resolveEnvPrefix } from 'packages/vite/src/node/env'; -+import { loadEnv } from 'packages/vite/src/node/env'; -+import { resolveSSROptions } from 'packages/vite/src/node/ssr/index'; -+import type { RollupOptions } from 'rollup' -+import type { Alias, AliasOptions } from 'dep-types/alias' -+import { BuildOptions } from 'packages/vite/src/node/build'; -+import { ServerOptions } from 'packages/vite/src/node/server/index'; -+import { PreviewOptions } from 'packages/vite/src/node/preview'; -+import { CSSOptions } from 'packages/vite/src/node/plugins/css'; -+import { ESBuildOptions } from 'packages/vite/src/node/plugins/esbuild'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve'; -+import { LogLevel } from 'packages/vite/src/node/logger'; -+import { Logger } from 'packages/vite/src/node/logger'; -+import { DepOptimizationOptions } from 'packages/vite/src/node/optimizer/index'; -+import { JsonOptions } from 'packages/vite/src/node/plugins/json'; -+import { SSROptions } from 'packages/vite/src/node/ssr/index'; -+import { RenderBuiltAssetUrl } from 'packages/vite/src/node/build'; -+import { ResolvedBuildOptions } from 'packages/vite/src/node/build'; -+import { ResolvedServerOptions } from 'packages/vite/src/node/server/index'; -+import { ResolvedPreviewOptions } from 'packages/vite/src/node/preview'; -+import { ResolvedCSSOptions } from 'packages/vite/src/node/plugins/css'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve'; -+import { ResolvedSSROptions } from 'packages/vite/src/node/ssr/index'; -+import { PluginWithRequiredHook } from 'packages/vite/src/node/plugin'; -+import { HookHandler } from 'packages/vite/src/node/plugin'; -+ -+ -+export interface ConfigEnv { -+ /** -+ * 'serve': during dev (`vite` command) -+ * 'build': when building for production (`vite build` command) -+ */ -+ command: 'build' | 'serve' -+ mode: string -+ isSsrBuild?: boolean -+ isPreview?: boolean -+} -+ -+export interface ResolvedWorkerOptions { -+ format: 'es' | 'iife' -+ plugins: (bundleChain: string[]) => Promise -+ rollupOptions: RollupOptions -+} -+ -+/** -+ * spa: include SPA fallback middleware and configure sirv with `single: true` in preview -+ * -+ * mpa: only include non-SPA HTML middlewares -+ * -+ * custom: don't include HTML middlewares -+ */ -+export type AppType = 'spa' | 'mpa' | 'custom' -+ -+export type PluginOption = -+ | Plugin -+ | false -+ | null -+ | undefined -+ | PluginOption[] -+ | Promise -+ -+export interface HTMLOptions { -+ /** -+ * A nonce value placeholder that will be used when generating script/style tags. -+ * -+ * Make sure that this placeholder will be replaced with a unique value for each request by the server. -+ */ -+ cspNonce?: string -+} -+ -+export interface ExperimentalOptions { -+ /** -+ * Append fake `&lang.(ext)` when queries are specified, to preserve the file extension for following plugins to process. -+ * -+ * @experimental -+ * @default false -+ */ -+ importGlobRestoreExtension?: boolean -+ /** -+ * Allow finegrain control over assets and public files paths -+ * -+ * @experimental -+ */ -+ renderBuiltUrl?: RenderBuiltAssetUrl -+ /** -+ * Enables support of HMR partial accept via `import.meta.hot.acceptExports`. -+ * -+ * @experimental -+ * @default false -+ */ -+ hmrPartialAccept?: boolean -+ /** -+ * Skips SSR transform to make it easier to use Vite with Node ESM loaders. -+ * @warning Enabling this will break normal operation of Vite's SSR in development mode. -+ * -+ * @experimental -+ * @default false -+ */ -+ skipSsrTransform?: boolean -+} -+ -+export interface LegacyOptions { -+ /** -+ * In Vite 4, SSR-externalized modules (modules not bundled and loaded by Node.js at runtime) -+ * are implicitly proxied in dev to automatically handle `default` and `__esModule` access. -+ * However, this does not correctly reflect how it works in the Node.js runtime, causing -+ * inconsistencies between dev and prod. -+ * -+ * In Vite 5, the proxy is removed so dev and prod are consistent, but if you still require -+ * the old behaviour, you can enable this option. If so, please leave your feedback at -+ * https://github.com/vitejs/vite/discussions/14697. -+ */ -+ proxySsrExternalModules?: boolean -+} -+ -+export interface UserConfig { -+ /** -+ * Project root directory. Can be an absolute path, or a path relative from -+ * the location of the config file itself. -+ * @default process.cwd() -+ */ -+ root?: string -+ /** -+ * Base public path when served in development or production. -+ * @default '/' -+ */ -+ base?: string -+ /** -+ * Directory to serve as plain static assets. Files in this directory are -+ * served and copied to build dist dir as-is without transform. The value -+ * can be either an absolute file system path or a path relative to project root. -+ * -+ * Set to `false` or an empty string to disable copied static assets to build dist dir. -+ * @default 'public' -+ */ -+ publicDir?: string | false -+ /** -+ * Directory to save cache files. Files in this directory are pre-bundled -+ * deps or some other cache files that generated by vite, which can improve -+ * the performance. You can use `--force` flag or manually delete the directory -+ * to regenerate the cache files. The value can be either an absolute file -+ * system path or a path relative to project root. -+ * Default to `.vite` when no `package.json` is detected. -+ * @default 'node_modules/.vite' -+ */ -+ cacheDir?: string -+ /** -+ * Explicitly set a mode to run in. This will override the default mode for -+ * each command, and can be overridden by the command line --mode option. -+ */ -+ mode?: string -+ /** -+ * Define global variable replacements. -+ * Entries will be defined on `window` during dev and replaced during build. -+ */ -+ define?: Record -+ /** -+ * Array of vite plugins to use. -+ */ -+ plugins?: PluginOption[] -+ /** -+ * Configure resolver -+ */ -+ resolve?: ResolveOptions & { alias?: AliasOptions } -+ /** -+ * HTML related options -+ */ -+ html?: HTMLOptions -+ /** -+ * CSS related options (preprocessors and CSS modules) -+ */ -+ css?: CSSOptions -+ /** -+ * JSON loading options -+ */ -+ json?: JsonOptions -+ /** -+ * Transform options to pass to esbuild. -+ * Or set to `false` to disable esbuild. -+ */ -+ esbuild?: ESBuildOptions | false -+ /** -+ * Specify additional picomatch patterns to be treated as static assets. -+ */ -+ assetsInclude?: string | RegExp | (string | RegExp)[] -+ /** -+ * Server specific options, e.g. host, port, https... -+ */ -+ server?: ServerOptions -+ /** -+ * Build specific options -+ */ -+ build?: BuildOptions -+ /** -+ * Preview specific options, e.g. host, port, https... -+ */ -+ preview?: PreviewOptions -+ /** -+ * Dep optimization options -+ */ -+ optimizeDeps?: DepOptimizationOptions -+ /** -+ * SSR specific options -+ */ -+ ssr?: SSROptions -+ /** -+ * Experimental features -+ * -+ * Features under this field could change in the future and might NOT follow semver. -+ * Please be careful and always pin Vite's version when using them. -+ * @experimental -+ */ -+ experimental?: ExperimentalOptions -+ /** -+ * Legacy options -+ * -+ * Features under this field only follow semver for patches, they could be removed in a -+ * future minor version. Please always pin Vite's version to a minor when using them. -+ */ -+ legacy?: LegacyOptions -+ /** -+ * Log level. -+ * @default 'info' -+ */ -+ logLevel?: LogLevel -+ /** -+ * Custom logger. -+ */ -+ customLogger?: Logger -+ /** -+ * @default true -+ */ -+ clearScreen?: boolean -+ /** -+ * Environment files directory. Can be an absolute path, or a path relative from -+ * root. -+ * @default root -+ */ -+ envDir?: string -+ /** -+ * Env variables starts with `envPrefix` will be exposed to your client source code via import.meta.env. -+ * @default 'VITE_' -+ */ -+ envPrefix?: string | string[] -+ /** -+ * Worker bundle options -+ */ -+ worker?: { -+ /** -+ * Output format for worker bundle -+ * @default 'iife' -+ */ -+ format?: 'es' | 'iife' -+ /** -+ * Vite plugins that apply to worker bundle. The plugins returned by this function -+ * should be new instances every time it is called, because they are used for each -+ * rollup worker bundling process. -+ */ -+ plugins?: () => PluginOption[] -+ /** -+ * Rollup options to build worker bundle -+ */ -+ rollupOptions?: Omit< -+ RollupOptions, -+ 'plugins' | 'input' | 'onwarn' | 'preserveEntrySignatures' -+ > -+ } -+ /** -+ * Whether your application is a Single Page Application (SPA), -+ * a Multi-Page Application (MPA), or Custom Application (SSR -+ * and frameworks with custom HTML handling) -+ * @default 'spa' -+ */ -+ appType?: AppType -+} -+ -+export interface InlineConfig extends UserConfig { -+ configFile?: string | false -+ envFile?: false -+} -+ -+export interface PluginHookUtils { -+ getSortedPlugins: ( -+ hookName: K, -+ ) => PluginWithRequiredHook[] -+ getSortedPluginHooks: ( -+ hookName: K, -+ ) => NonNullable>[] -+} -+ -+export type ResolveFn = ( -+ id: string, -+ importer?: string, -+ aliasOnly?: boolean, -+ ssr?: boolean, -+) => Promise -+ -+export type ResolvedConfig = Readonly< -+ Omit< -+ UserConfig, -+ 'plugins' | 'css' | 'assetsInclude' | 'optimizeDeps' | 'worker' | 'build' -+ > & { -+ configFile: string | undefined -+ configFileDependencies: string[] -+ inlineConfig: InlineConfig -+ root: string -+ base: string -+ /** @internal */ -+ decodedBase: string -+ /** @internal */ -+ rawBase: string -+ publicDir: string -+ cacheDir: string -+ command: 'build' | 'serve' -+ mode: string -+ isWorker: boolean -+ // in nested worker bundle to find the main config -+ /** @internal */ -+ mainConfig: ResolvedConfig | null -+ /** @internal list of bundle entry id. used to detect recursive worker bundle. */ -+ bundleChain: string[] -+ isProduction: boolean -+ envDir: string -+ env: Record -+ resolve: Required & { -+ alias: Alias[] -+ } -+ plugins: readonly Plugin[] -+ css: ResolvedCSSOptions -+ esbuild: ESBuildOptions | false -+ server: ResolvedServerOptions -+ build: ResolvedBuildOptions -+ preview: ResolvedPreviewOptions -+ ssr: ResolvedSSROptions -+ assetsInclude: (file: string) => boolean -+ logger: Logger -+ createResolver: (options?: Partial) => ResolveFn -+ optimizeDeps: DepOptimizationOptions -+ /** @internal */ -+ packageCache: PackageCache -+ worker: ResolvedWorkerOptions -+ appType: AppType -+ experimental: ExperimentalOptions -+ } & PluginHookUtils -+> -+ -+/** -+ * Check and warn if `path` includes characters that don't work well in Vite, -+ * such as `#` and `?`. -+ */ -+export function checkBadCharactersInPath(path: string, logger: Logger): void { -+ const badChars = [] -+ -+ if (path.includes('#')) { -+ badChars.push('#') -+ } -+ if (path.includes('?')) { -+ badChars.push('?') -+ } -+ -+ if (badChars.length > 0) { -+ const charString = badChars.map((c) => `"${c}"`).join(' and ') -+ const inflectedChars = badChars.length > 1 ? 'characters' : 'character' -+ -+ logger.warn( -+ colors.yellow( -+ `The project root contains the ${charString} ${inflectedChars} (${colors.cyan( -+ path, -+ )}), which may not work when running Vite. Consider renaming the directory to remove the characters.`, -+ ), -+ ) -+ } -+} -+ -+export async function resolveConfig( -+ inlineConfig: InlineConfig, -+ command: 'build' | 'serve', -+ defaultMode = 'development', -+ defaultNodeEnv = 'development', -+ isPreview = false, -+): Promise { -+ let config = inlineConfig -+ let configFileDependencies: string[] = [] -+ let mode = inlineConfig.mode || defaultMode -+ const isNodeEnvSet = !!process.env.NODE_ENV -+ const packageCache: PackageCache = new Map() -+ -+ // some dependencies e.g. @vue/compiler-* relies on NODE_ENV for getting -+ // production-specific behavior, so set it early on -+ if (!isNodeEnvSet) { -+ process.env.NODE_ENV = defaultNodeEnv -+ } -+ -+ const configEnv: ConfigEnv = { -+ mode, -+ command, -+ isSsrBuild: command === 'build' && !!config.build?.ssr, -+ isPreview, -+ } -+ -+ let { configFile } = config -+ if (configFile !== false) { -+ const loadResult = await loadConfigFromFile( -+ configEnv, -+ configFile, -+ config.root, -+ config.logLevel, -+ config.customLogger, -+ ) -+ if (loadResult) { -+ config = mergeConfig(loadResult.config, config) -+ configFile = loadResult.path -+ configFileDependencies = loadResult.dependencies -+ } -+ } -+ -+ // user config may provide an alternative mode. But --mode has a higher priority -+ mode = inlineConfig.mode || config.mode || mode -+ configEnv.mode = mode -+ -+ const filterPlugin = (p: Plugin) => { -+ if (!p) { -+ return false -+ } else if (!p.apply) { -+ return true -+ } else if (typeof p.apply === 'function') { -+ return p.apply({ ...config, mode }, configEnv) -+ } else { -+ return p.apply === command -+ } -+ } -+ -+ // resolve plugins -+ const rawUserPlugins = ( -+ (await asyncFlatten(config.plugins || [])) as Plugin[] -+ ).filter(filterPlugin) -+ -+ const [prePlugins, normalPlugins, postPlugins] = -+ sortUserPlugins(rawUserPlugins) -+ -+ // run config hooks -+ const userPlugins = [...prePlugins, ...normalPlugins, ...postPlugins] -+ config = await runConfigHook(config, userPlugins, configEnv) -+ -+ // Define logger -+ const logger = createLogger(config.logLevel, { -+ allowClearScreen: config.clearScreen, -+ customLogger: config.customLogger, -+ }) -+ -+ // resolve root -+ const resolvedRoot = normalizePath( -+ config.root ? path.resolve(config.root) : process.cwd(), -+ ) -+ -+ checkBadCharactersInPath(resolvedRoot, logger) -+ -+ const clientAlias = [ -+ { -+ find: /^\/?@vite\/env/, -+ replacement: path.posix.join(FS_PREFIX, normalizePath(ENV_ENTRY)), -+ }, -+ { -+ find: /^\/?@vite\/client/, -+ replacement: path.posix.join(FS_PREFIX, normalizePath(CLIENT_ENTRY)), -+ }, -+ ] -+ -+ // resolve alias with internal client alias -+ const resolvedAlias = normalizeAlias( -+ mergeAlias(clientAlias, config.resolve?.alias || []), -+ ) -+ -+ const resolveOptions: ResolvedConfig['resolve'] = { -+ mainFields: config.resolve?.mainFields ?? DEFAULT_MAIN_FIELDS, -+ conditions: config.resolve?.conditions ?? [], -+ extensions: config.resolve?.extensions ?? DEFAULT_EXTENSIONS, -+ dedupe: config.resolve?.dedupe ?? [], -+ preserveSymlinks: config.resolve?.preserveSymlinks ?? false, -+ alias: resolvedAlias, -+ } -+ -+ if ( -+ // @ts-expect-error removed field -+ config.resolve?.browserField === false && -+ resolveOptions.mainFields.includes('browser') -+ ) { -+ logger.warn( -+ colors.yellow( -+ `\`resolve.browserField\` is set to false, but the option is removed in favour of ` + -+ `the 'browser' string in \`resolve.mainFields\`. You may want to update \`resolve.mainFields\` ` + -+ `to remove the 'browser' string and preserve the previous browser behaviour.`, -+ ), -+ ) -+ } -+ -+ // load .env files -+ const envDir = config.envDir -+ ? normalizePath(path.resolve(resolvedRoot, config.envDir)) -+ : resolvedRoot -+ const userEnv = -+ inlineConfig.envFile !== false && -+ loadEnv(mode, envDir, resolveEnvPrefix(config)) -+ -+ // Note it is possible for user to have a custom mode, e.g. `staging` where -+ // development-like behavior is expected. This is indicated by NODE_ENV=development -+ // loaded from `.staging.env` and set by us as VITE_USER_NODE_ENV -+ const userNodeEnv = process.env.VITE_USER_NODE_ENV -+ if (!isNodeEnvSet && userNodeEnv) { -+ if (userNodeEnv === 'development') { -+ process.env.NODE_ENV = 'development' -+ } else { -+ // NODE_ENV=production is not supported as it could break HMR in dev for frameworks like Vue -+ logger.warn( -+ `NODE_ENV=${userNodeEnv} is not supported in the .env file. ` + -+ `Only NODE_ENV=development is supported to create a development build of your project. ` + -+ `If you need to set process.env.NODE_ENV, you can set it in the Vite config instead.`, -+ ) -+ } -+ } -+ -+ const isProduction = process.env.NODE_ENV === 'production' -+ -+ // resolve public base url -+ const isBuild = command === 'build' -+ const relativeBaseShortcut = config.base === '' || config.base === './' -+ -+ // During dev, we ignore relative base and fallback to '/' -+ // For the SSR build, relative base isn't possible by means -+ // of import.meta.url. -+ const resolvedBase = relativeBaseShortcut -+ ? !isBuild || config.build?.ssr -+ ? '/' -+ : './' -+ : (resolveBaseUrl(config.base, isBuild, logger) ?? '/') -+ -+ const resolvedBuildOptions = resolveBuildOptions( -+ config.build, -+ logger, -+ resolvedRoot, -+ ) -+ -+ // resolve cache directory -+ const pkgDir = findNearestPackageData(resolvedRoot, packageCache)?.dir -+ const cacheDir = normalizePath( -+ config.cacheDir -+ ? path.resolve(resolvedRoot, config.cacheDir) -+ : pkgDir -+ ? path.join(pkgDir, `node_modules/.vite`) -+ : path.join(resolvedRoot, `.vite`), -+ ) -+ -+ const assetsFilter = -+ config.assetsInclude && -+ (!Array.isArray(config.assetsInclude) || config.assetsInclude.length) -+ ? createFilter(config.assetsInclude) -+ : () => false -+ -+ // create an internal resolver to be used in special scenarios, e.g. -+ // optimizer & handling css @imports -+ const createResolver: ResolvedConfig['createResolver'] = (options) => { -+ let aliasContainer: PluginContainer | undefined -+ let resolverContainer: PluginContainer | undefined -+ return async (id, importer, aliasOnly, ssr) => { -+ let container: PluginContainer -+ if (aliasOnly) { -+ container = -+ aliasContainer || -+ (aliasContainer = await createPluginContainer({ -+ ...resolved, -+ plugins: [aliasPlugin({ entries: resolved.resolve.alias })], -+ })) -+ } else { -+ container = -+ resolverContainer || -+ (resolverContainer = await createPluginContainer({ -+ ...resolved, -+ plugins: [ -+ aliasPlugin({ entries: resolved.resolve.alias }), -+ resolvePlugin({ -+ ...resolved.resolve, -+ root: resolvedRoot, -+ isProduction, -+ isBuild: command === 'build', -+ ssrConfig: resolved.ssr, -+ asSrc: true, -+ preferRelative: false, -+ tryIndex: true, -+ ...options, -+ idOnly: true, -+ fsUtils: getFsUtils(resolved), -+ }), -+ ], -+ })) -+ } -+ return ( -+ await container.resolveId(id, importer, { -+ ssr, -+ scan: options?.scan, -+ }) -+ )?.id -+ } -+ } -+ -+ const { publicDir } = config -+ const resolvedPublicDir = -+ publicDir !== false && publicDir !== '' -+ ? normalizePath( -+ path.resolve( -+ resolvedRoot, -+ typeof publicDir === 'string' ? publicDir : 'public', -+ ), -+ ) -+ : '' -+ -+ const server = resolveServerOptions(resolvedRoot, config.server, logger) -+ const ssr = resolveSSROptions(config.ssr, resolveOptions.preserveSymlinks) -+ -+ const optimizeDeps = config.optimizeDeps || {} -+ -+ const BASE_URL = resolvedBase -+ -+ let resolved: ResolvedConfig -+ -+ let createUserWorkerPlugins = config.worker?.plugins -+ if (Array.isArray(createUserWorkerPlugins)) { -+ // @ts-expect-error backward compatibility -+ createUserWorkerPlugins = () => config.worker?.plugins -+ -+ logger.warn( -+ colors.yellow( -+ `worker.plugins is now a function that returns an array of plugins. ` + -+ `Please update your Vite config accordingly.\n`, -+ ), -+ ) -+ } -+ -+ const createWorkerPlugins = async function (bundleChain: string[]) { -+ // Some plugins that aren't intended to work in the bundling of workers (doing post-processing at build time for example). -+ // And Plugins may also have cached that could be corrupted by being used in these extra rollup calls. -+ // So we need to separate the worker plugin from the plugin that vite needs to run. -+ const rawWorkerUserPlugins = ( -+ (await asyncFlatten(createUserWorkerPlugins?.() || [])) as Plugin[] -+ ).filter(filterPlugin) -+ -+ // resolve worker -+ let workerConfig = mergeConfig({}, config) -+ const [workerPrePlugins, workerNormalPlugins, workerPostPlugins] = -+ sortUserPlugins(rawWorkerUserPlugins) -+ -+ // run config hooks -+ const workerUserPlugins = [ -+ ...workerPrePlugins, -+ ...workerNormalPlugins, -+ ...workerPostPlugins, -+ ] -+ workerConfig = await runConfigHook( -+ workerConfig, -+ workerUserPlugins, -+ configEnv, -+ ) -+ -+ const workerResolved: ResolvedConfig = { -+ ...workerConfig, -+ ...resolved, -+ isWorker: true, -+ mainConfig: resolved, -+ bundleChain, -+ } -+ const resolvedWorkerPlugins = await resolvePlugins( -+ workerResolved, -+ workerPrePlugins, -+ workerNormalPlugins, -+ workerPostPlugins, -+ ) -+ -+ // run configResolved hooks -+ await Promise.all( -+ createPluginHookUtils(resolvedWorkerPlugins) -+ .getSortedPluginHooks('configResolved') -+ .map((hook) => hook(workerResolved)), -+ ) -+ -+ return resolvedWorkerPlugins -+ } -+ -+ const resolvedWorkerOptions: ResolvedWorkerOptions = { -+ format: config.worker?.format || 'iife', -+ plugins: createWorkerPlugins, -+ rollupOptions: config.worker?.rollupOptions || {}, -+ } -+ -+ const base = withTrailingSlash(resolvedBase) -+ -+ resolved = { -+ configFile: configFile ? normalizePath(configFile) : undefined, -+ configFileDependencies: configFileDependencies.map((name) => -+ normalizePath(path.resolve(name)), -+ ), -+ inlineConfig, -+ root: resolvedRoot, -+ base, -+ decodedBase: decodeURI(base), -+ rawBase: resolvedBase, -+ resolve: resolveOptions, -+ publicDir: resolvedPublicDir, -+ cacheDir, -+ command, -+ mode, -+ ssr, -+ isWorker: false, -+ mainConfig: null, -+ bundleChain: [], -+ isProduction, -+ plugins: userPlugins, -+ css: resolveCSSOptions(config.css), -+ esbuild: -+ config.esbuild === false -+ ? false -+ : { -+ jsxDev: !isProduction, -+ ...config.esbuild, -+ }, -+ server, -+ build: resolvedBuildOptions, -+ preview: resolvePreviewOptions(config.preview, server), -+ envDir, -+ env: { -+ ...userEnv, -+ BASE_URL, -+ MODE: mode, -+ DEV: !isProduction, -+ PROD: isProduction, -+ }, -+ assetsInclude(file: string) { -+ return DEFAULT_ASSETS_RE.test(file) || assetsFilter(file) -+ }, -+ logger, -+ packageCache, -+ createResolver, -+ optimizeDeps: { -+ holdUntilCrawlEnd: true, -+ ...optimizeDeps, -+ esbuildOptions: { -+ preserveSymlinks: resolveOptions.preserveSymlinks, -+ ...optimizeDeps.esbuildOptions, -+ }, -+ }, -+ worker: resolvedWorkerOptions, -+ appType: config.appType ?? 'spa', -+ experimental: { -+ importGlobRestoreExtension: false, -+ hmrPartialAccept: false, -+ ...config.experimental, -+ }, -+ getSortedPlugins: undefined!, -+ getSortedPluginHooks: undefined!, -+ } -+ resolved = { -+ ...config, -+ ...resolved, -+ } -+ ;(resolved.plugins as Plugin[]) = await resolvePlugins( -+ resolved, -+ prePlugins, -+ normalPlugins, -+ postPlugins, -+ ) -+ Object.assign(resolved, createPluginHookUtils(resolved.plugins)) -+ -+ // call configResolved hooks -+ await Promise.all( -+ resolved -+ .getSortedPluginHooks('configResolved') -+ .map((hook) => hook(resolved)), -+ ) -+ -+ optimizeDepsDisabledBackwardCompatibility(resolved, resolved.optimizeDeps) -+ optimizeDepsDisabledBackwardCompatibility( -+ resolved, -+ resolved.ssr.optimizeDeps, -+ 'ssr.', -+ ) -+ -+ debug?.(`using resolved config: %O`, { -+ ...resolved, -+ plugins: resolved.plugins.map((p) => p.name), -+ worker: { -+ ...resolved.worker, -+ plugins: `() => plugins`, -+ }, -+ }) -+ -+ // validate config -+ -+ if ( -+ config.build?.terserOptions && -+ config.build.minify && -+ config.build.minify !== 'terser' -+ ) { -+ logger.warn( -+ colors.yellow( -+ `build.terserOptions is specified but build.minify is not set to use Terser. ` + -+ `Note Vite now defaults to use esbuild for minification. If you still ` + -+ `prefer Terser, set build.minify to "terser".`, -+ ), -+ ) -+ } -+ -+ // Check if all assetFileNames have the same reference. -+ // If not, display a warn for user. -+ const outputOption = config.build?.rollupOptions?.output ?? [] -+ // Use isArray to narrow its type to array -+ if (Array.isArray(outputOption)) { -+ const assetFileNamesList = outputOption.map( -+ (output) => output.assetFileNames, -+ ) -+ if (assetFileNamesList.length > 1) { -+ const firstAssetFileNames = assetFileNamesList[0] -+ const hasDifferentReference = assetFileNamesList.some( -+ (assetFileNames) => assetFileNames !== firstAssetFileNames, -+ ) -+ if (hasDifferentReference) { -+ resolved.logger.warn( -+ colors.yellow(` -+assetFileNames isn't equal for every build.rollupOptions.output. A single pattern across all outputs is supported by Vite. -+`), -+ ) -+ } -+ } -+ } -+ -+ // Warn about removal of experimental features -+ if ( -+ // @ts-expect-error Option removed -+ config.legacy?.buildSsrCjsExternalHeuristics || -+ // @ts-expect-error Option removed -+ config.ssr?.format === 'cjs' -+ ) { -+ resolved.logger.warn( -+ colors.yellow(` -+(!) Experimental legacy.buildSsrCjsExternalHeuristics and ssr.format were be removed in Vite 5. -+ The only SSR Output format is ESM. Find more information at https://github.com/vitejs/vite/discussions/13816. -+`), -+ ) -+ } -+ -+ const resolvedBuildOutDir = normalizePath( -+ path.resolve(resolved.root, resolved.build.outDir), -+ ) -+ if ( -+ isParentDirectory(resolvedBuildOutDir, resolved.root) || -+ resolvedBuildOutDir === resolved.root -+ ) { -+ resolved.logger.warn( -+ colors.yellow(` -+(!) build.outDir must not be the same directory of root or a parent directory of root as this could cause Vite to overwriting source files with build outputs. -+`), -+ ) -+ } -+ -+ return resolved -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/env.ts b/packages/vite/src/node/env.ts -index 897524612..ddfe02937 100644 ---- a/packages/vite/src/node/env.ts -+++ b/packages/vite/src/node/env.ts -@@ -1,9 +1,11 @@ -+import { UserConfig } from 'packages/vite/src/node/config/UserConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/loadConfigFromFile'; - import fs from 'node:fs' - import path from 'node:path' - import { parse } from 'dotenv' - import { type DotenvPopulateInput, expand } from 'dotenv-expand' - import { arraify, normalizePath, tryStatSync } from './utils' --import type { UserConfig } from './config' - - export function getEnvFilesForMode(mode: string, envDir: string): string[] { - return [ -diff --git a/packages/vite/src/node/fsUtils.ts b/packages/vite/src/node/fsUtils.ts -index a295d4fc4..56f34b315 100644 ---- a/packages/vite/src/node/fsUtils.ts -+++ b/packages/vite/src/node/fsUtils.ts -@@ -1,7 +1,7 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; - import fs from 'node:fs' - import path from 'node:path' - import type { FSWatcher } from 'dep-types/chokidar' --import type { ResolvedConfig } from './config' - import { - isInNodeModules, - normalizePath, -diff --git a/packages/vite/src/node/http.ts b/packages/vite/src/node/http.ts -index 51a063ba8..0bf31e9a2 100644 ---- a/packages/vite/src/node/http.ts -+++ b/packages/vite/src/node/http.ts -@@ -1,3 +1,8 @@ -+import { HttpServer } from 'packages/vite/src/node/server/index/ServerOptions'; -+import { HttpServer } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { HttpServer } from 'packages/vite/src/node/server/index/_createServer'; -+import { HttpServer } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { HttpServer } from 'packages/vite/src/node/server/index/restartServer'; - import fsp from 'node:fs/promises' - import path from 'node:path' - import type { OutgoingHttpHeaders as HttpServerHeaders } from 'node:http' -@@ -6,7 +11,6 @@ import type { Connect } from 'dep-types/connect' - import colors from 'picocolors' - import type { ProxyOptions } from './server/middlewares/proxy' - import type { Logger } from './logger' --import type { HttpServer } from './server' - - export interface CommonServerOptions { - /** -diff --git a/packages/vite/src/node/index.ts b/packages/vite/src/node/index.ts -index 3d17f8737..6f26e34e2 100644 ---- a/packages/vite/src/node/index.ts -+++ b/packages/vite/src/node/index.ts -@@ -1,16 +1,98 @@ -+import { LibraryFormats } from 'packages/vite/src/node/build/BuildOptions'; -+import { LibraryOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { ResolveModulePreloadDependenciesFn } from 'packages/vite/src/node/build/BuildOptions'; -+import { ModulePreloadOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build/BuildOptions'; -+import { LibraryFormats } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { LibraryOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ResolveModulePreloadDependenciesFn } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ModulePreloadOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { BuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ResolvedModulePreloadOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { ResolvedBuildOptions } from 'packages/vite/src/node/build/resolveBuildOptions'; -+import { build } from 'packages/vite/src/node/build/build'; -+import { LibraryFormats } from 'packages/vite/src/node/build/resolveBuildOutputs'; -+import { LibraryOptions } from 'packages/vite/src/node/build/resolveBuildOutputs'; -+import { AppType } from 'packages/vite/src/node/config/UserConfig'; -+import { PluginOption } from 'packages/vite/src/node/config/UserConfig'; -+import { HTMLOptions } from 'packages/vite/src/node/config/UserConfig'; -+import { ExperimentalOptions } from 'packages/vite/src/node/config/UserConfig'; -+import { LegacyOptions } from 'packages/vite/src/node/config/UserConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/UserConfig'; -+import { ConfigEnv } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedWorkerOptions } from 'packages/vite/src/node/config/resolveConfig'; -+import { AppType } from 'packages/vite/src/node/config/resolveConfig'; -+import { PluginOption } from 'packages/vite/src/node/config/resolveConfig'; -+import { HTMLOptions } from 'packages/vite/src/node/config/resolveConfig'; -+import { ExperimentalOptions } from 'packages/vite/src/node/config/resolveConfig'; -+import { LegacyOptions } from 'packages/vite/src/node/config/resolveConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { InlineConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { PluginHookUtils } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolveFn } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { resolveConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ConfigEnv } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { AppType } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { PluginOption } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { HTMLOptions } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { ExperimentalOptions } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { LegacyOptions } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { UserConfig } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { loadConfigFromFile } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { HmrContext } from 'packages/vite/src/node/server/hmr/handleHMRUpdate'; -+import { DepOptimizationConfig } from 'packages/vite/src/node/optimizer/index/DepOptimizationConfig'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { DepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { DepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/extractExportsData'; -+import { FileSystemServeOptions } from 'packages/vite/src/node/server/index/ServerOptions'; -+import { HttpServer } from 'packages/vite/src/node/server/index/ServerOptions'; -+import { ServerOptions } from 'packages/vite/src/node/server/index/ServerOptions'; -+import { HttpServer } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { ResolvedServerUrls } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { HttpServer } from 'packages/vite/src/node/server/index/_createServer'; -+import { ResolvedServerUrls } from 'packages/vite/src/node/server/index/_createServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/_createServer'; -+import { FileSystemServeOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { HttpServer } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { ServerOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { ResolvedServerOptions } from 'packages/vite/src/node/server/index/resolveServerOptions'; -+import { HttpServer } from 'packages/vite/src/node/server/index/restartServer'; -+import { ResolvedServerUrls } from 'packages/vite/src/node/server/index/restartServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/restartServer'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryCleanFsResolve'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryCleanFsResolve'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePackageEntry'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePackageEntry'; -+import { ResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; -+import { TransformResult } from 'packages/vite/src/node/server/transformRequest/transformRequest'; -+import { TransformOptions } from 'packages/vite/src/node/server/transformRequest/transformRequest'; -+import { TransformOptions } from 'packages/vite/src/node/server/transformRequest/doTransform'; -+import { TransformResult } from 'packages/vite/src/node/server/transformRequest/loadAndTransform'; -+import { TransformOptions } from 'packages/vite/src/node/server/transformRequest/loadAndTransform'; -+import { TransformResult } from 'packages/vite/src/node/server/transformRequest/handleModuleSoftInvalidation'; - import type * as Rollup from 'rollup' - - export type { Rollup } - export { parseAst, parseAstAsync } from 'rollup/parseAst' - export { - defineConfig, -- loadConfigFromFile, -- resolveConfig, - sortUserPlugins, - } from './config' - export { createServer } from './server' - export { preview } from './preview' --export { build } from './build' - export { optimizeDeps } from './optimizer' - export { formatPostcssSourceMap, preprocessCSS } from './plugins/css' - export { transformWithEsbuild } from './plugins/esbuild' -@@ -20,58 +102,14 @@ export type { FetchModuleOptions } from './ssr/fetchModule' - export * from './publicUtils' - - // additional types --export type { -- AppType, -- ConfigEnv, -- ExperimentalOptions, -- HTMLOptions, -- InlineConfig, -- LegacyOptions, -- PluginHookUtils, -- PluginOption, -- ResolveFn, -- ResolvedWorkerOptions, -- ResolvedConfig, -- UserConfig, -- UserConfigExport, -- UserConfigFn, -- UserConfigFnObject, -- UserConfigFnPromise, --} from './config' - export type { FilterPattern } from './utils' - export type { CorsOptions, CorsOrigin, CommonServerOptions } from './http' --export type { -- ViteDevServer, -- ServerOptions, -- FileSystemServeOptions, -- ServerHook, -- ResolvedServerOptions, -- ResolvedServerUrls, -- HttpServer, --} from './server' --export type { -- BuildOptions, -- LibraryOptions, -- LibraryFormats, -- RenderBuiltAssetUrl, -- ResolvedBuildOptions, -- ModulePreloadOptions, -- ResolvedModulePreloadOptions, -- ResolveModulePreloadDependenciesFn, --} from './build' - export type { - PreviewOptions, - PreviewServer, - PreviewServerHook, - ResolvedPreviewOptions, - } from './preview' --export type { -- DepOptimizationMetadata, -- DepOptimizationOptions, -- DepOptimizationConfig, -- OptimizedDepInfo, -- ExportsData, --} from './optimizer' - export type { - ResolvedSSROptions, - SsrDepOptimizationOptions, -@@ -104,7 +142,6 @@ export type { JsonOptions } from './plugins/json' - export type { TransformOptions as EsbuildTransformOptions } from 'esbuild' - export type { ESBuildOptions, ESBuildTransformResult } from './plugins/esbuild' - export type { Manifest, ManifestChunk } from './plugins/manifest' --export type { ResolveOptions, InternalResolveOptions } from './plugins/resolve' - export type { SplitVendorChunkCache } from './plugins/splitVendorChunk' - export type { TerserOptions } from './plugins/terser' - -@@ -117,11 +154,7 @@ export type { PluginContainer } from './server/pluginContainer' - export type { ModuleGraph, ModuleNode, ResolvedUrl } from './server/moduleGraph' - export type { SendOptions } from './server/send' - export type { ProxyOptions } from './server/middlewares/proxy' --export type { -- TransformOptions, -- TransformResult, --} from './server/transformRequest' --export type { HmrOptions, HmrContext } from './server/hmr' -+export type { HmrOptions } from './server/hmr' - - export type { - HMRBroadcaster, -diff --git a/packages/vite/src/node/logger.ts b/packages/vite/src/node/logger.ts -index 7928c954b..ece4da37a 100644 ---- a/packages/vite/src/node/logger.ts -+++ b/packages/vite/src/node/logger.ts -@@ -1,9 +1,11 @@ - /* eslint no-console: 0 */ - -+import { ResolvedServerUrls } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { ResolvedServerUrls } from 'packages/vite/src/node/server/index/_createServer'; -+import { ResolvedServerUrls } from 'packages/vite/src/node/server/index/restartServer'; - import readline from 'node:readline' - import colors from 'picocolors' - import type { RollupError } from 'rollup' --import type { ResolvedServerUrls } from './server' - - export type LogType = 'error' | 'warn' | 'info' - export type LogLevel = LogType | 'silent' -diff --git a/packages/vite/src/node/optimizer/esbuildDepPlugin.ts b/packages/vite/src/node/optimizer/esbuildDepPlugin.ts -index 1f4c4dab1..36d84b6e4 100644 ---- a/packages/vite/src/node/optimizer/esbuildDepPlugin.ts -+++ b/packages/vite/src/node/optimizer/esbuildDepPlugin.ts -@@ -1,9 +1,17 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isModuleCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssAnalysisPlugin'; -+import { browserExternalId } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { optionalPeerDepId } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { optionalPeerDepId } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { browserExternalId } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; - import path from 'node:path' - import type { ImportKind, Plugin } from 'esbuild' - import { KNOWN_ASSET_TYPES } from '../constants' - import type { PackageCache } from '../packages' - import { getDepOptimizationConfig } from '../config' --import type { ResolvedConfig } from '../config' - import { - escapeRegex, - flattenId, -@@ -12,8 +20,6 @@ import { - moduleListContains, - normalizePath, - } from '../utils' --import { browserExternalId, optionalPeerDepId } from '../plugins/resolve' --import { isCSSRequest, isModuleCSSRequest } from '../plugins/css' - - const externalWithConversionNamespace = - 'vite:dep-pre-bundle:external-conversion' -diff --git a/packages/vite/src/node/optimizer/index.ts b/packages/vite/src/node/optimizer/index.ts -index e62d78fdf..826868a61 100644 ---- a/packages/vite/src/node/optimizer/index.ts -+++ b/packages/vite/src/node/optimizer/index.ts -@@ -1,3 +1,30 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { DepOptimizationConfig } from 'packages/vite/src/node/optimizer/index/DepOptimizationConfig'; -+import { DepOptimizationConfig } from 'packages/vite/src/node/optimizer/index'; -+import { debug } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { DepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { loadCachedDepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { loadCachedDepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index'; -+import { jsMapExtensionRE } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { DepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { DepOptimizationResult } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { initDepsOptimizerMetadata } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { addOptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { runOptimizeDeps } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { runOptimizeDeps } from 'packages/vite/src/node/optimizer/index'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun'; -+import { prepareEsbuildOptimizerRun } from 'packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun'; -+import { prepareEsbuildOptimizerRun } from 'packages/vite/src/node/optimizer/index'; -+import { addManuallyIncludedOptimizeDeps } from 'packages/vite/src/node/optimizer/index/addManuallyIncludedOptimizeDeps'; -+import { addManuallyIncludedOptimizeDeps } from 'packages/vite/src/node/optimizer/index'; -+import { ExportsData } from 'packages/vite/src/node/optimizer/index/extractExportsData'; -+import { extractExportsData } from 'packages/vite/src/node/optimizer/index/extractExportsData'; -+import { extractExportsData } from 'packages/vite/src/node/optimizer/index'; - import fs from 'node:fs' - import fsp from 'node:fs/promises' - import path from 'node:path' -@@ -9,7 +36,6 @@ import esbuild, { build } from 'esbuild' - import { init, parse } from 'es-module-lexer' - import glob from 'fast-glob' - import { getDepOptimizationConfig } from '../config' --import type { ResolvedConfig } from '../config' - import { - createDebugger, - flattenId, -@@ -37,18 +63,7 @@ export { - getDepsOptimizer, - } from './optimizer' - --const debug = createDebugger('vite:deps') -- - const jsExtensionRE = /\.js$/i --const jsMapExtensionRE = /\.js\.map$/i -- --export type ExportsData = { -- hasModuleSyntax: boolean -- // exported names (for `export { a as b }`, `b` is exported name) -- exports: readonly string[] -- // hint if the dep requires loading as jsx -- jsxLoader?: boolean --} - - export interface DepsOptimizer { - metadata: DepOptimizationMetadata -@@ -65,88 +80,6 @@ export interface DepsOptimizer { - options: DepOptimizationOptions - } - --export interface DepOptimizationConfig { -- /** -- * Force optimize listed dependencies (must be resolvable import paths, -- * cannot be globs). -- */ -- include?: string[] -- /** -- * Do not optimize these dependencies (must be resolvable import paths, -- * cannot be globs). -- */ -- exclude?: string[] -- /** -- * Forces ESM interop when importing these dependencies. Some legacy -- * packages advertise themselves as ESM but use `require` internally -- * @experimental -- */ -- needsInterop?: string[] -- /** -- * Options to pass to esbuild during the dep scanning and optimization -- * -- * Certain options are omitted since changing them would not be compatible -- * with Vite's dep optimization. -- * -- * - `external` is also omitted, use Vite's `optimizeDeps.exclude` option -- * - `plugins` are merged with Vite's dep plugin -- * -- * https://esbuild.github.io/api -- */ -- esbuildOptions?: Omit< -- EsbuildBuildOptions, -- | 'bundle' -- | 'entryPoints' -- | 'external' -- | 'write' -- | 'watch' -- | 'outdir' -- | 'outfile' -- | 'outbase' -- | 'outExtension' -- | 'metafile' -- > -- /** -- * List of file extensions that can be optimized. A corresponding esbuild -- * plugin must exist to handle the specific extension. -- * -- * By default, Vite can optimize `.mjs`, `.js`, `.ts`, and `.mts` files. This option -- * allows specifying additional extensions. -- * -- * @experimental -- */ -- extensions?: string[] -- /** -- * Deps optimization during build was removed in Vite 5.1. This option is -- * now redundant and will be removed in a future version. Switch to using -- * `optimizeDeps.noDiscovery` and an empty or undefined `optimizeDeps.include`. -- * true or 'dev' disables the optimizer, false or 'build' leaves it enabled. -- * @default 'build' -- * @deprecated -- * @experimental -- */ -- disabled?: boolean | 'build' | 'dev' -- /** -- * Automatic dependency discovery. When `noDiscovery` is true, only dependencies -- * listed in `include` will be optimized. The scanner isn't run for cold start -- * in this case. CJS-only dependencies must be present in `include` during dev. -- * @default false -- * @experimental -- */ -- noDiscovery?: boolean -- /** -- * When enabled, it will hold the first optimized deps results until all static -- * imports are crawled on cold start. This avoids the need for full-page reloads -- * when new dependencies are discovered and they trigger the generation of new -- * common chunks. If all dependencies are found by the scanner plus the explicitly -- * defined ones in `include`, it is better to disable this option to let the -- * browser process more requests in parallel. -- * @default true -- * @experimental -- */ -- holdUntilCrawlEnd?: boolean --} -- - export type DepOptimizationOptions = DepOptimizationConfig & { - /** - * By default, Vite will crawl your `index.html` to detect dependencies that -@@ -166,76 +99,6 @@ export type DepOptimizationOptions = DepOptimizationConfig & { - force?: boolean - } - --export interface DepOptimizationResult { -- metadata: DepOptimizationMetadata -- /** -- * When doing a re-run, if there are newly discovered dependencies -- * the page reload will be delayed until the next rerun so we need -- * to be able to discard the result -- */ -- commit: () => Promise -- cancel: () => void --} -- --export interface OptimizedDepInfo { -- id: string -- file: string -- src?: string -- needsInterop?: boolean -- browserHash?: string -- fileHash?: string -- /** -- * During optimization, ids can still be resolved to their final location -- * but the bundles may not yet be saved to disk -- */ -- processing?: Promise -- /** -- * ExportData cache, discovered deps will parse the src entry to get exports -- * data used both to define if interop is needed and when pre-bundling -- */ -- exportsData?: Promise --} -- --export interface DepOptimizationMetadata { -- /** -- * The main hash is determined by user config and dependency lockfiles. -- * This is checked on server startup to avoid unnecessary re-bundles. -- */ -- hash: string -- /** -- * This hash is determined by dependency lockfiles. -- * This is checked on server startup to avoid unnecessary re-bundles. -- */ -- lockfileHash: string -- /** -- * This hash is determined by user config. -- * This is checked on server startup to avoid unnecessary re-bundles. -- */ -- configHash: string -- /** -- * The browser hash is determined by the main hash plus additional dependencies -- * discovered at runtime. This is used to invalidate browser requests to -- * optimized deps. -- */ -- browserHash: string -- /** -- * Metadata for each already optimized dependency -- */ -- optimized: Record -- /** -- * Metadata for non-entry optimized chunks and dynamic imports -- */ -- chunks: Record -- /** -- * Metadata for each newly discovered dependency after processing -- */ -- discovered: Record -- /** -- * OptimizedDepInfo list -- */ -- depInfoList: OptimizedDepInfo[] --} -- - /** - * Scan and optimize dependencies within a project. - * Used by Vite CLI when running `vite optimize`. -@@ -301,91 +164,8 @@ export async function optimizeServerSsrDeps( - return result.metadata - } - --export function initDepsOptimizerMetadata( -- config: ResolvedConfig, -- ssr: boolean, -- timestamp?: string, --): DepOptimizationMetadata { -- const { lockfileHash, configHash, hash } = getDepHash(config, ssr) -- return { -- hash, -- lockfileHash, -- configHash, -- browserHash: getOptimizedBrowserHash(hash, {}, timestamp), -- optimized: {}, -- chunks: {}, -- discovered: {}, -- depInfoList: [], -- } --} -- --export function addOptimizedDepInfo( -- metadata: DepOptimizationMetadata, -- type: 'optimized' | 'discovered' | 'chunks', -- depInfo: OptimizedDepInfo, --): OptimizedDepInfo { -- metadata[type][depInfo.id] = depInfo -- metadata.depInfoList.push(depInfo) -- return depInfo --} -- - let firstLoadCachedDepOptimizationMetadata = true - --/** -- * Creates the initial dep optimization metadata, loading it from the deps cache -- * if it exists and pre-bundling isn't forced -- */ --export async function loadCachedDepOptimizationMetadata( -- config: ResolvedConfig, -- ssr: boolean, -- force = config.optimizeDeps.force, -- asCommand = false, --): Promise { -- const log = asCommand ? config.logger.info : debug -- -- if (firstLoadCachedDepOptimizationMetadata) { -- firstLoadCachedDepOptimizationMetadata = false -- // Fire up a clean up of stale processing deps dirs if older process exited early -- setTimeout(() => cleanupDepsCacheStaleDirs(config), 0) -- } -- -- const depsCacheDir = getDepsCacheDir(config, ssr) -- -- if (!force) { -- let cachedMetadata: DepOptimizationMetadata | undefined -- try { -- const cachedMetadataPath = path.join(depsCacheDir, METADATA_FILENAME) -- cachedMetadata = parseDepsOptimizerMetadata( -- await fsp.readFile(cachedMetadataPath, 'utf-8'), -- depsCacheDir, -- ) -- } catch (e) {} -- // hash is consistent, no need to re-bundle -- if (cachedMetadata) { -- if (cachedMetadata.lockfileHash !== getLockfileHash(config, ssr)) { -- config.logger.info( -- 'Re-optimizing dependencies because lockfile has changed', -- ) -- } else if (cachedMetadata.configHash !== getConfigHash(config, ssr)) { -- config.logger.info( -- 'Re-optimizing dependencies because vite config has changed', -- ) -- } else { -- log?.('Hash is consistent. Skipping. Use --force to override.') -- // Nothing to commit or cancel as we are using the cache, we only -- // need to resolve the processing promise so requests can move on -- return cachedMetadata -- } -- } -- } else { -- config.logger.info('Forced re-optimization of dependencies') -- } -- -- // Start with a fresh cache -- debug?.(colors.green(`removing old cache dir ${depsCacheDir}`)) -- await fsp.rm(depsCacheDir, { recursive: true, force: true }) --} -- - /** - * Initial optimizeDeps at server start. Perform a fast scan using esbuild to - * find deps to pre-bundle and include user hard-coded dependencies -@@ -447,422 +227,6 @@ export function depsLogString(qualifiedIds: string[]): string { - return colors.yellow(qualifiedIds.join(`, `)) - } - --/** -- * Internally, Vite uses this function to prepare a optimizeDeps run. When Vite starts, we can get -- * the metadata and start the server without waiting for the optimizeDeps processing to be completed -- */ --export function runOptimizeDeps( -- resolvedConfig: ResolvedConfig, -- depsInfo: Record, -- ssr: boolean, --): { -- cancel: () => Promise -- result: Promise --} { -- const optimizerContext = { cancelled: false } -- -- const config: ResolvedConfig = { -- ...resolvedConfig, -- command: 'build', -- } -- -- const depsCacheDir = getDepsCacheDir(resolvedConfig, ssr) -- const processingCacheDir = getProcessingDepsCacheDir(resolvedConfig, ssr) -- -- // Create a temporary directory so we don't need to delete optimized deps -- // until they have been processed. This also avoids leaving the deps cache -- // directory in a corrupted state if there is an error -- fs.mkdirSync(processingCacheDir, { recursive: true }) -- -- // a hint for Node.js -- // all files in the cache directory should be recognized as ES modules -- debug?.(colors.green(`creating package.json in ${processingCacheDir}`)) -- fs.writeFileSync( -- path.resolve(processingCacheDir, 'package.json'), -- `{\n "type": "module"\n}\n`, -- ) -- -- const metadata = initDepsOptimizerMetadata(config, ssr) -- -- metadata.browserHash = getOptimizedBrowserHash( -- metadata.hash, -- depsFromOptimizedDepInfo(depsInfo), -- ) -- -- // We prebundle dependencies with esbuild and cache them, but there is no need -- // to wait here. Code that needs to access the cached deps needs to await -- // the optimizedDepInfo.processing promise for each dep -- -- const qualifiedIds = Object.keys(depsInfo) -- let cleaned = false -- let committed = false -- const cleanUp = () => { -- // If commit was already called, ignore the clean up even if a cancel was requested -- // This minimizes the chances of leaving the deps cache in a corrupted state -- if (!cleaned && !committed) { -- cleaned = true -- // No need to wait, we can clean up in the background because temp folders -- // are unique per run -- debug?.(colors.green(`removing cache dir ${processingCacheDir}`)) -- try { -- // When exiting the process, `fsp.rm` may not take effect, so we use `fs.rmSync` -- fs.rmSync(processingCacheDir, { recursive: true, force: true }) -- } catch (error) { -- // Ignore errors -- } -- } -- } -- -- const successfulResult: DepOptimizationResult = { -- metadata, -- cancel: cleanUp, -- commit: async () => { -- if (cleaned) { -- throw new Error( -- 'Can not commit a Deps Optimization run as it was cancelled', -- ) -- } -- // Ignore clean up requests after this point so the temp folder isn't deleted before -- // we finish committing the new deps cache files to the deps folder -- committed = true -- -- // Write metadata file, then commit the processing folder to the global deps cache -- // Rewire the file paths from the temporary processing dir to the final deps cache dir -- const dataPath = path.join(processingCacheDir, METADATA_FILENAME) -- debug?.( -- colors.green(`creating ${METADATA_FILENAME} in ${processingCacheDir}`), -- ) -- fs.writeFileSync( -- dataPath, -- stringifyDepsOptimizerMetadata(metadata, depsCacheDir), -- ) -- -- // In order to minimize the time where the deps folder isn't in a consistent state, -- // we first rename the old depsCacheDir to a temporary path, then we rename the -- // new processing cache dir to the depsCacheDir. In systems where doing so in sync -- // is safe, we do an atomic operation (at least for this thread). For Windows, we -- // found there are cases where the rename operation may finish before it's done -- // so we do a graceful rename checking that the folder has been properly renamed. -- // We found that the rename-rename (then delete the old folder in the background) -- // is safer than a delete-rename operation. -- const temporaryPath = depsCacheDir + getTempSuffix() -- const depsCacheDirPresent = fs.existsSync(depsCacheDir) -- if (isWindows) { -- if (depsCacheDirPresent) { -- debug?.(colors.green(`renaming ${depsCacheDir} to ${temporaryPath}`)) -- await safeRename(depsCacheDir, temporaryPath) -- } -- debug?.( -- colors.green(`renaming ${processingCacheDir} to ${depsCacheDir}`), -- ) -- await safeRename(processingCacheDir, depsCacheDir) -- } else { -- if (depsCacheDirPresent) { -- debug?.(colors.green(`renaming ${depsCacheDir} to ${temporaryPath}`)) -- fs.renameSync(depsCacheDir, temporaryPath) -- } -- debug?.( -- colors.green(`renaming ${processingCacheDir} to ${depsCacheDir}`), -- ) -- fs.renameSync(processingCacheDir, depsCacheDir) -- } -- -- // Delete temporary path in the background -- if (depsCacheDirPresent) { -- debug?.(colors.green(`removing cache temp dir ${temporaryPath}`)) -- fsp.rm(temporaryPath, { recursive: true, force: true }) -- } -- }, -- } -- -- if (!qualifiedIds.length) { -- // No deps to optimize, we still commit the processing cache dir to remove -- // the previous optimized deps if they exist, and let the next server start -- // skip the scanner step if the lockfile hasn't changed -- return { -- cancel: async () => cleanUp(), -- result: Promise.resolve(successfulResult), -- } -- } -- -- const cancelledResult: DepOptimizationResult = { -- metadata, -- commit: async () => cleanUp(), -- cancel: cleanUp, -- } -- -- const start = performance.now() -- -- const preparedRun = prepareEsbuildOptimizerRun( -- resolvedConfig, -- depsInfo, -- ssr, -- processingCacheDir, -- optimizerContext, -- ) -- -- const runResult = preparedRun.then(({ context, idToExports }) => { -- function disposeContext() { -- return context?.dispose().catch((e) => { -- config.logger.error('Failed to dispose esbuild context', { error: e }) -- }) -- } -- if (!context || optimizerContext.cancelled) { -- disposeContext() -- return cancelledResult -- } -- -- return context -- .rebuild() -- .then((result) => { -- const meta = result.metafile! -- -- // the paths in `meta.outputs` are relative to `process.cwd()` -- const processingCacheDirOutputPath = path.relative( -- process.cwd(), -- processingCacheDir, -- ) -- -- for (const id in depsInfo) { -- const output = esbuildOutputFromId( -- meta.outputs, -- id, -- processingCacheDir, -- ) -- -- const { exportsData, ...info } = depsInfo[id] -- addOptimizedDepInfo(metadata, 'optimized', { -- ...info, -- // We only need to hash the output.imports in to check for stability, but adding the hash -- // and file path gives us a unique hash that may be useful for other things in the future -- fileHash: getHash( -- metadata.hash + -- depsInfo[id].file + -- JSON.stringify(output.imports), -- ), -- browserHash: metadata.browserHash, -- // After bundling we have more information and can warn the user about legacy packages -- // that require manual configuration -- needsInterop: needsInterop( -- config, -- ssr, -- id, -- idToExports[id], -- output, -- ), -- }) -- } -- -- for (const o of Object.keys(meta.outputs)) { -- if (!jsMapExtensionRE.test(o)) { -- const id = path -- .relative(processingCacheDirOutputPath, o) -- .replace(jsExtensionRE, '') -- const file = getOptimizedDepPath(id, resolvedConfig, ssr) -- if ( -- !findOptimizedDepInfoInRecord( -- metadata.optimized, -- (depInfo) => depInfo.file === file, -- ) -- ) { -- addOptimizedDepInfo(metadata, 'chunks', { -- id, -- file, -- needsInterop: false, -- browserHash: metadata.browserHash, -- }) -- } -- } -- } -- -- debug?.( -- `Dependencies bundled in ${(performance.now() - start).toFixed(2)}ms`, -- ) -- -- return successfulResult -- }) -- -- .catch((e) => { -- if (e.errors && e.message.includes('The build was canceled')) { -- // esbuild logs an error when cancelling, but this is expected so -- // return an empty result instead -- return cancelledResult -- } -- throw e -- }) -- .finally(() => { -- return disposeContext() -- }) -- }) -- -- runResult.catch(() => { -- cleanUp() -- }) -- -- return { -- async cancel() { -- optimizerContext.cancelled = true -- const { context } = await preparedRun -- await context?.cancel() -- cleanUp() -- }, -- result: runResult, -- } --} -- --async function prepareEsbuildOptimizerRun( -- resolvedConfig: ResolvedConfig, -- depsInfo: Record, -- ssr: boolean, -- processingCacheDir: string, -- optimizerContext: { cancelled: boolean }, --): Promise<{ -- context?: BuildContext -- idToExports: Record --}> { -- const config: ResolvedConfig = { -- ...resolvedConfig, -- command: 'build', -- } -- -- // esbuild generates nested directory output with lowest common ancestor base -- // this is unpredictable and makes it difficult to analyze entry / output -- // mapping. So what we do here is: -- // 1. flatten all ids to eliminate slash -- // 2. in the plugin, read the entry ourselves as virtual files to retain the -- // path. -- const flatIdDeps: Record = {} -- const idToExports: Record = {} -- -- const optimizeDeps = getDepOptimizationConfig(config, ssr) -- -- const { plugins: pluginsFromConfig = [], ...esbuildOptions } = -- optimizeDeps?.esbuildOptions ?? {} -- -- await Promise.all( -- Object.keys(depsInfo).map(async (id) => { -- const src = depsInfo[id].src! -- const exportsData = await (depsInfo[id].exportsData ?? -- extractExportsData(src, config, ssr)) -- if (exportsData.jsxLoader && !esbuildOptions.loader?.['.js']) { -- // Ensure that optimization won't fail by defaulting '.js' to the JSX parser. -- // This is useful for packages such as Gatsby. -- esbuildOptions.loader = { -- '.js': 'jsx', -- ...esbuildOptions.loader, -- } -- } -- const flatId = flattenId(id) -- flatIdDeps[flatId] = src -- idToExports[id] = exportsData -- }), -- ) -- -- if (optimizerContext.cancelled) return { context: undefined, idToExports } -- -- const define = { -- 'process.env.NODE_ENV': JSON.stringify(process.env.NODE_ENV || config.mode), -- } -- -- const platform = -- ssr && config.ssr?.target !== 'webworker' ? 'node' : 'browser' -- -- const external = [...(optimizeDeps?.exclude ?? [])] -- -- const plugins = [...pluginsFromConfig] -- if (external.length) { -- plugins.push(esbuildCjsExternalPlugin(external, platform)) -- } -- plugins.push(esbuildDepPlugin(flatIdDeps, external, config, ssr)) -- -- const context = await esbuild.context({ -- absWorkingDir: process.cwd(), -- entryPoints: Object.keys(flatIdDeps), -- bundle: true, -- // We can't use platform 'neutral', as esbuild has custom handling -- // when the platform is 'node' or 'browser' that can't be emulated -- // by using mainFields and conditions -- platform, -- define, -- format: 'esm', -- // See https://github.com/evanw/esbuild/issues/1921#issuecomment-1152991694 -- banner: -- platform === 'node' -- ? { -- js: `import { createRequire } from 'module';const require = createRequire(import.meta.url);`, -- } -- : undefined, -- target: ESBUILD_MODULES_TARGET, -- external, -- logLevel: 'error', -- splitting: true, -- sourcemap: true, -- outdir: processingCacheDir, -- ignoreAnnotations: true, -- metafile: true, -- plugins, -- charset: 'utf8', -- ...esbuildOptions, -- supported: { -- ...defaultEsbuildSupported, -- ...esbuildOptions.supported, -- }, -- }) -- return { context, idToExports } --} -- --export async function addManuallyIncludedOptimizeDeps( -- deps: Record, -- config: ResolvedConfig, -- ssr: boolean, --): Promise { -- const { logger } = config -- const optimizeDeps = getDepOptimizationConfig(config, ssr) -- const optimizeDepsInclude = optimizeDeps?.include ?? [] -- if (optimizeDepsInclude.length) { -- const unableToOptimize = (id: string, msg: string) => { -- if (optimizeDepsInclude.includes(id)) { -- logger.warn( -- `${msg}: ${colors.cyan(id)}, present in '${ -- ssr ? 'ssr.' : '' -- }optimizeDeps.include'`, -- ) -- } -- } -- -- const includes = [...optimizeDepsInclude] -- for (let i = 0; i < includes.length; i++) { -- const id = includes[i] -- if (glob.isDynamicPattern(id)) { -- const globIds = expandGlobIds(id, config) -- includes.splice(i, 1, ...globIds) -- i += globIds.length - 1 -- } -- } -- -- const resolve = createOptimizeDepsIncludeResolver(config, ssr) -- for (const id of includes) { -- // normalize 'foo >bar` as 'foo > bar' to prevent same id being added -- // and for pretty printing -- const normalizedId = normalizeId(id) -- if (!deps[normalizedId]) { -- const entry = await resolve(id) -- if (entry) { -- if (isOptimizable(entry, optimizeDeps)) { -- if (!entry.endsWith('?__vite_skip_optimization')) { -- deps[normalizedId] = entry -- } -- } else { -- unableToOptimize(id, 'Cannot optimize dependency') -- } -- } else { -- unableToOptimize(id, 'Failed to resolve dependency') -- } -- } -- } -- } --} -- - // Convert to { id: src } - export function depsFromOptimizedDepInfo( - depsInfo: Record, -@@ -1059,60 +423,6 @@ function esbuildOutputFromId( - } - } - --export async function extractExportsData( -- filePath: string, -- config: ResolvedConfig, -- ssr: boolean, --): Promise { -- await init -- -- const optimizeDeps = getDepOptimizationConfig(config, ssr) -- -- const esbuildOptions = optimizeDeps?.esbuildOptions ?? {} -- if (optimizeDeps.extensions?.some((ext) => filePath.endsWith(ext))) { -- // For custom supported extensions, build the entry file to transform it into JS, -- // and then parse with es-module-lexer. Note that the `bundle` option is not `true`, -- // so only the entry file is being transformed. -- const result = await build({ -- ...esbuildOptions, -- entryPoints: [filePath], -- write: false, -- format: 'esm', -- }) -- const [, exports, , hasModuleSyntax] = parse(result.outputFiles[0].text) -- return { -- hasModuleSyntax, -- exports: exports.map((e) => e.n), -- } -- } -- -- let parseResult: ReturnType -- let usedJsxLoader = false -- -- const entryContent = await fsp.readFile(filePath, 'utf-8') -- try { -- parseResult = parse(entryContent) -- } catch { -- const loader = esbuildOptions.loader?.[path.extname(filePath)] || 'jsx' -- debug?.( -- `Unable to parse: ${filePath}.\n Trying again with a ${loader} transform.`, -- ) -- const transformed = await transformWithEsbuild(entryContent, filePath, { -- loader, -- }) -- parseResult = parse(transformed.code) -- usedJsxLoader = true -- } -- -- const [, exports, , hasModuleSyntax] = parseResult -- const exportsData: ExportsData = { -- hasModuleSyntax, -- exports: exports.map((e) => e.n), -- jsxLoader: usedJsxLoader, -- } -- return exportsData --} -- - function needsInterop( - config: ResolvedConfig, - ssr: boolean, -diff --git a/packages/vite/src/node/optimizer/index/DepOptimizationConfig.ts b/packages/vite/src/node/optimizer/index/DepOptimizationConfig.ts -new file mode 100644 -index 000000000..5d0aa2a3c ---- /dev/null -+++ b/packages/vite/src/node/optimizer/index/DepOptimizationConfig.ts -@@ -0,0 +1,84 @@ -+import type { BuildContext, BuildOptions as EsbuildBuildOptions } from 'esbuild' -+ -+ -+export interface DepOptimizationConfig { -+ /** -+ * Force optimize listed dependencies (must be resolvable import paths, -+ * cannot be globs). -+ */ -+ include?: string[] -+ /** -+ * Do not optimize these dependencies (must be resolvable import paths, -+ * cannot be globs). -+ */ -+ exclude?: string[] -+ /** -+ * Forces ESM interop when importing these dependencies. Some legacy -+ * packages advertise themselves as ESM but use `require` internally -+ * @experimental -+ */ -+ needsInterop?: string[] -+ /** -+ * Options to pass to esbuild during the dep scanning and optimization -+ * -+ * Certain options are omitted since changing them would not be compatible -+ * with Vite's dep optimization. -+ * -+ * - `external` is also omitted, use Vite's `optimizeDeps.exclude` option -+ * - `plugins` are merged with Vite's dep plugin -+ * -+ * https://esbuild.github.io/api -+ */ -+ esbuildOptions?: Omit< -+ EsbuildBuildOptions, -+ | 'bundle' -+ | 'entryPoints' -+ | 'external' -+ | 'write' -+ | 'watch' -+ | 'outdir' -+ | 'outfile' -+ | 'outbase' -+ | 'outExtension' -+ | 'metafile' -+ > -+ /** -+ * List of file extensions that can be optimized. A corresponding esbuild -+ * plugin must exist to handle the specific extension. -+ * -+ * By default, Vite can optimize `.mjs`, `.js`, `.ts`, and `.mts` files. This option -+ * allows specifying additional extensions. -+ * -+ * @experimental -+ */ -+ extensions?: string[] -+ /** -+ * Deps optimization during build was removed in Vite 5.1. This option is -+ * now redundant and will be removed in a future version. Switch to using -+ * `optimizeDeps.noDiscovery` and an empty or undefined `optimizeDeps.include`. -+ * true or 'dev' disables the optimizer, false or 'build' leaves it enabled. -+ * @default 'build' -+ * @deprecated -+ * @experimental -+ */ -+ disabled?: boolean | 'build' | 'dev' -+ /** -+ * Automatic dependency discovery. When `noDiscovery` is true, only dependencies -+ * listed in `include` will be optimized. The scanner isn't run for cold start -+ * in this case. CJS-only dependencies must be present in `include` during dev. -+ * @default false -+ * @experimental -+ */ -+ noDiscovery?: boolean -+ /** -+ * When enabled, it will hold the first optimized deps results until all static -+ * imports are crawled on cold start. This avoids the need for full-page reloads -+ * when new dependencies are discovered and they trigger the generation of new -+ * common chunks. If all dependencies are found by the scanner plus the explicitly -+ * defined ones in `include`, it is better to disable this option to let the -+ * browser process more requests in parallel. -+ * @default true -+ * @experimental -+ */ -+ holdUntilCrawlEnd?: boolean -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/optimizer/index/addManuallyIncludedOptimizeDeps.ts b/packages/vite/src/node/optimizer/index/addManuallyIncludedOptimizeDeps.ts -new file mode 100644 -index 000000000..a3b0a80be ---- /dev/null -+++ b/packages/vite/src/node/optimizer/index/addManuallyIncludedOptimizeDeps.ts -@@ -0,0 +1,61 @@ -+import colors from 'picocolors' -+import glob from 'fast-glob' -+import { getDepOptimizationConfig } from 'packages/vite/src/node/config'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { isOptimizable } from 'packages/vite/src/node/utils'; -+import { normalizeId } from 'packages/vite/src/node/utils'; -+import { expandGlobIds } from 'packages/vite/src/node/optimizer/resolve'; -+import { createOptimizeDepsIncludeResolver } from 'packages/vite/src/node/optimizer/resolve'; -+ -+ -+export async function addManuallyIncludedOptimizeDeps( -+ deps: Record, -+ config: ResolvedConfig, -+ ssr: boolean, -+): Promise { -+ const { logger } = config -+ const optimizeDeps = getDepOptimizationConfig(config, ssr) -+ const optimizeDepsInclude = optimizeDeps?.include ?? [] -+ if (optimizeDepsInclude.length) { -+ const unableToOptimize = (id: string, msg: string) => { -+ if (optimizeDepsInclude.includes(id)) { -+ logger.warn( -+ `${msg}: ${colors.cyan(id)}, present in '${ -+ ssr ? 'ssr.' : '' -+ }optimizeDeps.include'`, -+ ) -+ } -+ } -+ -+ const includes = [...optimizeDepsInclude] -+ for (let i = 0; i < includes.length; i++) { -+ const id = includes[i] -+ if (glob.isDynamicPattern(id)) { -+ const globIds = expandGlobIds(id, config) -+ includes.splice(i, 1, ...globIds) -+ i += globIds.length - 1 -+ } -+ } -+ -+ const resolve = createOptimizeDepsIncludeResolver(config, ssr) -+ for (const id of includes) { -+ // normalize 'foo >bar` as 'foo > bar' to prevent same id being added -+ // and for pretty printing -+ const normalizedId = normalizeId(id) -+ if (!deps[normalizedId]) { -+ const entry = await resolve(id) -+ if (entry) { -+ if (isOptimizable(entry, optimizeDeps)) { -+ if (!entry.endsWith('?__vite_skip_optimization')) { -+ deps[normalizedId] = entry -+ } -+ } else { -+ unableToOptimize(id, 'Cannot optimize dependency') -+ } -+ } else { -+ unableToOptimize(id, 'Failed to resolve dependency') -+ } -+ } -+ } -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/optimizer/index/extractExportsData.ts b/packages/vite/src/node/optimizer/index/extractExportsData.ts -new file mode 100644 -index 000000000..a9540b5b9 ---- /dev/null -+++ b/packages/vite/src/node/optimizer/index/extractExportsData.ts -@@ -0,0 +1,70 @@ -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import esbuild, { build } from 'esbuild' -+import { init, parse } from 'es-module-lexer' -+import { getDepOptimizationConfig } from 'packages/vite/src/node/config'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { transformWithEsbuild } from 'packages/vite/src/node/plugins/esbuild'; -+ -+ -+export type ExportsData = { -+ hasModuleSyntax: boolean -+ // exported names (for `export { a as b }`, `b` is exported name) -+ exports: readonly string[] -+ // hint if the dep requires loading as jsx -+ jsxLoader?: boolean -+} -+ -+export async function extractExportsData( -+ filePath: string, -+ config: ResolvedConfig, -+ ssr: boolean, -+): Promise { -+ await init -+ -+ const optimizeDeps = getDepOptimizationConfig(config, ssr) -+ -+ const esbuildOptions = optimizeDeps?.esbuildOptions ?? {} -+ if (optimizeDeps.extensions?.some((ext) => filePath.endsWith(ext))) { -+ // For custom supported extensions, build the entry file to transform it into JS, -+ // and then parse with es-module-lexer. Note that the `bundle` option is not `true`, -+ // so only the entry file is being transformed. -+ const result = await build({ -+ ...esbuildOptions, -+ entryPoints: [filePath], -+ write: false, -+ format: 'esm', -+ }) -+ const [, exports, , hasModuleSyntax] = parse(result.outputFiles[0].text) -+ return { -+ hasModuleSyntax, -+ exports: exports.map((e) => e.n), -+ } -+ } -+ -+ let parseResult: ReturnType -+ let usedJsxLoader = false -+ -+ const entryContent = await fsp.readFile(filePath, 'utf-8') -+ try { -+ parseResult = parse(entryContent) -+ } catch { -+ const loader = esbuildOptions.loader?.[path.extname(filePath)] || 'jsx' -+ debug?.( -+ `Unable to parse: ${filePath}.\n Trying again with a ${loader} transform.`, -+ ) -+ const transformed = await transformWithEsbuild(entryContent, filePath, { -+ loader, -+ }) -+ parseResult = parse(transformed.code) -+ usedJsxLoader = true -+ } -+ -+ const [, exports, , hasModuleSyntax] = parseResult -+ const exportsData: ExportsData = { -+ hasModuleSyntax, -+ exports: exports.map((e) => e.n), -+ jsxLoader: usedJsxLoader, -+ } -+ return exportsData -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata.ts b/packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata.ts -new file mode 100644 -index 000000000..157298771 ---- /dev/null -+++ b/packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata.ts -@@ -0,0 +1,131 @@ -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import colors from 'picocolors' -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { METADATA_FILENAME } from 'packages/vite/src/node/constants'; -+import { createDebugger } from 'packages/vite/src/node/utils'; -+ -+ -+const export debug = createDebugger('vite:deps') -+ -+export type ExportsData = { -+ hasModuleSyntax: boolean -+ // exported names (for `export { a as b }`, `b` is exported name) -+ exports: readonly string[] -+ // hint if the dep requires loading as jsx -+ jsxLoader?: boolean -+} -+ -+export interface OptimizedDepInfo { -+ id: string -+ file: string -+ src?: string -+ needsInterop?: boolean -+ browserHash?: string -+ fileHash?: string -+ /** -+ * During optimization, ids can still be resolved to their final location -+ * but the bundles may not yet be saved to disk -+ */ -+ processing?: Promise -+ /** -+ * ExportData cache, discovered deps will parse the src entry to get exports -+ * data used both to define if interop is needed and when pre-bundling -+ */ -+ exportsData?: Promise -+} -+ -+export interface DepOptimizationMetadata { -+ /** -+ * The main hash is determined by user config and dependency lockfiles. -+ * This is checked on server startup to avoid unnecessary re-bundles. -+ */ -+ hash: string -+ /** -+ * This hash is determined by dependency lockfiles. -+ * This is checked on server startup to avoid unnecessary re-bundles. -+ */ -+ lockfileHash: string -+ /** -+ * This hash is determined by user config. -+ * This is checked on server startup to avoid unnecessary re-bundles. -+ */ -+ configHash: string -+ /** -+ * The browser hash is determined by the main hash plus additional dependencies -+ * discovered at runtime. This is used to invalidate browser requests to -+ * optimized deps. -+ */ -+ browserHash: string -+ /** -+ * Metadata for each already optimized dependency -+ */ -+ optimized: Record -+ /** -+ * Metadata for non-entry optimized chunks and dynamic imports -+ */ -+ chunks: Record -+ /** -+ * Metadata for each newly discovered dependency after processing -+ */ -+ discovered: Record -+ /** -+ * OptimizedDepInfo list -+ */ -+ depInfoList: OptimizedDepInfo[] -+} -+ -+/** -+ * Creates the initial dep optimization metadata, loading it from the deps cache -+ * if it exists and pre-bundling isn't forced -+ */ -+export async function loadCachedDepOptimizationMetadata( -+ config: ResolvedConfig, -+ ssr: boolean, -+ force = config.optimizeDeps.force, -+ asCommand = false, -+): Promise { -+ const log = asCommand ? config.logger.info : debug -+ -+ if (firstLoadCachedDepOptimizationMetadata) { -+ firstLoadCachedDepOptimizationMetadata = false -+ // Fire up a clean up of stale processing deps dirs if older process exited early -+ setTimeout(() => cleanupDepsCacheStaleDirs(config), 0) -+ } -+ -+ const depsCacheDir = getDepsCacheDir(config, ssr) -+ -+ if (!force) { -+ let cachedMetadata: DepOptimizationMetadata | undefined -+ try { -+ const cachedMetadataPath = path.join(depsCacheDir, METADATA_FILENAME) -+ cachedMetadata = parseDepsOptimizerMetadata( -+ await fsp.readFile(cachedMetadataPath, 'utf-8'), -+ depsCacheDir, -+ ) -+ } catch (e) {} -+ // hash is consistent, no need to re-bundle -+ if (cachedMetadata) { -+ if (cachedMetadata.lockfileHash !== getLockfileHash(config, ssr)) { -+ config.logger.info( -+ 'Re-optimizing dependencies because lockfile has changed', -+ ) -+ } else if (cachedMetadata.configHash !== getConfigHash(config, ssr)) { -+ config.logger.info( -+ 'Re-optimizing dependencies because vite config has changed', -+ ) -+ } else { -+ log?.('Hash is consistent. Skipping. Use --force to override.') -+ // Nothing to commit or cancel as we are using the cache, we only -+ // need to resolve the processing promise so requests can move on -+ return cachedMetadata -+ } -+ } -+ } else { -+ config.logger.info('Forced re-optimization of dependencies') -+ } -+ -+ // Start with a fresh cache -+ debug?.(colors.green(`removing old cache dir ${depsCacheDir}`)) -+ await fsp.rm(depsCacheDir, { recursive: true, force: true }) -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun.ts b/packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun.ts -new file mode 100644 -index 000000000..37e1560b0 ---- /dev/null -+++ b/packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun.ts -@@ -0,0 +1,136 @@ -+import type { BuildContext, BuildOptions as EsbuildBuildOptions } from 'esbuild' -+import esbuild, { build } from 'esbuild' -+import { getDepOptimizationConfig } from 'packages/vite/src/node/config'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { defaultEsbuildSupported } from 'packages/vite/src/node/plugins/esbuild'; -+import { ESBUILD_MODULES_TARGET } from 'packages/vite/src/node/constants'; -+import { esbuildCjsExternalPlugin } from 'packages/vite/src/node/optimizer/esbuildDepPlugin'; -+ -+ -+export type ExportsData = { -+ hasModuleSyntax: boolean -+ // exported names (for `export { a as b }`, `b` is exported name) -+ exports: readonly string[] -+ // hint if the dep requires loading as jsx -+ jsxLoader?: boolean -+} -+ -+export interface OptimizedDepInfo { -+ id: string -+ file: string -+ src?: string -+ needsInterop?: boolean -+ browserHash?: string -+ fileHash?: string -+ /** -+ * During optimization, ids can still be resolved to their final location -+ * but the bundles may not yet be saved to disk -+ */ -+ processing?: Promise -+ /** -+ * ExportData cache, discovered deps will parse the src entry to get exports -+ * data used both to define if interop is needed and when pre-bundling -+ */ -+ exportsData?: Promise -+} -+ -+export async function prepareEsbuildOptimizerRun( -+ resolvedConfig: ResolvedConfig, -+ depsInfo: Record, -+ ssr: boolean, -+ processingCacheDir: string, -+ optimizerContext: { cancelled: boolean }, -+): Promise<{ -+ context?: BuildContext -+ idToExports: Record -+}> { -+ const config: ResolvedConfig = { -+ ...resolvedConfig, -+ command: 'build', -+ } -+ -+ // esbuild generates nested directory output with lowest common ancestor base -+ // this is unpredictable and makes it difficult to analyze entry / output -+ // mapping. So what we do here is: -+ // 1. flatten all ids to eliminate slash -+ // 2. in the plugin, read the entry ourselves as virtual files to retain the -+ // path. -+ const flatIdDeps: Record = {} -+ const idToExports: Record = {} -+ -+ const optimizeDeps = getDepOptimizationConfig(config, ssr) -+ -+ const { plugins: pluginsFromConfig = [], ...esbuildOptions } = -+ optimizeDeps?.esbuildOptions ?? {} -+ -+ await Promise.all( -+ Object.keys(depsInfo).map(async (id) => { -+ const src = depsInfo[id].src! -+ const exportsData = await (depsInfo[id].exportsData ?? -+ extractExportsData(src, config, ssr)) -+ if (exportsData.jsxLoader && !esbuildOptions.loader?.['.js']) { -+ // Ensure that optimization won't fail by defaulting '.js' to the JSX parser. -+ // This is useful for packages such as Gatsby. -+ esbuildOptions.loader = { -+ '.js': 'jsx', -+ ...esbuildOptions.loader, -+ } -+ } -+ const flatId = flattenId(id) -+ flatIdDeps[flatId] = src -+ idToExports[id] = exportsData -+ }), -+ ) -+ -+ if (optimizerContext.cancelled) return { context: undefined, idToExports } -+ -+ const define = { -+ 'process.env.NODE_ENV': JSON.stringify(process.env.NODE_ENV || config.mode), -+ } -+ -+ const platform = -+ ssr && config.ssr?.target !== 'webworker' ? 'node' : 'browser' -+ -+ const external = [...(optimizeDeps?.exclude ?? [])] -+ -+ const plugins = [...pluginsFromConfig] -+ if (external.length) { -+ plugins.push(esbuildCjsExternalPlugin(external, platform)) -+ } -+ plugins.push(esbuildDepPlugin(flatIdDeps, external, config, ssr)) -+ -+ const context = await esbuild.context({ -+ absWorkingDir: process.cwd(), -+ entryPoints: Object.keys(flatIdDeps), -+ bundle: true, -+ // We can't use platform 'neutral', as esbuild has custom handling -+ // when the platform is 'node' or 'browser' that can't be emulated -+ // by using mainFields and conditions -+ platform, -+ define, -+ format: 'esm', -+ // See https://github.com/evanw/esbuild/issues/1921#issuecomment-1152991694 -+ banner: -+ platform === 'node' -+ ? { -+ js: `import { createRequire } from 'module';const require = createRequire(import.meta.url);`, -+ } -+ : undefined, -+ target: ESBUILD_MODULES_TARGET, -+ external, -+ logLevel: 'error', -+ splitting: true, -+ sourcemap: true, -+ outdir: processingCacheDir, -+ ignoreAnnotations: true, -+ metafile: true, -+ plugins, -+ charset: 'utf8', -+ ...esbuildOptions, -+ supported: { -+ ...defaultEsbuildSupported, -+ ...esbuildOptions.supported, -+ }, -+ }) -+ return { context, idToExports } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/optimizer/index/runOptimizeDeps.ts b/packages/vite/src/node/optimizer/index/runOptimizeDeps.ts -new file mode 100644 -index 000000000..02743a788 ---- /dev/null -+++ b/packages/vite/src/node/optimizer/index/runOptimizeDeps.ts -@@ -0,0 +1,381 @@ -+import fs from 'node:fs' -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import { performance } from 'node:perf_hooks' -+import colors from 'picocolors' -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { getHash } from 'packages/vite/src/node/utils'; -+import { METADATA_FILENAME } from 'packages/vite/src/node/constants'; -+import { isWindows } from 'packages/vite/src/shared/utils'; -+ -+ -+const export jsMapExtensionRE = /\.js\.map$/i -+ -+export type ExportsData = { -+ hasModuleSyntax: boolean -+ // exported names (for `export { a as b }`, `b` is exported name) -+ exports: readonly string[] -+ // hint if the dep requires loading as jsx -+ jsxLoader?: boolean -+} -+ -+export interface OptimizedDepInfo { -+ id: string -+ file: string -+ src?: string -+ needsInterop?: boolean -+ browserHash?: string -+ fileHash?: string -+ /** -+ * During optimization, ids can still be resolved to their final location -+ * but the bundles may not yet be saved to disk -+ */ -+ processing?: Promise -+ /** -+ * ExportData cache, discovered deps will parse the src entry to get exports -+ * data used both to define if interop is needed and when pre-bundling -+ */ -+ exportsData?: Promise -+} -+ -+export interface DepOptimizationMetadata { -+ /** -+ * The main hash is determined by user config and dependency lockfiles. -+ * This is checked on server startup to avoid unnecessary re-bundles. -+ */ -+ hash: string -+ /** -+ * This hash is determined by dependency lockfiles. -+ * This is checked on server startup to avoid unnecessary re-bundles. -+ */ -+ lockfileHash: string -+ /** -+ * This hash is determined by user config. -+ * This is checked on server startup to avoid unnecessary re-bundles. -+ */ -+ configHash: string -+ /** -+ * The browser hash is determined by the main hash plus additional dependencies -+ * discovered at runtime. This is used to invalidate browser requests to -+ * optimized deps. -+ */ -+ browserHash: string -+ /** -+ * Metadata for each already optimized dependency -+ */ -+ optimized: Record -+ /** -+ * Metadata for non-entry optimized chunks and dynamic imports -+ */ -+ chunks: Record -+ /** -+ * Metadata for each newly discovered dependency after processing -+ */ -+ discovered: Record -+ /** -+ * OptimizedDepInfo list -+ */ -+ depInfoList: OptimizedDepInfo[] -+} -+ -+export interface DepOptimizationResult { -+ metadata: DepOptimizationMetadata -+ /** -+ * When doing a re-run, if there are newly discovered dependencies -+ * the page reload will be delayed until the next rerun so we need -+ * to be able to discard the result -+ */ -+ commit: () => Promise -+ cancel: () => void -+} -+ -+export function initDepsOptimizerMetadata( -+ config: ResolvedConfig, -+ ssr: boolean, -+ timestamp?: string, -+): DepOptimizationMetadata { -+ const { lockfileHash, configHash, hash } = getDepHash(config, ssr) -+ return { -+ hash, -+ lockfileHash, -+ configHash, -+ browserHash: getOptimizedBrowserHash(hash, {}, timestamp), -+ optimized: {}, -+ chunks: {}, -+ discovered: {}, -+ depInfoList: [], -+ } -+} -+ -+export function addOptimizedDepInfo( -+ metadata: DepOptimizationMetadata, -+ type: 'optimized' | 'discovered' | 'chunks', -+ depInfo: OptimizedDepInfo, -+): OptimizedDepInfo { -+ metadata[type][depInfo.id] = depInfo -+ metadata.depInfoList.push(depInfo) -+ return depInfo -+} -+ -+/** -+ * Internally, Vite uses this function to prepare a optimizeDeps run. When Vite starts, we can get -+ * the metadata and start the server without waiting for the optimizeDeps processing to be completed -+ */ -+export function runOptimizeDeps( -+ resolvedConfig: ResolvedConfig, -+ depsInfo: Record, -+ ssr: boolean, -+): { -+ cancel: () => Promise -+ result: Promise -+} { -+ const optimizerContext = { cancelled: false } -+ -+ const config: ResolvedConfig = { -+ ...resolvedConfig, -+ command: 'build', -+ } -+ -+ const depsCacheDir = getDepsCacheDir(resolvedConfig, ssr) -+ const processingCacheDir = getProcessingDepsCacheDir(resolvedConfig, ssr) -+ -+ // Create a temporary directory so we don't need to delete optimized deps -+ // until they have been processed. This also avoids leaving the deps cache -+ // directory in a corrupted state if there is an error -+ fs.mkdirSync(processingCacheDir, { recursive: true }) -+ -+ // a hint for Node.js -+ // all files in the cache directory should be recognized as ES modules -+ debug?.(colors.green(`creating package.json in ${processingCacheDir}`)) -+ fs.writeFileSync( -+ path.resolve(processingCacheDir, 'package.json'), -+ `{\n "type": "module"\n}\n`, -+ ) -+ -+ const metadata = initDepsOptimizerMetadata(config, ssr) -+ -+ metadata.browserHash = getOptimizedBrowserHash( -+ metadata.hash, -+ depsFromOptimizedDepInfo(depsInfo), -+ ) -+ -+ // We prebundle dependencies with esbuild and cache them, but there is no need -+ // to wait here. Code that needs to access the cached deps needs to await -+ // the optimizedDepInfo.processing promise for each dep -+ -+ const qualifiedIds = Object.keys(depsInfo) -+ let cleaned = false -+ let committed = false -+ const cleanUp = () => { -+ // If commit was already called, ignore the clean up even if a cancel was requested -+ // This minimizes the chances of leaving the deps cache in a corrupted state -+ if (!cleaned && !committed) { -+ cleaned = true -+ // No need to wait, we can clean up in the background because temp folders -+ // are unique per run -+ debug?.(colors.green(`removing cache dir ${processingCacheDir}`)) -+ try { -+ // When exiting the process, `fsp.rm` may not take effect, so we use `fs.rmSync` -+ fs.rmSync(processingCacheDir, { recursive: true, force: true }) -+ } catch (error) { -+ // Ignore errors -+ } -+ } -+ } -+ -+ const successfulResult: DepOptimizationResult = { -+ metadata, -+ cancel: cleanUp, -+ commit: async () => { -+ if (cleaned) { -+ throw new Error( -+ 'Can not commit a Deps Optimization run as it was cancelled', -+ ) -+ } -+ // Ignore clean up requests after this point so the temp folder isn't deleted before -+ // we finish committing the new deps cache files to the deps folder -+ committed = true -+ -+ // Write metadata file, then commit the processing folder to the global deps cache -+ // Rewire the file paths from the temporary processing dir to the final deps cache dir -+ const dataPath = path.join(processingCacheDir, METADATA_FILENAME) -+ debug?.( -+ colors.green(`creating ${METADATA_FILENAME} in ${processingCacheDir}`), -+ ) -+ fs.writeFileSync( -+ dataPath, -+ stringifyDepsOptimizerMetadata(metadata, depsCacheDir), -+ ) -+ -+ // In order to minimize the time where the deps folder isn't in a consistent state, -+ // we first rename the old depsCacheDir to a temporary path, then we rename the -+ // new processing cache dir to the depsCacheDir. In systems where doing so in sync -+ // is safe, we do an atomic operation (at least for this thread). For Windows, we -+ // found there are cases where the rename operation may finish before it's done -+ // so we do a graceful rename checking that the folder has been properly renamed. -+ // We found that the rename-rename (then delete the old folder in the background) -+ // is safer than a delete-rename operation. -+ const temporaryPath = depsCacheDir + getTempSuffix() -+ const depsCacheDirPresent = fs.existsSync(depsCacheDir) -+ if (isWindows) { -+ if (depsCacheDirPresent) { -+ debug?.(colors.green(`renaming ${depsCacheDir} to ${temporaryPath}`)) -+ await safeRename(depsCacheDir, temporaryPath) -+ } -+ debug?.( -+ colors.green(`renaming ${processingCacheDir} to ${depsCacheDir}`), -+ ) -+ await safeRename(processingCacheDir, depsCacheDir) -+ } else { -+ if (depsCacheDirPresent) { -+ debug?.(colors.green(`renaming ${depsCacheDir} to ${temporaryPath}`)) -+ fs.renameSync(depsCacheDir, temporaryPath) -+ } -+ debug?.( -+ colors.green(`renaming ${processingCacheDir} to ${depsCacheDir}`), -+ ) -+ fs.renameSync(processingCacheDir, depsCacheDir) -+ } -+ -+ // Delete temporary path in the background -+ if (depsCacheDirPresent) { -+ debug?.(colors.green(`removing cache temp dir ${temporaryPath}`)) -+ fsp.rm(temporaryPath, { recursive: true, force: true }) -+ } -+ }, -+ } -+ -+ if (!qualifiedIds.length) { -+ // No deps to optimize, we still commit the processing cache dir to remove -+ // the previous optimized deps if they exist, and let the next server start -+ // skip the scanner step if the lockfile hasn't changed -+ return { -+ cancel: async () => cleanUp(), -+ result: Promise.resolve(successfulResult), -+ } -+ } -+ -+ const cancelledResult: DepOptimizationResult = { -+ metadata, -+ commit: async () => cleanUp(), -+ cancel: cleanUp, -+ } -+ -+ const start = performance.now() -+ -+ const preparedRun = prepareEsbuildOptimizerRun( -+ resolvedConfig, -+ depsInfo, -+ ssr, -+ processingCacheDir, -+ optimizerContext, -+ ) -+ -+ const runResult = preparedRun.then(({ context, idToExports }) => { -+ function disposeContext() { -+ return context?.dispose().catch((e) => { -+ config.logger.error('Failed to dispose esbuild context', { error: e }) -+ }) -+ } -+ if (!context || optimizerContext.cancelled) { -+ disposeContext() -+ return cancelledResult -+ } -+ -+ return context -+ .rebuild() -+ .then((result) => { -+ const meta = result.metafile! -+ -+ // the paths in `meta.outputs` are relative to `process.cwd()` -+ const processingCacheDirOutputPath = path.relative( -+ process.cwd(), -+ processingCacheDir, -+ ) -+ -+ for (const id in depsInfo) { -+ const output = esbuildOutputFromId( -+ meta.outputs, -+ id, -+ processingCacheDir, -+ ) -+ -+ const { exportsData, ...info } = depsInfo[id] -+ addOptimizedDepInfo(metadata, 'optimized', { -+ ...info, -+ // We only need to hash the output.imports in to check for stability, but adding the hash -+ // and file path gives us a unique hash that may be useful for other things in the future -+ fileHash: getHash( -+ metadata.hash + -+ depsInfo[id].file + -+ JSON.stringify(output.imports), -+ ), -+ browserHash: metadata.browserHash, -+ // After bundling we have more information and can warn the user about legacy packages -+ // that require manual configuration -+ needsInterop: needsInterop( -+ config, -+ ssr, -+ id, -+ idToExports[id], -+ output, -+ ), -+ }) -+ } -+ -+ for (const o of Object.keys(meta.outputs)) { -+ if (!jsMapExtensionRE.test(o)) { -+ const id = path -+ .relative(processingCacheDirOutputPath, o) -+ .replace(jsExtensionRE, '') -+ const file = getOptimizedDepPath(id, resolvedConfig, ssr) -+ if ( -+ !findOptimizedDepInfoInRecord( -+ metadata.optimized, -+ (depInfo) => depInfo.file === file, -+ ) -+ ) { -+ addOptimizedDepInfo(metadata, 'chunks', { -+ id, -+ file, -+ needsInterop: false, -+ browserHash: metadata.browserHash, -+ }) -+ } -+ } -+ } -+ -+ debug?.( -+ `Dependencies bundled in ${(performance.now() - start).toFixed(2)}ms`, -+ ) -+ -+ return successfulResult -+ }) -+ -+ .catch((e) => { -+ if (e.errors && e.message.includes('The build was canceled')) { -+ // esbuild logs an error when cancelling, but this is expected so -+ // return an empty result instead -+ return cancelledResult -+ } -+ throw e -+ }) -+ .finally(() => { -+ return disposeContext() -+ }) -+ }) -+ -+ runResult.catch(() => { -+ cleanUp() -+ }) -+ -+ return { -+ async cancel() { -+ optimizerContext.cancelled = true -+ const { context } = await preparedRun -+ await context?.cancel() -+ cleanUp() -+ }, -+ result: runResult, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/optimizer/optimizer.ts b/packages/vite/src/node/optimizer/optimizer.ts -index 3f76e480a..046df3009 100644 ---- a/packages/vite/src/node/optimizer/optimizer.ts -+++ b/packages/vite/src/node/optimizer/optimizer.ts -@@ -1,25 +1,29 @@ -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { loadCachedDepOptimizationMetadata } from 'packages/vite/src/node/optimizer/index/loadCachedDepOptimizationMetadata'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { DepOptimizationResult } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { initDepsOptimizerMetadata } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { addOptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { runOptimizeDeps } from 'packages/vite/src/node/optimizer/index/runOptimizeDeps'; -+import { OptimizedDepInfo } from 'packages/vite/src/node/optimizer/index/prepareEsbuildOptimizerRun'; -+import { addManuallyIncludedOptimizeDeps } from 'packages/vite/src/node/optimizer/index/addManuallyIncludedOptimizeDeps'; -+import { extractExportsData } from 'packages/vite/src/node/optimizer/index/extractExportsData'; - import colors from 'picocolors' - import { createDebugger, getHash, promiseWithResolvers } from '../utils' - import type { PromiseWithResolvers } from '../utils' - import { getDepOptimizationConfig } from '../config' - import type { ResolvedConfig, ViteDevServer } from '..' - import { -- addManuallyIncludedOptimizeDeps, -- addOptimizedDepInfo, -+ , - createIsOptimizedDepFile, - createIsOptimizedDepUrl, - depsFromOptimizedDepInfo, - depsLogString, - discoverProjectDependencies, -- extractExportsData, - getOptimizedDepPath, -- initDepsOptimizerMetadata, -- loadCachedDepOptimizationMetadata, - optimizeServerSsrDeps, -- runOptimizeDeps, - toDiscoveredDependencies, - } from '.' --import type { DepOptimizationResult, DepsOptimizer, OptimizedDepInfo } from '.' - - const debug = createDebugger('vite:deps') - -diff --git a/packages/vite/src/node/optimizer/resolve.ts b/packages/vite/src/node/optimizer/resolve.ts -index 822b19e18..5e2041a30 100644 ---- a/packages/vite/src/node/optimizer/resolve.ts -+++ b/packages/vite/src/node/optimizer/resolve.ts -@@ -1,7 +1,7 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; - import path from 'node:path' - import glob from 'fast-glob' - import micromatch from 'micromatch' --import type { ResolvedConfig } from '../config' - import { escapeRegex, getNpmPackageName } from '../utils' - import { resolvePackageData } from '../packages' - import { slash } from '../../shared/utils' -diff --git a/packages/vite/src/node/plugin.ts b/packages/vite/src/node/plugin.ts -index 8ef2f2281..2240dd95d 100644 ---- a/packages/vite/src/node/plugin.ts -+++ b/packages/vite/src/node/plugin.ts -@@ -1,3 +1,10 @@ -+import { UserConfig } from 'packages/vite/src/node/config/UserConfig'; -+import { ConfigEnv } from 'packages/vite/src/node/config/resolveConfig'; -+import { UserConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ConfigEnv } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { UserConfig } from 'packages/vite/src/node/config/loadConfigFromFile'; -+import { HmrContext } from 'packages/vite/src/node/server/hmr/handleHMRUpdate'; - import type { - CustomPluginOptions, - LoadResult, -@@ -9,11 +16,9 @@ import type { - TransformResult, - } from 'rollup' - export type { PluginContext } from 'rollup' --import type { ConfigEnv, ResolvedConfig, UserConfig } from './config' - import type { ServerHook } from './server' - import type { IndexHtmlTransform } from './plugins/html' - import type { ModuleNode } from './server/moduleGraph' --import type { HmrContext } from './server/hmr' - import type { PreviewServerHook } from './preview' - - /** -diff --git a/packages/vite/src/node/plugins/asset.ts b/packages/vite/src/node/plugins/asset.ts -index 1415211d5..0e4221cf0 100644 ---- a/packages/vite/src/node/plugins/asset.ts -+++ b/packages/vite/src/node/plugins/asset.ts -@@ -1,3 +1,4 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; - import path from 'node:path' - import { parse as parseUrl } from 'node:url' - import fsp from 'node:fs/promises' -@@ -15,7 +16,6 @@ import { - toOutputFilePathInJS, - } from '../build' - import type { Plugin } from '../plugin' --import type { ResolvedConfig } from '../config' - import { checkPublicFile } from '../publicDir' - import { - encodeURIPath, -diff --git a/packages/vite/src/node/plugins/assetImportMetaUrl.ts b/packages/vite/src/node/plugins/assetImportMetaUrl.ts -index 588f6e08b..924658276 100644 ---- a/packages/vite/src/node/plugins/assetImportMetaUrl.ts -+++ b/packages/vite/src/node/plugins/assetImportMetaUrl.ts -@@ -1,16 +1,21 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryCleanFsResolve'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/tryNodeResolve'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolvePackageEntry'; -+import { tryFsResolve } from 'packages/vite/src/node/plugins/resolve/resolvePackageEntry'; -+import { InternalResolveOptions } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; -+import { tryFsResolve } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; - import path from 'node:path' - import MagicString from 'magic-string' - import { stripLiteral } from 'strip-literal' - import type { Plugin } from '../plugin' --import type { ResolvedConfig } from '../config' - import type { ResolveFn } from '../' - import { injectQuery, isParentDirectory, transformStableResult } from '../utils' - import { CLIENT_ENTRY } from '../constants' - import { slash } from '../../shared/utils' - import { fileToUrl } from './asset' - import { preloadHelperId } from './importAnalysisBuild' --import type { InternalResolveOptions } from './resolve' --import { tryFsResolve } from './resolve' - import { hasViteIgnoreRE } from './importAnalysis' - - /** -diff --git a/packages/vite/src/node/plugins/clientInjections.ts b/packages/vite/src/node/plugins/clientInjections.ts -index c66f3877e..accf28907 100644 ---- a/packages/vite/src/node/plugins/clientInjections.ts -+++ b/packages/vite/src/node/plugins/clientInjections.ts -@@ -1,6 +1,6 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; - import path from 'node:path' - import type { Plugin } from '../plugin' --import type { ResolvedConfig } from '../config' - import { CLIENT_ENTRY, ENV_ENTRY } from '../constants' - import { isObject, normalizePath, resolveHostname } from '../utils' - import { replaceDefine, serializeDefine } from './define' -diff --git a/packages/vite/src/node/plugins/css.ts b/packages/vite/src/node/plugins/css.ts -index 26ba17c19..9d7570f37 100644 ---- a/packages/vite/src/node/plugins/css.ts -+++ b/packages/vite/src/node/plugins/css.ts -@@ -1,3 +1,135 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { commonjsProxyRE } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { cssModuleRE } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isModuleCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { cssModulesCache } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { removedPureCssFilesCache } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { cssPlugin } from 'packages/vite/src/node/plugins/css'; -+import { cssModuleRE } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { commonjsProxyRE } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { styleAttrRE } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { transformOnlyRE } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { cssBundleName } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { directRequestRE } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isDirectCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { cssUrlAssetRE } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { cssPostPlugin } from 'packages/vite/src/node/plugins/css'; -+import { commonjsProxyRE } from 'packages/vite/src/node/plugins/css/cssAnalysisPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssAnalysisPlugin'; -+import { cssAnalysisPlugin } from 'packages/vite/src/node/plugins/css'; -+import { PreprocessLang } from 'packages/vite/src/node/plugins/css/compileCSSPreprocessors'; -+import { compileCSSPreprocessors } from 'packages/vite/src/node/plugins/css/compileCSSPreprocessors'; -+import { compileCSSPreprocessors } from 'packages/vite/src/node/plugins/css'; -+import { cssModuleRE } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { PreprocessLang } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { PostCssDialectLang } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { CssLang } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { getCssResolversKeys } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { compileCSSPreprocessors } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { getAtImportResolvers } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { compileCSS } from 'packages/vite/src/node/plugins/css/compileCSS'; -+import { compileCSS } from 'packages/vite/src/node/plugins/css'; -+import { decoder } from 'packages/vite/src/node/plugins/css/minifyCSS'; -+import { cssBundleName } from 'packages/vite/src/node/plugins/css/minifyCSS'; -+import { minifyCSS } from 'packages/vite/src/node/plugins/css/minifyCSS'; -+import { minifyCSS } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { SassStylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { cleanScssBugUrl } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { fixScssBugImportValue } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { makeScssWorker } from 'packages/vite/src/node/plugins/css/makeScssWorker'; -+import { makeScssWorker } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/makeModernScssWorker'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/makeModernScssWorker'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeModernScssWorker'; -+import { SassStylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeModernScssWorker'; -+import { cleanScssBugUrl } from 'packages/vite/src/node/plugins/css/makeModernScssWorker'; -+import { makeModernScssWorker } from 'packages/vite/src/node/plugins/css/makeModernScssWorker'; -+import { makeModernScssWorker } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { SassStylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { cleanScssBugUrl } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { makeModernScssWorker } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { makeModernCompilerScssWorker } from 'packages/vite/src/node/plugins/css/makeModernCompilerScssWorker'; -+import { makeModernCompilerScssWorker } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { SassStylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { StylePreprocessorResults } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { SassStylePreprocessor } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { cleanScssBugUrl } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { makeScssWorker } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { makeModernScssWorker } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { makeModernCompilerScssWorker } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { scssProcessor } from 'packages/vite/src/node/plugins/css/scssProcessor'; -+import { scssProcessor } from 'packages/vite/src/node/plugins/css'; -+import { cssUrlRE } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { cssDataUriRE } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { importCssRE } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { CssUrlReplacer } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { rewriteCssUrls } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { rewriteCssDataUris } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { rewriteImportCss } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { rebaseUrls } from 'packages/vite/src/node/plugins/css/rebaseUrls'; -+import { rebaseUrls } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { cssUrlRE } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { cssDataUriRE } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { importCssRE } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { CssUrlReplacer } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { rewriteCssUrls } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { rebaseUrls } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { makeLessWorker } from 'packages/vite/src/node/plugins/css/makeLessWorker'; -+import { makeLessWorker } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { StylePreprocessorResults } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { StylePreprocessor } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { cssUrlRE } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { cssDataUriRE } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { importCssRE } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { CssUrlReplacer } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { rewriteCssUrls } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { rebaseUrls } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { makeLessWorker } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { lessProcessor } from 'packages/vite/src/node/plugins/css/lessProcessor'; -+import { lessProcessor } from 'packages/vite/src/node/plugins/css'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/makeStylWorker'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeStylWorker'; -+import { StylusStylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/makeStylWorker'; -+import { makeStylWorker } from 'packages/vite/src/node/plugins/css/makeStylWorker'; -+import { makeStylWorker } from 'packages/vite/src/node/plugins/css'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { PreprocessorAdditionalData } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { StylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { StylusStylePreprocessorOptions } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { StylePreprocessorResults } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { StylusStylePreprocessor } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { makeStylWorker } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { stylProcessor } from 'packages/vite/src/node/plugins/css/stylProcessor'; -+import { stylProcessor } from 'packages/vite/src/node/plugins/css'; -+import { decoder } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { cssModuleRE } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { styleAttrRE } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { PreprocessLang } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { PostCssDialectLang } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { CSSAtImportResolvers } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { getAtImportResolvers } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { compileCSS } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { CssUrlReplacer } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { skipUrlReplacer } from 'packages/vite/src/node/plugins/css/compileLightningCSS'; -+import { compileLightningCSS } from 'packages/vite/src/node/plugins/css'; - import fs from 'node:fs' - import fsp from 'node:fs/promises' - import path from 'node:path' -@@ -42,7 +174,6 @@ import { - ESBUILD_MODULES_TARGET, - SPECIAL_QUERY_RE, - } from '../constants' --import type { ResolvedConfig } from '../config' - import type { Plugin } from '../plugin' - import { checkPublicFile } from '../publicDir' - import { -@@ -84,8 +215,6 @@ import { - } from './asset' - import type { ESBuildOptions } from './esbuild' - import { getChunkOriginalFileName } from './manifest' -- --const decoder = new TextDecoder() - // const debug = createDebugger('vite:css') - - export interface CSSOptions { -@@ -185,61 +314,21 @@ export function resolveCSSOptions( - } - return { ...options, lightningcss: undefined } - } -- --const cssModuleRE = new RegExp(`\\.module${CSS_LANGS_RE.source}`) --const directRequestRE = /[?&]direct\b/ - const htmlProxyRE = /[?&]html-proxy\b/ - const htmlProxyIndexRE = /&index=(\d+)/ --const commonjsProxyRE = /\?commonjs-proxy/ - const inlineRE = /[?&]inline\b/ - const inlineCSSRE = /[?&]inline-css\b/ --const styleAttrRE = /[?&]style-attr\b/ --const functionCallRE = /^[A-Z_][\w-]*\(/i --const transformOnlyRE = /[?&]transform-only\b/ - const nonEscapedDoubleQuoteRe = /(? -- CSS_LANGS_RE.test(request) -+export - --export const isModuleCSSRequest = (request: string): boolean => -- cssModuleRE.test(request) -+export - --export const isDirectCSSRequest = (request: string): boolean => -- CSS_LANGS_RE.test(request) && directRequestRE.test(request) -+export - - export const isDirectRequest = (request: string): boolean => - directRequestRE.test(request) - --const cssModulesCache = new WeakMap< -- ResolvedConfig, -- Map> -->() -- --export const removedPureCssFilesCache = new WeakMap< -- ResolvedConfig, -- Map -->() -- - const postcssConfigCache = new WeakMap< - ResolvedConfig, - PostCSSConfigResult | null | Promise -@@ -249,784 +338,6 @@ function encodePublicUrlsInCSS(config: ResolvedConfig) { - return config.command === 'build' - } - --const cssUrlAssetRE = /__VITE_CSS_URL__([\da-f]+)__/g -- --/** -- * Plugin applied before user plugins -- */ --export function cssPlugin(config: ResolvedConfig): Plugin { -- const isBuild = config.command === 'build' -- let moduleCache: Map> -- -- const resolveUrl = config.createResolver({ -- preferRelative: true, -- tryIndex: false, -- extensions: [], -- }) -- -- let preprocessorWorkerController: PreprocessorWorkerController | undefined -- -- // warm up cache for resolved postcss config -- if (config.css?.transformer !== 'lightningcss') { -- resolvePostcssConfig(config) -- } -- -- return { -- name: 'vite:css', -- -- buildStart() { -- // Ensure a new cache for every build (i.e. rebuilding in watch mode) -- moduleCache = new Map>() -- cssModulesCache.set(config, moduleCache) -- -- removedPureCssFilesCache.set(config, new Map()) -- -- preprocessorWorkerController = createPreprocessorWorkerController( -- normalizeMaxWorkers(config.css.preprocessorMaxWorkers), -- ) -- preprocessorWorkerControllerCache.set( -- config, -- preprocessorWorkerController, -- ) -- }, -- -- buildEnd() { -- preprocessorWorkerController?.close() -- }, -- -- async load(id) { -- if (!isCSSRequest(id)) return -- -- if (urlRE.test(id)) { -- if (isModuleCSSRequest(id)) { -- throw new Error( -- `?url is not supported with CSS modules. (tried to import ${JSON.stringify( -- id, -- )})`, -- ) -- } -- -- // *.css?url -- // in dev, it's handled by assets plugin. -- if (isBuild) { -- id = injectQuery(removeUrlQuery(id), 'transform-only') -- return ( -- `import ${JSON.stringify(id)};` + -- `export default "__VITE_CSS_URL__${Buffer.from(id).toString( -- 'hex', -- )}__"` -- ) -- } -- } -- }, -- -- async transform(raw, id) { -- if ( -- !isCSSRequest(id) || -- commonjsProxyRE.test(id) || -- SPECIAL_QUERY_RE.test(id) -- ) { -- return -- } -- const urlReplacer: CssUrlReplacer = async (url, importer) => { -- const decodedUrl = decodeURI(url) -- if (checkPublicFile(decodedUrl, config)) { -- if (encodePublicUrlsInCSS(config)) { -- return publicFileToBuiltUrl(decodedUrl, config) -- } else { -- return joinUrlSegments(config.base, decodedUrl) -- } -- } -- const [id, fragment] = decodedUrl.split('#') -- let resolved = await resolveUrl(id, importer) -- if (resolved) { -- if (fragment) resolved += '#' + fragment -- return fileToUrl(resolved, config, this) -- } -- if (config.command === 'build') { -- const isExternal = config.build.rollupOptions.external -- ? resolveUserExternal( -- config.build.rollupOptions.external, -- decodedUrl, // use URL as id since id could not be resolved -- id, -- false, -- ) -- : false -- -- if (!isExternal) { -- // #9800 If we cannot resolve the css url, leave a warning. -- config.logger.warnOnce( -- `\n${decodedUrl} referenced in ${id} didn't resolve at build time, it will remain unchanged to be resolved at runtime`, -- ) -- } -- } -- return url -- } -- -- const { -- code: css, -- modules, -- deps, -- map, -- } = await compileCSS( -- id, -- raw, -- config, -- preprocessorWorkerController!, -- urlReplacer, -- ) -- if (modules) { -- moduleCache.set(id, modules) -- } -- -- if (deps) { -- for (const file of deps) { -- this.addWatchFile(file) -- } -- } -- -- return { -- code: css, -- map, -- } -- }, -- } --} -- --/** -- * Plugin applied after user plugins -- */ --export function cssPostPlugin(config: ResolvedConfig): Plugin { -- // styles initialization in buildStart causes a styling loss in watch -- const styles: Map = new Map() -- // queue to emit css serially to guarantee the files are emitted in a deterministic order -- let codeSplitEmitQueue = createSerialPromiseQueue() -- const urlEmitQueue = createSerialPromiseQueue() -- let pureCssChunks: Set -- -- // when there are multiple rollup outputs and extracting CSS, only emit once, -- // since output formats have no effect on the generated CSS. -- let hasEmitted = false -- let chunkCSSMap: Map -- -- const rollupOptionsOutput = config.build.rollupOptions.output -- const assetFileNames = ( -- Array.isArray(rollupOptionsOutput) -- ? rollupOptionsOutput[0] -- : rollupOptionsOutput -- )?.assetFileNames -- const getCssAssetDirname = (cssAssetName: string) => { -- const cssAssetNameDir = path.dirname(cssAssetName) -- if (!assetFileNames) { -- return path.join(config.build.assetsDir, cssAssetNameDir) -- } else if (typeof assetFileNames === 'string') { -- return path.join(path.dirname(assetFileNames), cssAssetNameDir) -- } else { -- return path.dirname( -- assetFileNames({ -- name: cssAssetName, -- type: 'asset', -- source: '/* vite internal call, ignore */', -- }), -- ) -- } -- } -- -- return { -- name: 'vite:css-post', -- -- renderStart() { -- // Ensure new caches for every build (i.e. rebuilding in watch mode) -- pureCssChunks = new Set() -- hasEmitted = false -- chunkCSSMap = new Map() -- codeSplitEmitQueue = createSerialPromiseQueue() -- }, -- -- async transform(css, id, options) { -- if ( -- !isCSSRequest(id) || -- commonjsProxyRE.test(id) || -- SPECIAL_QUERY_RE.test(id) -- ) { -- return -- } -- -- css = stripBomTag(css) -- -- // cache css compile result to map -- // and then use the cache replace inline-style-flag -- // when `generateBundle` in vite:build-html plugin and devHtmlHook -- const inlineCSS = inlineCSSRE.test(id) -- const isHTMLProxy = htmlProxyRE.test(id) -- if (inlineCSS && isHTMLProxy) { -- if (styleAttrRE.test(id)) { -- css = css.replace(/"/g, '"') -- } -- const index = htmlProxyIndexRE.exec(id)?.[1] -- if (index == null) { -- throw new Error(`HTML proxy index in "${id}" not found`) -- } -- addToHTMLProxyTransformResult( -- `${getHash(cleanUrl(id))}_${Number.parseInt(index)}`, -- css, -- ) -- return `export default ''` -- } -- -- const inlined = inlineRE.test(id) -- const modules = cssModulesCache.get(config)!.get(id) -- -- // #6984, #7552 -- // `foo.module.css` => modulesCode -- // `foo.module.css?inline` => cssContent -- const modulesCode = -- modules && -- !inlined && -- dataToEsm(modules, { namedExports: true, preferConst: true }) -- -- if (config.command === 'serve') { -- const getContentWithSourcemap = async (content: string) => { -- if (config.css?.devSourcemap) { -- const sourcemap = this.getCombinedSourcemap() -- if (sourcemap.mappings) { -- await injectSourcesContent(sourcemap, cleanUrl(id), config.logger) -- } -- return getCodeWithSourcemap('css', content, sourcemap) -- } -- return content -- } -- -- if (isDirectCSSRequest(id)) { -- return null -- } -- // server only -- if (options?.ssr) { -- return modulesCode || `export default ${JSON.stringify(css)}` -- } -- if (inlined) { -- return `export default ${JSON.stringify(css)}` -- } -- -- const cssContent = await getContentWithSourcemap(css) -- const code = [ -- `import { updateStyle as __vite__updateStyle, removeStyle as __vite__removeStyle } from ${JSON.stringify( -- path.posix.join(config.base, CLIENT_PUBLIC_PATH), -- )}`, -- `const __vite__id = ${JSON.stringify(id)}`, -- `const __vite__css = ${JSON.stringify(cssContent)}`, -- `__vite__updateStyle(__vite__id, __vite__css)`, -- // css modules exports change on edit so it can't self accept -- `${modulesCode || 'import.meta.hot.accept()'}`, -- `import.meta.hot.prune(() => __vite__removeStyle(__vite__id))`, -- ].join('\n') -- return { code, map: { mappings: '' } } -- } -- -- // build CSS handling ---------------------------------------------------- -- -- // record css -- if (!inlined) { -- styles.set(id, css) -- } -- -- let code: string -- if (modulesCode) { -- code = modulesCode -- } else if (inlined) { -- let content = css -- if (config.build.cssMinify) { -- content = await minifyCSS(content, config, true) -- } -- code = `export default ${JSON.stringify(content)}` -- } else { -- // empty module when it's not a CSS module nor `?inline` -- code = '' -- } -- -- return { -- code, -- map: { mappings: '' }, -- // avoid the css module from being tree-shaken so that we can retrieve -- // it in renderChunk() -- moduleSideEffects: modulesCode || inlined ? false : 'no-treeshake', -- } -- }, -- -- async renderChunk(code, chunk, opts) { -- let chunkCSS = '' -- // the chunk is empty if it's a dynamic entry chunk that only contains a CSS import -- const isJsChunkEmpty = code === '' && !chunk.isEntry -- let isPureCssChunk = true -- const ids = Object.keys(chunk.modules) -- for (const id of ids) { -- if (styles.has(id)) { -- // ?transform-only is used for ?url and shouldn't be included in normal CSS chunks -- if (!transformOnlyRE.test(id)) { -- chunkCSS += styles.get(id) -- // a css module contains JS, so it makes this not a pure css chunk -- if (cssModuleRE.test(id)) { -- isPureCssChunk = false -- } -- } -- } else if (!isJsChunkEmpty) { -- // if the module does not have a style, then it's not a pure css chunk. -- // this is true because in the `transform` hook above, only modules -- // that are css gets added to the `styles` map. -- isPureCssChunk = false -- } -- } -- -- const publicAssetUrlMap = publicAssetUrlCache.get(config)! -- -- // resolve asset URL placeholders to their built file URLs -- const resolveAssetUrlsInCss = ( -- chunkCSS: string, -- cssAssetName: string, -- ) => { -- const encodedPublicUrls = encodePublicUrlsInCSS(config) -- -- const relative = config.base === './' || config.base === '' -- const cssAssetDirname = -- encodedPublicUrls || relative -- ? slash(getCssAssetDirname(cssAssetName)) -- : undefined -- -- const toRelative = (filename: string) => { -- // relative base + extracted CSS -- const relativePath = path.posix.relative(cssAssetDirname!, filename) -- return relativePath[0] === '.' ? relativePath : './' + relativePath -- } -- -- // replace asset url references with resolved url. -- chunkCSS = chunkCSS.replace(assetUrlRE, (_, fileHash, postfix = '') => { -- const filename = this.getFileName(fileHash) + postfix -- chunk.viteMetadata!.importedAssets.add(cleanUrl(filename)) -- return encodeURIPath( -- toOutputFilePathInCss( -- filename, -- 'asset', -- cssAssetName, -- 'css', -- config, -- toRelative, -- ), -- ) -- }) -- // resolve public URL from CSS paths -- if (encodedPublicUrls) { -- const relativePathToPublicFromCSS = path.posix.relative( -- cssAssetDirname!, -- '', -- ) -- chunkCSS = chunkCSS.replace(publicAssetUrlRE, (_, hash) => { -- const publicUrl = publicAssetUrlMap.get(hash)!.slice(1) -- return encodeURIPath( -- toOutputFilePathInCss( -- publicUrl, -- 'public', -- cssAssetName, -- 'css', -- config, -- () => `${relativePathToPublicFromCSS}/${publicUrl}`, -- ), -- ) -- }) -- } -- return chunkCSS -- } -- -- function ensureFileExt(name: string, ext: string) { -- return normalizePath( -- path.format({ ...path.parse(name), base: undefined, ext }), -- ) -- } -- -- let s: MagicString | undefined -- const urlEmitTasks: Array<{ -- cssAssetName: string -- originalFilename: string -- content: string -- start: number -- end: number -- }> = [] -- -- if (code.includes('__VITE_CSS_URL__')) { -- let match: RegExpExecArray | null -- cssUrlAssetRE.lastIndex = 0 -- while ((match = cssUrlAssetRE.exec(code))) { -- const [full, idHex] = match -- const id = Buffer.from(idHex, 'hex').toString() -- const originalFilename = cleanUrl(id) -- const cssAssetName = ensureFileExt( -- path.basename(originalFilename), -- '.css', -- ) -- if (!styles.has(id)) { -- throw new Error( -- `css content for ${JSON.stringify(id)} was not found`, -- ) -- } -- -- let cssContent = styles.get(id)! -- -- cssContent = resolveAssetUrlsInCss(cssContent, cssAssetName) -- -- urlEmitTasks.push({ -- cssAssetName, -- originalFilename, -- content: cssContent, -- start: match.index, -- end: match.index + full.length, -- }) -- } -- } -- -- // should await even if this chunk does not include __VITE_CSS_URL__ -- // so that code after this line runs in the same order -- await urlEmitQueue.run(async () => -- Promise.all( -- urlEmitTasks.map(async (info) => { -- info.content = await finalizeCss(info.content, true, config) -- }), -- ), -- ) -- if (urlEmitTasks.length > 0) { -- const toRelativeRuntime = createToImportMetaURLBasedRelativeRuntime( -- opts.format, -- config.isWorker, -- ) -- s ||= new MagicString(code) -- -- for (const { -- cssAssetName, -- originalFilename, -- content, -- start, -- end, -- } of urlEmitTasks) { -- const referenceId = this.emitFile({ -- name: cssAssetName, -- type: 'asset', -- source: content, -- }) -- generatedAssets -- .get(config)! -- .set(referenceId, { originalName: originalFilename }) -- -- const filename = this.getFileName(referenceId) -- chunk.viteMetadata!.importedAssets.add(cleanUrl(filename)) -- const replacement = toOutputFilePathInJS( -- filename, -- 'asset', -- chunk.fileName, -- 'js', -- config, -- toRelativeRuntime, -- ) -- const replacementString = -- typeof replacement === 'string' -- ? JSON.stringify(encodeURIPath(replacement)).slice(1, -1) -- : `"+${replacement.runtime}+"` -- s.update(start, end, replacementString) -- } -- } -- -- if (chunkCSS) { -- if (isPureCssChunk && (opts.format === 'es' || opts.format === 'cjs')) { -- // this is a shared CSS-only chunk that is empty. -- pureCssChunks.add(chunk) -- } -- -- if (config.build.cssCodeSplit) { -- if (opts.format === 'es' || opts.format === 'cjs') { -- const isEntry = chunk.isEntry && isPureCssChunk -- const cssFullAssetName = ensureFileExt(chunk.name, '.css') -- // if facadeModuleId doesn't exist or doesn't have a CSS extension, -- // that means a JS entry file imports a CSS file. -- // in this case, only use the filename for the CSS chunk name like JS chunks. -- const cssAssetName = -- chunk.isEntry && -- (!chunk.facadeModuleId || !isCSSRequest(chunk.facadeModuleId)) -- ? path.basename(cssFullAssetName) -- : cssFullAssetName -- const originalFilename = getChunkOriginalFileName( -- chunk, -- config.root, -- opts.format, -- ) -- -- chunkCSS = resolveAssetUrlsInCss(chunkCSS, cssAssetName) -- -- // wait for previous tasks as well -- chunkCSS = await codeSplitEmitQueue.run(async () => { -- return finalizeCss(chunkCSS, true, config) -- }) -- -- // emit corresponding css file -- const referenceId = this.emitFile({ -- name: cssAssetName, -- type: 'asset', -- source: chunkCSS, -- }) -- generatedAssets -- .get(config)! -- .set(referenceId, { originalName: originalFilename, isEntry }) -- chunk.viteMetadata!.importedCss.add(this.getFileName(referenceId)) -- } else if (!config.build.ssr) { -- // legacy build and inline css -- -- // Entry chunk CSS will be collected into `chunk.viteMetadata.importedCss` -- // and injected later by the `'vite:build-html'` plugin into the `index.html` -- // so it will be duplicated. (https://github.com/vitejs/vite/issues/2062#issuecomment-782388010) -- // But because entry chunk can be imported by dynamic import, -- // we shouldn't remove the inlined CSS. (#10285) -- -- chunkCSS = await finalizeCss(chunkCSS, true, config) -- let cssString = JSON.stringify(chunkCSS) -- cssString = -- renderAssetUrlInJS( -- this, -- config, -- chunk, -- opts, -- cssString, -- )?.toString() || cssString -- const style = `__vite_style__` -- const injectCode = -- `var ${style} = document.createElement('style');` + -- `${style}.textContent = ${cssString};` + -- `document.head.appendChild(${style});` -- let injectionPoint -- const wrapIdx = code.indexOf('System.register') -- if (wrapIdx >= 0) { -- const executeFnStart = code.indexOf('execute:', wrapIdx) -- injectionPoint = code.indexOf('{', executeFnStart) + 1 -- } else { -- const insertMark = "'use strict';" -- injectionPoint = code.indexOf(insertMark) + insertMark.length -- } -- s ||= new MagicString(code) -- s.appendRight(injectionPoint, injectCode) -- } -- } else { -- // resolve public URL from CSS paths, we need to use absolute paths -- chunkCSS = resolveAssetUrlsInCss(chunkCSS, cssBundleName) -- // finalizeCss is called for the aggregated chunk in generateBundle -- -- chunkCSSMap.set(chunk.fileName, chunkCSS) -- } -- } -- -- if (s) { -- if (config.build.sourcemap) { -- return { -- code: s.toString(), -- map: s.generateMap({ hires: 'boundary' }), -- } -- } else { -- return { code: s.toString() } -- } -- } -- return null -- }, -- -- augmentChunkHash(chunk) { -- if (chunk.viteMetadata?.importedCss.size) { -- let hash = '' -- for (const id of chunk.viteMetadata.importedCss) { -- hash += id -- } -- return hash -- } -- }, -- -- async generateBundle(opts, bundle) { -- // @ts-expect-error asset emits are skipped in legacy bundle -- if (opts.__vite_skip_asset_emit__) { -- return -- } -- -- function extractCss() { -- let css = '' -- const collected = new Set() -- // will be populated in order they are used by entry points -- const dynamicImports = new Set() -- -- function collect(chunk: OutputChunk | OutputAsset) { -- if (!chunk || chunk.type !== 'chunk' || collected.has(chunk)) return -- collected.add(chunk) -- -- // First collect all styles from the synchronous imports (lowest priority) -- chunk.imports.forEach((importName) => collect(bundle[importName])) -- // Save dynamic imports in deterministic order to add the styles later (to have the highest priority) -- chunk.dynamicImports.forEach((importName) => -- dynamicImports.add(importName), -- ) -- // Then collect the styles of the current chunk (might overwrite some styles from previous imports) -- css += chunkCSSMap.get(chunk.preliminaryFileName) ?? '' -- } -- -- // The bundle is guaranteed to be deterministic, if not then we have a bug in rollup. -- // So we use it to ensure a deterministic order of styles -- for (const chunk of Object.values(bundle)) { -- if (chunk.type === 'chunk' && chunk.isEntry) { -- collect(chunk) -- } -- } -- // Now collect the dynamic chunks, this is done last to have the styles overwrite the previous ones -- for (const chunkName of dynamicImports) { -- collect(bundle[chunkName]) -- } -- -- return css -- } -- let extractedCss = !hasEmitted && extractCss() -- if (extractedCss) { -- hasEmitted = true -- extractedCss = await finalizeCss(extractedCss, true, config) -- this.emitFile({ -- name: cssBundleName, -- type: 'asset', -- source: extractedCss, -- }) -- } -- -- // remove empty css chunks and their imports -- if (pureCssChunks.size) { -- // map each pure css chunk (rendered chunk) to it's corresponding bundle -- // chunk. we check that by `preliminaryFileName` as they have different -- // `filename`s (rendered chunk has the !~{XXX}~ placeholder) -- const prelimaryNameToChunkMap = Object.fromEntries( -- Object.values(bundle) -- .filter((chunk): chunk is OutputChunk => chunk.type === 'chunk') -- .map((chunk) => [chunk.preliminaryFileName, chunk.fileName]), -- ) -- -- // When running in watch mode the generateBundle is called once per output format -- // in this case the `bundle` is not populated with the other output files -- // but they are still in `pureCssChunks`. -- // So we need to filter the names and only use those who are defined -- const pureCssChunkNames = [...pureCssChunks] -- .map((pureCssChunk) => prelimaryNameToChunkMap[pureCssChunk.fileName]) -- .filter(Boolean) -- -- const replaceEmptyChunk = getEmptyChunkReplacer( -- pureCssChunkNames, -- opts.format, -- ) -- -- for (const file in bundle) { -- const chunk = bundle[file] -- if (chunk.type === 'chunk') { -- let chunkImportsPureCssChunk = false -- // remove pure css chunk from other chunk's imports, -- // and also register the emitted CSS files under the importer -- // chunks instead. -- chunk.imports = chunk.imports.filter((file) => { -- if (pureCssChunkNames.includes(file)) { -- const { importedCss, importedAssets } = ( -- bundle[file] as OutputChunk -- ).viteMetadata! -- importedCss.forEach((file) => -- chunk.viteMetadata!.importedCss.add(file), -- ) -- importedAssets.forEach((file) => -- chunk.viteMetadata!.importedAssets.add(file), -- ) -- chunkImportsPureCssChunk = true -- return false -- } -- return true -- }) -- if (chunkImportsPureCssChunk) { -- chunk.code = replaceEmptyChunk(chunk.code) -- } -- } -- } -- -- const removedPureCssFiles = removedPureCssFilesCache.get(config)! -- pureCssChunkNames.forEach((fileName) => { -- removedPureCssFiles.set(fileName, bundle[fileName] as RenderedChunk) -- delete bundle[fileName] -- delete bundle[`${fileName}.map`] -- }) -- } -- }, -- } --} -- --export function cssAnalysisPlugin(config: ResolvedConfig): Plugin { -- let server: ViteDevServer -- -- return { -- name: 'vite:css-analysis', -- -- configureServer(_server) { -- server = _server -- }, -- -- async transform(_, id, options) { -- if ( -- !isCSSRequest(id) || -- commonjsProxyRE.test(id) || -- SPECIAL_QUERY_RE.test(id) -- ) { -- return -- } -- -- const ssr = options?.ssr === true -- const { moduleGraph } = server -- const thisModule = moduleGraph.getModuleById(id) -- -- // Handle CSS @import dependency HMR and other added modules via this.addWatchFile. -- // JS-related HMR is handled in the import-analysis plugin. -- if (thisModule) { -- // CSS modules cannot self-accept since it exports values -- const isSelfAccepting = -- !cssModulesCache.get(config)?.get(id) && -- !inlineRE.test(id) && -- !htmlProxyRE.test(id) -- // attached by pluginContainer.addWatchFile -- const pluginImports = (this as unknown as TransformPluginContext) -- ._addedImports -- if (pluginImports) { -- // record deps in the module graph so edits to @import css can trigger -- // main import to hot update -- const depModules = new Set() -- const devBase = config.base -- for (const file of pluginImports) { -- depModules.add( -- isCSSRequest(file) -- ? moduleGraph.createFileOnlyEntry(file) -- : await moduleGraph.ensureEntryFromUrl( -- stripBase( -- await fileToUrl(file, config, this), -- (config.server?.origin ?? '') + devBase, -- ), -- ssr, -- ), -- ) -- } -- moduleGraph.updateModuleInfo( -- thisModule, -- depModules, -- null, -- // The root CSS proxy module is self-accepting and should not -- // have an explicit accept list -- new Set(), -- null, -- isSelfAccepting, -- ssr, -- ) -- } else { -- thisModule.isSelfAccepting = isSelfAccepting -- } -- } -- }, -- } --} -- - /** - * Create a replacer function that takes code and replaces given pure CSS chunk imports - * @param pureCssChunkNames The chunks that only contain pure CSS and should be replaced -@@ -1062,408 +373,6 @@ export function getEmptyChunkReplacer( - ) - } - --interface CSSAtImportResolvers { -- css: ResolveFn -- sass: ResolveFn -- less: ResolveFn --} -- --function createCSSResolvers(config: ResolvedConfig): CSSAtImportResolvers { -- let cssResolve: ResolveFn | undefined -- let sassResolve: ResolveFn | undefined -- let lessResolve: ResolveFn | undefined -- return { -- get css() { -- return ( -- cssResolve || -- (cssResolve = config.createResolver({ -- extensions: ['.css'], -- mainFields: ['style'], -- conditions: ['style'], -- tryIndex: false, -- preferRelative: true, -- })) -- ) -- }, -- -- get sass() { -- return ( -- sassResolve || -- (sassResolve = config.createResolver({ -- extensions: ['.scss', '.sass', '.css'], -- mainFields: ['sass', 'style'], -- conditions: ['sass', 'style'], -- tryIndex: true, -- tryPrefix: '_', -- preferRelative: true, -- })) -- ) -- }, -- -- get less() { -- return ( -- lessResolve || -- (lessResolve = config.createResolver({ -- extensions: ['.less', '.css'], -- mainFields: ['less', 'style'], -- conditions: ['less', 'style'], -- tryIndex: false, -- preferRelative: true, -- })) -- ) -- }, -- } --} -- --function getCssResolversKeys( -- resolvers: CSSAtImportResolvers, --): Array { -- return Object.keys(resolvers) as unknown as Array --} -- --async function compileCSSPreprocessors( -- id: string, -- lang: PreprocessLang, -- code: string, -- config: ResolvedConfig, -- workerController: PreprocessorWorkerController, --): Promise<{ code: string; map?: ExistingRawSourceMap; deps?: Set }> { -- const { preprocessorOptions, devSourcemap } = config.css ?? {} -- const atImportResolvers = getAtImportResolvers(config) -- -- const preProcessor = workerController[lang] -- let opts = (preprocessorOptions && preprocessorOptions[lang]) || {} -- // support @import from node dependencies by default -- switch (lang) { -- case PreprocessLang.scss: -- case PreprocessLang.sass: -- opts = { -- includePaths: ['node_modules'], -- alias: config.resolve.alias, -- ...opts, -- } -- break -- case PreprocessLang.less: -- case PreprocessLang.styl: -- case PreprocessLang.stylus: -- opts = { -- paths: ['node_modules'], -- alias: config.resolve.alias, -- ...opts, -- } -- } -- // important: set this for relative import resolving -- opts.filename = cleanUrl(id) -- opts.enableSourcemap = devSourcemap ?? false -- -- const preprocessResult = await preProcessor( -- code, -- config.root, -- opts, -- atImportResolvers, -- ) -- if (preprocessResult.error) { -- throw preprocessResult.error -- } -- -- let deps: Set | undefined -- if (preprocessResult.deps) { -- const normalizedFilename = normalizePath(opts.filename) -- // sometimes sass registers the file itself as a dep -- deps = new Set( -- [...preprocessResult.deps].filter( -- (dep) => normalizePath(dep) !== normalizedFilename, -- ), -- ) -- } -- -- return { -- code: preprocessResult.code, -- map: combineSourcemapsIfExists( -- opts.filename, -- preprocessResult.map, -- preprocessResult.additionalMap, -- ), -- deps, -- } --} -- --const configToAtImportResolvers = new WeakMap< -- ResolvedConfig, -- CSSAtImportResolvers -->() --function getAtImportResolvers(config: ResolvedConfig) { -- let atImportResolvers = configToAtImportResolvers.get(config) -- if (!atImportResolvers) { -- atImportResolvers = createCSSResolvers(config) -- configToAtImportResolvers.set(config, atImportResolvers) -- } -- return atImportResolvers --} -- --async function compileCSS( -- id: string, -- code: string, -- config: ResolvedConfig, -- workerController: PreprocessorWorkerController, -- urlReplacer?: CssUrlReplacer, --): Promise<{ -- code: string -- map?: SourceMapInput -- ast?: PostCSS.Result -- modules?: Record -- deps?: Set --}> { -- if (config.css?.transformer === 'lightningcss') { -- return compileLightningCSS(id, code, config, urlReplacer) -- } -- -- const { modules: modulesOptions, devSourcemap } = config.css || {} -- const isModule = modulesOptions !== false && cssModuleRE.test(id) -- // although at serve time it can work without processing, we do need to -- // crawl them in order to register watch dependencies. -- const needInlineImport = code.includes('@import') -- const hasUrl = cssUrlRE.test(code) || cssImageSetRE.test(code) -- const lang = CSS_LANGS_RE.exec(id)?.[1] as CssLang | undefined -- const postcssConfig = await resolvePostcssConfig(config) -- -- // 1. plain css that needs no processing -- if ( -- lang === 'css' && -- !postcssConfig && -- !isModule && -- !needInlineImport && -- !hasUrl -- ) { -- return { code, map: null } -- } -- -- let modules: Record | undefined -- const deps = new Set() -- -- // 2. pre-processors: sass etc. -- let preprocessorMap: ExistingRawSourceMap | undefined -- if (isPreProcessor(lang)) { -- const preprocessorResult = await compileCSSPreprocessors( -- id, -- lang, -- code, -- config, -- workerController, -- ) -- code = preprocessorResult.code -- preprocessorMap = preprocessorResult.map -- preprocessorResult.deps?.forEach((dep) => deps.add(dep)) -- } -- -- // 3. postcss -- const atImportResolvers = getAtImportResolvers(config) -- const postcssOptions = (postcssConfig && postcssConfig.options) || {} -- -- const postcssPlugins = -- postcssConfig && postcssConfig.plugins ? postcssConfig.plugins.slice() : [] -- -- if (needInlineImport) { -- postcssPlugins.unshift( -- (await importPostcssImport()).default({ -- async resolve(id, basedir) { -- const publicFile = checkPublicFile(id, config) -- if (publicFile) { -- return publicFile -- } -- -- const resolved = await atImportResolvers.css( -- id, -- path.join(basedir, '*'), -- ) -- -- if (resolved) { -- return path.resolve(resolved) -- } -- -- // postcss-import falls back to `resolve` dep if this is unresolved, -- // but we've shimmed to remove the `resolve` dep to cut on bundle size. -- // warn here to provide a better error message. -- if (!path.isAbsolute(id)) { -- config.logger.error( -- colors.red( -- `Unable to resolve \`@import "${id}"\` from ${basedir}`, -- ), -- ) -- } -- -- return id -- }, -- async load(id) { -- const code = await fs.promises.readFile(id, 'utf-8') -- const lang = CSS_LANGS_RE.exec(id)?.[1] as CssLang | undefined -- if (isPreProcessor(lang)) { -- const result = await compileCSSPreprocessors( -- id, -- lang, -- code, -- config, -- workerController, -- ) -- result.deps?.forEach((dep) => deps.add(dep)) -- // TODO: support source map -- return result.code -- } -- return code -- }, -- nameLayer(index) { -- return `vite--anon-layer-${getHash(id)}-${index}` -- }, -- }), -- ) -- } -- -- if (urlReplacer) { -- postcssPlugins.push( -- UrlRewritePostcssPlugin({ -- replacer: urlReplacer, -- logger: config.logger, -- }), -- ) -- } -- -- if (isModule) { -- postcssPlugins.unshift( -- (await importPostcssModules()).default({ -- ...modulesOptions, -- localsConvention: modulesOptions?.localsConvention, -- getJSON( -- cssFileName: string, -- _modules: Record, -- outputFileName: string, -- ) { -- modules = _modules -- if (modulesOptions && typeof modulesOptions.getJSON === 'function') { -- modulesOptions.getJSON(cssFileName, _modules, outputFileName) -- } -- }, -- async resolve(id: string, importer: string) { -- for (const key of getCssResolversKeys(atImportResolvers)) { -- const resolved = await atImportResolvers[key](id, importer) -- if (resolved) { -- return path.resolve(resolved) -- } -- } -- -- return id -- }, -- }), -- ) -- } -- -- if (!postcssPlugins.length) { -- return { -- code, -- map: preprocessorMap, -- deps, -- } -- } -- -- let postcssResult: PostCSS.Result -- try { -- const source = removeDirectQuery(id) -- const postcss = await importPostcss() -- // postcss is an unbundled dep and should be lazy imported -- postcssResult = await postcss.default(postcssPlugins).process(code, { -- ...postcssOptions, -- parser: lang === 'sss' ? loadSss(config.root) : postcssOptions.parser, -- to: source, -- from: source, -- ...(devSourcemap -- ? { -- map: { -- inline: false, -- annotation: false, -- // postcss may return virtual files -- // we cannot obtain content of them, so this needs to be enabled -- sourcesContent: true, -- // when "prev: preprocessorMap", the result map may include duplicate filename in `postcssResult.map.sources` -- // prev: preprocessorMap, -- }, -- } -- : {}), -- }) -- -- // record CSS dependencies from @imports -- for (const message of postcssResult.messages) { -- if (message.type === 'dependency') { -- deps.add(normalizePath(message.file as string)) -- } else if (message.type === 'dir-dependency') { -- // https://github.com/postcss/postcss/blob/main/docs/guidelines/plugin.md#3-dependencies -- const { dir, glob: globPattern = '**' } = message -- const pattern = -- glob.escapePath(normalizePath(path.resolve(path.dirname(id), dir))) + -- `/` + -- globPattern -- const files = glob.sync(pattern, { -- ignore: ['**/node_modules/**'], -- }) -- for (let i = 0; i < files.length; i++) { -- deps.add(files[i]) -- } -- } else if (message.type === 'warning') { -- const warning = message as PostCSS.Warning -- let msg = `[vite:css] ${warning.text}` -- msg += `\n${generateCodeFrame( -- code, -- { -- line: warning.line, -- column: warning.column - 1, // 1-based -- }, -- warning.endLine !== undefined && warning.endColumn !== undefined -- ? { -- line: warning.endLine, -- column: warning.endColumn - 1, // 1-based -- } -- : undefined, -- )}` -- config.logger.warn(colors.yellow(msg)) -- } -- } -- } catch (e) { -- e.message = `[postcss] ${e.message}` -- e.code = code -- e.loc = { -- file: e.file, -- line: e.line, -- column: e.column - 1, // 1-based -- } -- throw e -- } -- -- if (!devSourcemap) { -- return { -- ast: postcssResult, -- code: postcssResult.css, -- map: { mappings: '' }, -- modules, -- deps, -- } -- } -- -- const rawPostcssMap = postcssResult.map.toJSON() -- -- const postcssMap = await formatPostcssSourceMap( -- // version property of rawPostcssMap is declared as string -- // but actually it is a number -- rawPostcssMap as Omit as ExistingRawSourceMap, -- cleanUrl(id), -- ) -- -- return { -- ast: postcssResult, -- code: postcssResult.css, -- map: combineSourcemapsIfExists(cleanUrl(id), postcssMap, preprocessorMap), -- modules, -- deps, -- } --} -- - function createCachedImport(imp: () => Promise): () => T | Promise { - let cached: T | Promise - return () => { -@@ -1622,17 +531,6 @@ async function resolvePostcssConfig( - postcssConfigCache.set(config, result) - return result - } -- --type CssUrlReplacer = ( -- url: string, -- importer?: string, --) => string | Promise --// https://drafts.csswg.org/css-syntax-3/#identifier-code-point --export const cssUrlRE = -- /(?<=^|[^\w\-\u0080-\uffff])url\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ --export const cssDataUriRE = -- /(?<=^|[^\w\-\u0080-\uffff])data-uri\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ --export const importCssRE = /@import ('[^']+\.css'|"[^"]+\.css"|[^'")]+\.css)/ - // Assuming a function name won't be longer than 256 chars - // eslint-disable-next-line regexp/no-unused-capturing-group -- doesn't detect asyncReplace usage - const cssImageSetRE = /(?<=image-set\()((?:[\w\-]{1,256}\([^)]*\)|[^)])*)(?=\))/ -@@ -1685,36 +583,6 @@ const UrlRewritePostcssPlugin: PostCSS.PluginCreator<{ - } - UrlRewritePostcssPlugin.postcss = true - --function rewriteCssUrls( -- css: string, -- replacer: CssUrlReplacer, --): Promise { -- return asyncReplace(css, cssUrlRE, async (match) => { -- const [matched, rawUrl] = match -- return await doUrlReplace(rawUrl.trim(), matched, replacer) -- }) --} -- --function rewriteCssDataUris( -- css: string, -- replacer: CssUrlReplacer, --): Promise { -- return asyncReplace(css, cssDataUriRE, async (match) => { -- const [matched, rawUrl] = match -- return await doUrlReplace(rawUrl.trim(), matched, replacer, 'data-uri') -- }) --} -- --function rewriteImportCss( -- css: string, -- replacer: CssUrlReplacer, --): Promise { -- return asyncReplace(css, importCssRE, async (match) => { -- const [matched, rawUrl] = match -- return await doImportCSSReplace(rawUrl, matched, replacer) -- }) --} -- - // TODO: image and cross-fade could contain a "url" that needs to be processed - // https://drafts.csswg.org/css-images-4/#image-notation - // https://drafts.csswg.org/css-images-4/#cross-fade-function -@@ -1739,14 +607,6 @@ async function rewriteCssImageSet( - return url - }) - } --function skipUrlReplacer(rawUrl: string) { -- return ( -- isExternalUrl(rawUrl) || -- isDataUrl(rawUrl) || -- rawUrl[0] === '#' || -- functionCallRE.test(rawUrl) -- ) --} - async function doUrlReplace( - rawUrl: string, - matched: string, -@@ -1799,65 +659,6 @@ async function doImportCSSReplace( - return `@import ${wrap}${await replacer(rawUrl)}${wrap}` - } - --async function minifyCSS( -- css: string, -- config: ResolvedConfig, -- inlined: boolean, --) { -- // We want inlined CSS to not end with a linebreak, while ensuring that -- // regular CSS assets do end with a linebreak. -- // See https://github.com/vitejs/vite/pull/13893#issuecomment-1678628198 -- -- if (config.build.cssMinify === 'lightningcss') { -- const { code, warnings } = (await importLightningCSS()).transform({ -- ...config.css?.lightningcss, -- targets: convertTargets(config.build.cssTarget), -- cssModules: undefined, -- filename: cssBundleName, -- code: Buffer.from(css), -- minify: true, -- }) -- if (warnings.length) { -- config.logger.warn( -- colors.yellow( -- `warnings when minifying css:\n${warnings -- .map((w) => w.message) -- .join('\n')}`, -- ), -- ) -- } -- -- // NodeJS res.code = Buffer -- // Deno res.code = Uint8Array -- // For correct decode compiled css need to use TextDecoder -- // LightningCSS output does not return a linebreak at the end -- return decoder.decode(code) + (inlined ? '' : '\n') -- } -- try { -- const { code, warnings } = await transform(css, { -- loader: 'css', -- target: config.build.cssTarget || undefined, -- ...resolveMinifyCssEsbuildOptions(config.esbuild || {}), -- }) -- if (warnings.length) { -- const msgs = await formatMessages(warnings, { kind: 'warning' }) -- config.logger.warn( -- colors.yellow(`warnings when minifying css:\n${msgs.join('\n')}`), -- ) -- } -- // esbuild output does return a linebreak at the end -- return inlined ? code.trimEnd() : code -- } catch (e) { -- if (e.errors) { -- e.message = '[esbuild css minify] ' + e.message -- const msgs = await formatMessages(e.errors, { kind: 'error' }) -- e.frame = '\n' + msgs.join('\n') -- e.loc = e.errors[0].location -- } -- throw e -- } --} -- - function resolveMinifyCssEsbuildOptions( - options: ESBuildOptions, - ): TransformOptions { -@@ -1923,75 +724,6 @@ export async function hoistAtRules(css: string): Promise { - - // Preprocessor support. This logic is largely replicated from @vue/compiler-sfc - --type PreprocessorAdditionalDataResult = -- | string -- | { content: string; map?: ExistingRawSourceMap } -- --type PreprocessorAdditionalData = -- | string -- | (( -- source: string, -- filename: string, -- ) => -- | PreprocessorAdditionalDataResult -- | Promise) -- --type StylePreprocessorOptions = { -- [key: string]: any -- additionalData?: PreprocessorAdditionalData -- maxWorkers?: number | true -- filename: string -- alias: Alias[] -- enableSourcemap: boolean --} -- --type SassStylePreprocessorOptions = StylePreprocessorOptions & -- Omit, 'data' | 'file' | 'outFile'> & { -- api?: 'legacy' | 'modern' | 'modern-compiler' -- } -- --type StylusStylePreprocessorOptions = StylePreprocessorOptions & { -- define?: Record --} -- --type StylePreprocessor = { -- process: ( -- source: string, -- root: string, -- options: StylePreprocessorOptions, -- resolvers: CSSAtImportResolvers, -- ) => StylePreprocessorResults | Promise -- close: () => void --} -- --type SassStylePreprocessor = { -- process: ( -- source: string, -- root: string, -- options: SassStylePreprocessorOptions, -- resolvers: CSSAtImportResolvers, -- ) => StylePreprocessorResults | Promise -- close: () => void --} -- --type StylusStylePreprocessor = { -- process: ( -- source: string, -- root: string, -- options: StylusStylePreprocessorOptions, -- resolvers: CSSAtImportResolvers, -- ) => StylePreprocessorResults | Promise -- close: () => void --} -- --export interface StylePreprocessorResults { -- code: string -- map?: ExistingRawSourceMap | undefined -- additionalMap?: ExistingRawSourceMap | undefined -- error?: RollupError -- deps: string[] --} -- - const loadedPreprocessorPath: Partial< - Record - > = {} -@@ -2053,774 +785,14 @@ function loadSss(root: string) { - declare const window: unknown | undefined - declare const location: { href: string } | undefined - --// in unix, scss might append `location.href` in environments that shim `location` --// see https://github.com/sass/dart-sass/issues/710 --function cleanScssBugUrl(url: string) { -- if ( -- // check bug via `window` and `location` global -- typeof window !== 'undefined' && -- typeof location !== 'undefined' && -- typeof location?.href === 'string' -- ) { -- const prefix = location.href.replace(/\/$/, '') -- return url.replace(prefix, '') -- } else { -- return url -- } --} -- --function fixScssBugImportValue( -- data: Sass.LegacyImporterResult, --): Sass.LegacyImporterResult { -- // the scss bug doesn't load files properly so we have to load it ourselves -- // to prevent internal error when it loads itself -- if ( -- // check bug via `window` and `location` global -- typeof window !== 'undefined' && -- typeof location !== 'undefined' && -- data && -- 'file' in data && -- (!('contents' in data) || data.contents == null) -- ) { -- // @ts-expect-error we need to preserve file property for HMR -- data.contents = fs.readFileSync(data.file, 'utf-8') -- } -- return data --} -- --// #region Sass --// .scss/.sass processor --const makeScssWorker = ( -- resolvers: CSSAtImportResolvers, -- alias: Alias[], -- maxWorkers: number | undefined, --) => { -- const internalImporter = async ( -- url: string, -- importer: string, -- filename: string, -- ) => { -- importer = cleanScssBugUrl(importer) -- const resolved = await resolvers.sass(url, importer) -- if (resolved) { -- try { -- const data = await rebaseUrls( -- resolved, -- filename, -- alias, -- '$', -- resolvers.sass, -- ) -- return fixScssBugImportValue(data) -- } catch (data) { -- return data -- } -- } else { -- return null -- } -- } -- -- const worker = new WorkerWithFallback( -- () => -- async ( -- sassPath: string, -- data: string, -- // additionalData can a function that is not cloneable but it won't be used -- options: SassStylePreprocessorOptions & { additionalData: undefined }, -- ) => { -- // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -- const sass: typeof Sass = require(sassPath) -- // eslint-disable-next-line no-restricted-globals -- const path: typeof import('node:path') = require('node:path') -- -- // NOTE: `sass` always runs it's own importer first, and only falls back to -- // the `importer` option when it can't resolve a path -- const _internalImporter: Sass.LegacyAsyncImporter = ( -- url, -- importer, -- done, -- ) => { -- internalImporter(url, importer, options.filename).then((data) => -- done?.(data), -- ) -- } -- const importer = [_internalImporter] -- if (options.importer) { -- Array.isArray(options.importer) -- ? importer.unshift(...options.importer) -- : importer.unshift(options.importer) -- } -- -- const finalOptions: Sass.LegacyOptions<'async'> = { -- ...options, -- data, -- file: options.filename, -- outFile: options.filename, -- importer, -- ...(options.enableSourcemap -- ? { -- sourceMap: true, -- omitSourceMapUrl: true, -- sourceMapRoot: path.dirname(options.filename), -- } -- : {}), -- } -- return new Promise((resolve, reject) => { -- sass.render(finalOptions, (err, res) => { -- if (err) { -- reject(err) -- } else { -- resolve({ -- css: res!.css.toString(), -- map: res!.map?.toString(), -- stats: res!.stats, -- }) -- } -- }) -- }) -- }, -- { -- parentFunctions: { internalImporter }, -- shouldUseFake(_sassPath, _data, options) { -- // functions and importer is a function and is not serializable -- // in that case, fallback to running in main thread -- return !!( -- (options.functions && Object.keys(options.functions).length > 0) || -- (options.importer && -- (!Array.isArray(options.importer) || options.importer.length > 0)) -- ) -- }, -- max: maxWorkers, -- }, -- ) -- return worker --} -- --const makeModernScssWorker = ( -- resolvers: CSSAtImportResolvers, -- alias: Alias[], -- maxWorkers: number | undefined, --) => { -- const internalCanonicalize = async ( -- url: string, -- importer: string, -- ): Promise => { -- importer = cleanScssBugUrl(importer) -- const resolved = await resolvers.sass(url, importer) -- return resolved ?? null -- } -- -- const internalLoad = async (file: string, rootFile: string) => { -- const result = await rebaseUrls(file, rootFile, alias, '$', resolvers.sass) -- if (result.contents) { -- return result.contents -- } -- return await fsp.readFile(result.file, 'utf-8') -- } -- -- const worker = new WorkerWithFallback( -- () => -- async ( -- sassPath: string, -- data: string, -- // additionalData can a function that is not cloneable but it won't be used -- options: SassStylePreprocessorOptions & { additionalData: undefined }, -- ) => { -- // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -- const sass: typeof Sass = require(sassPath) -- // eslint-disable-next-line no-restricted-globals -- const path: typeof import('node:path') = require('node:path') -- -- const { fileURLToPath, pathToFileURL }: typeof import('node:url') = -- // eslint-disable-next-line no-restricted-globals -- require('node:url') -- -- const sassOptions = { ...options } as Sass.StringOptions<'async'> -- sassOptions.url = pathToFileURL(options.filename) -- sassOptions.sourceMap = options.enableSourcemap -- -- const internalImporter: Sass.Importer<'async'> = { -- async canonicalize(url, context) { -- const importer = context.containingUrl -- ? fileURLToPath(context.containingUrl) -- : options.filename -- const resolved = await internalCanonicalize(url, importer) -- return resolved ? pathToFileURL(resolved) : null -- }, -- async load(canonicalUrl) { -- const ext = path.extname(canonicalUrl.pathname) -- let syntax: Sass.Syntax = 'scss' -- if (ext === '.sass') { -- syntax = 'indented' -- } else if (ext === '.css') { -- syntax = 'css' -- } -- const contents = await internalLoad( -- fileURLToPath(canonicalUrl), -- options.filename, -- ) -- return { contents, syntax } -- }, -- } -- sassOptions.importers = [ -- ...(sassOptions.importers ?? []), -- internalImporter, -- ] -- -- const result = await sass.compileStringAsync(data, sassOptions) -- return { -- css: result.css, -- map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -- stats: { -- includedFiles: result.loadedUrls -- .filter((url) => url.protocol === 'file:') -- .map((url) => fileURLToPath(url)), -- }, -- } satisfies ScssWorkerResult -- }, -- { -- parentFunctions: { -- internalCanonicalize, -- internalLoad, -- }, -- shouldUseFake(_sassPath, _data, options) { -- // functions and importer is a function and is not serializable -- // in that case, fallback to running in main thread -- return !!( -- (options.functions && Object.keys(options.functions).length > 0) || -- (options.importers && -- (!Array.isArray(options.importers) || options.importers.length > 0)) -- ) -- }, -- max: maxWorkers, -- }, -- ) -- return worker --} -- --// this is mostly a copy&paste of makeModernScssWorker --// however sharing code between two is hard because --// makeModernScssWorker above needs function inlined for worker. --const makeModernCompilerScssWorker = ( -- resolvers: CSSAtImportResolvers, -- alias: Alias[], -- _maxWorkers: number | undefined, --) => { -- let compiler: Sass.AsyncCompiler | undefined -- -- const worker: Awaited> = { -- async run(sassPath, data, options) { -- // need pathToFileURL for windows since import("D:...") fails -- // https://github.com/nodejs/node/issues/31710 -- const sass: typeof Sass = (await import(pathToFileURL(sassPath).href)) -- .default -- compiler ??= await sass.initAsyncCompiler() -- -- const sassOptions = { ...options } as Sass.StringOptions<'async'> -- sassOptions.url = pathToFileURL(options.filename) -- sassOptions.sourceMap = options.enableSourcemap -- -- const internalImporter: Sass.Importer<'async'> = { -- async canonicalize(url, context) { -- const importer = context.containingUrl -- ? fileURLToPath(context.containingUrl) -- : options.filename -- const resolved = await resolvers.sass(url, cleanScssBugUrl(importer)) -- return resolved ? pathToFileURL(resolved) : null -- }, -- async load(canonicalUrl) { -- const ext = path.extname(canonicalUrl.pathname) -- let syntax: Sass.Syntax = 'scss' -- if (ext === '.sass') { -- syntax = 'indented' -- } else if (ext === '.css') { -- syntax = 'css' -- } -- const result = await rebaseUrls( -- fileURLToPath(canonicalUrl), -- options.filename, -- alias, -- '$', -- resolvers.sass, -- ) -- const contents = -- result.contents ?? (await fsp.readFile(result.file, 'utf-8')) -- return { contents, syntax } -- }, -- } -- sassOptions.importers = [ -- ...(sassOptions.importers ?? []), -- internalImporter, -- ] -- -- const result = await compiler.compileStringAsync(data, sassOptions) -- return { -- css: result.css, -- map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -- stats: { -- includedFiles: result.loadedUrls -- .filter((url) => url.protocol === 'file:') -- .map((url) => fileURLToPath(url)), -- }, -- } satisfies ScssWorkerResult -- }, -- async stop() { -- compiler?.dispose() -- compiler = undefined -- }, -- } -- -- return worker --} -- - type ScssWorkerResult = { - css: string - map?: string | undefined - stats: Pick - } -- --const scssProcessor = ( -- maxWorkers: number | undefined, --): SassStylePreprocessor => { -- const workerMap = new Map>() -- -- return { -- close() { -- for (const worker of workerMap.values()) { -- worker.stop() -- } -- }, -- async process(source, root, options, resolvers) { -- const sassPackage = loadSassPackage(root) -- // TODO: change default in v6 -- // options.api ?? sassPackage.name === "sass-embedded" ? "modern-compiler" : "modern"; -- const api = options.api ?? 'legacy' -- -- if (!workerMap.has(options.alias)) { -- workerMap.set( -- options.alias, -- api === 'modern-compiler' -- ? makeModernCompilerScssWorker(resolvers, options.alias, maxWorkers) -- : api === 'modern' -- ? makeModernScssWorker(resolvers, options.alias, maxWorkers) -- : makeScssWorker(resolvers, options.alias, maxWorkers), -- ) -- } -- const worker = workerMap.get(options.alias)! -- -- const { content: data, map: additionalMap } = await getSource( -- source, -- options.filename, -- options.additionalData, -- options.enableSourcemap, -- ) -- -- const optionsWithoutAdditionalData = { -- ...options, -- additionalData: undefined, -- } -- try { -- const result = await worker.run( -- sassPackage.path, -- data, -- optionsWithoutAdditionalData, -- ) -- const deps = result.stats.includedFiles.map((f) => cleanScssBugUrl(f)) -- const map: ExistingRawSourceMap | undefined = result.map -- ? JSON.parse(result.map.toString()) -- : undefined -- -- return { -- code: result.css.toString(), -- map, -- additionalMap, -- deps, -- } -- } catch (e) { -- // normalize SASS error -- e.message = `[sass] ${e.message}` -- e.id = e.file -- e.frame = e.formatted -- return { code: '', error: e, deps: [] } -- } -- }, -- } --} - // #endregion -- --/** -- * relative url() inside \@imported sass and less files must be rebased to use -- * root file as base. -- */ --async function rebaseUrls( -- file: string, -- rootFile: string, -- alias: Alias[], -- variablePrefix: string, -- resolver: ResolveFn, --): Promise<{ file: string; contents?: string }> { -- file = path.resolve(file) // ensure os-specific flashes -- // in the same dir, no need to rebase -- const fileDir = path.dirname(file) -- const rootDir = path.dirname(rootFile) -- if (fileDir === rootDir) { -- return { file } -- } -- -- const content = await fsp.readFile(file, 'utf-8') -- // no url() -- const hasUrls = cssUrlRE.test(content) -- // data-uri() calls -- const hasDataUris = cssDataUriRE.test(content) -- // no @import xxx.css -- const hasImportCss = importCssRE.test(content) -- -- if (!hasUrls && !hasDataUris && !hasImportCss) { -- return { file } -- } -- -- let rebased -- const rebaseFn = async (url: string) => { -- if (url[0] === '/') return url -- // ignore url's starting with variable -- if (url.startsWith(variablePrefix)) return url -- // match alias, no need to rewrite -- for (const { find } of alias) { -- const matches = -- typeof find === 'string' ? url.startsWith(find) : find.test(url) -- if (matches) { -- return url -- } -- } -- const absolute = (await resolver(url, file)) || path.resolve(fileDir, url) -- const relative = path.relative(rootDir, absolute) -- return normalizePath(relative) -- } -- -- // fix css imports in less such as `@import "foo.css"` -- if (hasImportCss) { -- rebased = await rewriteImportCss(content, rebaseFn) -- } -- -- if (hasUrls) { -- rebased = await rewriteCssUrls(rebased || content, rebaseFn) -- } -- -- if (hasDataUris) { -- rebased = await rewriteCssDataUris(rebased || content, rebaseFn) -- } -- -- return { -- file, -- contents: rebased, -- } --} -- --// #region Less --// .less --const makeLessWorker = ( -- resolvers: CSSAtImportResolvers, -- alias: Alias[], -- maxWorkers: number | undefined, --) => { -- const viteLessResolve = async ( -- filename: string, -- dir: string, -- rootFile: string, -- ) => { -- const resolved = await resolvers.less(filename, path.join(dir, '*')) -- if (!resolved) return undefined -- -- const result = await rebaseUrls( -- resolved, -- rootFile, -- alias, -- '@', -- resolvers.less, -- ) -- if (result) { -- return { -- resolved, -- contents: 'contents' in result ? result.contents : undefined, -- } -- } -- return result -- } -- -- const worker = new WorkerWithFallback( -- () => { -- // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -- const fsp = require('node:fs/promises') -- // eslint-disable-next-line no-restricted-globals -- const path = require('node:path') -- -- let ViteLessManager: any -- const createViteLessPlugin = ( -- less: typeof Less, -- rootFile: string, -- ): Less.Plugin => { -- const { FileManager } = less -- ViteLessManager ??= class ViteManager extends FileManager { -- rootFile -- constructor(rootFile: string) { -- super() -- this.rootFile = rootFile -- } -- override supports(filename: string) { -- return !/^(?:https?:)?\/\//.test(filename) -- } -- override supportsSync() { -- return false -- } -- override async loadFile( -- filename: string, -- dir: string, -- opts: any, -- env: any, -- ): Promise { -- const result = await viteLessResolve(filename, dir, this.rootFile) -- if (result) { -- return { -- filename: path.resolve(result.resolved), -- contents: -- result.contents ?? -- (await fsp.readFile(result.resolved, 'utf-8')), -- } -- } else { -- return super.loadFile(filename, dir, opts, env) -- } -- } -- } -- -- return { -- install(_, pluginManager) { -- pluginManager.addFileManager(new ViteLessManager(rootFile)) -- }, -- minVersion: [3, 0, 0], -- } -- } -- -- return async ( -- lessPath: string, -- content: string, -- // additionalData can a function that is not cloneable but it won't be used -- options: StylePreprocessorOptions & { additionalData: undefined }, -- ) => { -- // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -- const nodeLess: typeof Less = require(lessPath) -- const viteResolverPlugin = createViteLessPlugin( -- nodeLess, -- options.filename, -- ) -- const result = await nodeLess.render(content, { -- ...options, -- plugins: [viteResolverPlugin, ...(options.plugins || [])], -- ...(options.enableSourcemap -- ? { -- sourceMap: { -- outputSourceFiles: true, -- sourceMapFileInline: false, -- }, -- } -- : {}), -- }) -- return result -- } -- }, -- { -- parentFunctions: { viteLessResolve }, -- shouldUseFake(_lessPath, _content, options) { -- // plugins are a function and is not serializable -- // in that case, fallback to running in main thread -- return options.plugins?.length > 0 -- }, -- max: maxWorkers, -- }, -- ) -- return worker --} -- --const lessProcessor = (maxWorkers: number | undefined): StylePreprocessor => { -- const workerMap = new Map>() -- -- return { -- close() { -- for (const worker of workerMap.values()) { -- worker.stop() -- } -- }, -- async process(source, root, options, resolvers) { -- const lessPath = loadPreprocessorPath(PreprocessLang.less, root) -- -- if (!workerMap.has(options.alias)) { -- workerMap.set( -- options.alias, -- makeLessWorker(resolvers, options.alias, maxWorkers), -- ) -- } -- const worker = workerMap.get(options.alias)! -- -- const { content, map: additionalMap } = await getSource( -- source, -- options.filename, -- options.additionalData, -- options.enableSourcemap, -- ) -- -- let result: Less.RenderOutput | undefined -- const optionsWithoutAdditionalData = { -- ...options, -- additionalData: undefined, -- } -- try { -- result = await worker.run( -- lessPath, -- content, -- optionsWithoutAdditionalData, -- ) -- } catch (e) { -- const error = e as Less.RenderError -- // normalize error info -- const normalizedError: RollupError = new Error( -- `[less] ${error.message || error.type}`, -- ) as RollupError -- normalizedError.loc = { -- file: error.filename || options.filename, -- line: error.line, -- column: error.column, -- } -- return { code: '', error: normalizedError, deps: [] } -- } -- -- const map: ExistingRawSourceMap = result.map && JSON.parse(result.map) -- if (map) { -- delete map.sourcesContent -- } -- -- return { -- code: result.css.toString(), -- map, -- additionalMap, -- deps: result.imports, -- } -- }, -- } --} - // #endregion - --// #region Stylus --// .styl --const makeStylWorker = (maxWorkers: number | undefined) => { -- const worker = new WorkerWithFallback( -- () => { -- return async ( -- stylusPath: string, -- content: string, -- root: string, -- // additionalData can a function that is not cloneable but it won't be used -- options: StylusStylePreprocessorOptions & { additionalData: undefined }, -- ) => { -- // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -- const nodeStylus: typeof Stylus = require(stylusPath) -- -- const ref = nodeStylus(content, options) -- if (options.define) { -- for (const key in options.define) { -- ref.define(key, options.define[key]) -- } -- } -- if (options.enableSourcemap) { -- ref.set('sourcemap', { -- comment: false, -- inline: false, -- basePath: root, -- }) -- } -- -- return { -- code: ref.render(), -- // @ts-expect-error sourcemap exists -- map: ref.sourcemap as ExistingRawSourceMap | undefined, -- deps: ref.deps(), -- } -- } -- }, -- { -- shouldUseFake(_stylusPath, _content, _root, options) { -- // define can include functions and those are not serializable -- // in that case, fallback to running in main thread -- return !!( -- options.define && -- Object.values(options.define).some((d) => typeof d === 'function') -- ) -- }, -- max: maxWorkers, -- }, -- ) -- return worker --} -- --const stylProcessor = ( -- maxWorkers: number | undefined, --): StylusStylePreprocessor => { -- const workerMap = new Map>() -- -- return { -- close() { -- for (const worker of workerMap.values()) { -- worker.stop() -- } -- }, -- async process(source, root, options, resolvers) { -- const stylusPath = loadPreprocessorPath(PreprocessLang.stylus, root) -- -- if (!workerMap.has(options.alias)) { -- workerMap.set(options.alias, makeStylWorker(maxWorkers)) -- } -- const worker = workerMap.get(options.alias)! -- -- // Get source with preprocessor options.additionalData. Make sure a new line separator -- // is added to avoid any render error, as added stylus content may not have semi-colon separators -- const { content, map: additionalMap } = await getSource( -- source, -- options.filename, -- options.additionalData, -- options.enableSourcemap, -- '\n', -- ) -- // Get preprocessor options.imports dependencies as stylus -- // does not return them with its builtin `.deps()` method -- const importsDeps = (options.imports ?? []).map((dep: string) => -- path.resolve(dep), -- ) -- const optionsWithoutAdditionalData = { -- ...options, -- additionalData: undefined, -- } -- try { -- const { code, map, deps } = await worker.run( -- stylusPath, -- content, -- root, -- optionsWithoutAdditionalData, -- ) -- return { -- code, -- map: formatStylusSourceMap(map, root), -- additionalMap, -- // Concat imports deps with computed deps -- deps: [...deps, ...importsDeps], -- } -- } catch (e) { -- const wrapped = new Error(`[stylus] ${e.message}`) -- wrapped.name = e.name -- wrapped.stack = e.stack -- return { code: '', error: wrapped, deps: [] } -- } -- }, -- } --} -- - function formatStylusSourceMap( - mapBefore: ExistingRawSourceMap | undefined, - root: string, -@@ -2932,118 +904,6 @@ function isPreProcessor(lang: any): lang is PreprocessLang { - } - - const importLightningCSS = createCachedImport(() => import('lightningcss')) --async function compileLightningCSS( -- id: string, -- src: string, -- config: ResolvedConfig, -- urlReplacer?: CssUrlReplacer, --): ReturnType { -- const deps = new Set() -- // Relative path is needed to get stable hash when using CSS modules -- const filename = cleanUrl(path.relative(config.root, id)) -- const toAbsolute = (filePath: string) => -- path.isAbsolute(filePath) ? filePath : path.join(config.root, filePath) -- -- const res = styleAttrRE.test(id) -- ? (await importLightningCSS()).transformStyleAttribute({ -- filename, -- code: Buffer.from(src), -- targets: config.css?.lightningcss?.targets, -- minify: config.isProduction && !!config.build.cssMinify, -- analyzeDependencies: true, -- }) -- : await ( -- await importLightningCSS() -- ).bundleAsync({ -- ...config.css?.lightningcss, -- filename, -- resolver: { -- read(filePath) { -- if (filePath === filename) { -- return src -- } -- // This happens with html-proxy (#13776) -- if (!filePath.endsWith('.css')) { -- return src -- } -- return fs.readFileSync(toAbsolute(filePath), 'utf-8') -- }, -- async resolve(id, from) { -- const publicFile = checkPublicFile(id, config) -- if (publicFile) { -- return publicFile -- } -- -- const resolved = await getAtImportResolvers(config).css( -- id, -- toAbsolute(from), -- ) -- -- if (resolved) { -- deps.add(resolved) -- return resolved -- } -- return id -- }, -- }, -- minify: config.isProduction && !!config.build.cssMinify, -- sourceMap: -- config.command === 'build' -- ? !!config.build.sourcemap -- : config.css?.devSourcemap, -- analyzeDependencies: true, -- cssModules: cssModuleRE.test(id) -- ? (config.css?.lightningcss?.cssModules ?? true) -- : undefined, -- }) -- -- // NodeJS res.code = Buffer -- // Deno res.code = Uint8Array -- // For correct decode compiled css need to use TextDecoder -- let css = decoder.decode(res.code) -- for (const dep of res.dependencies!) { -- switch (dep.type) { -- case 'url': -- if (skipUrlReplacer(dep.url)) { -- css = css.replace(dep.placeholder, () => dep.url) -- break -- } -- deps.add(dep.url) -- if (urlReplacer) { -- const replaceUrl = await urlReplacer(dep.url, id) -- css = css.replace(dep.placeholder, () => replaceUrl) -- } else { -- css = css.replace(dep.placeholder, () => dep.url) -- } -- break -- default: -- throw new Error(`Unsupported dependency type: ${dep.type}`) -- } -- } -- -- let modules: Record | undefined -- if ('exports' in res && res.exports) { -- modules = {} -- // https://github.com/parcel-bundler/lightningcss/issues/291 -- const sortedEntries = Object.entries(res.exports).sort((a, b) => -- a[0].localeCompare(b[0]), -- ) -- for (const [key, value] of sortedEntries) { -- modules[key] = value.name -- // https://lightningcss.dev/css-modules.html#class-composition -- for (const c of value.composes) { -- modules[key] += ' ' + c.name -- } -- } -- } -- -- return { -- code: css, -- map: 'map' in res ? res.map?.toString() : undefined, -- deps, -- modules, -- } --} - - // Convert https://esbuild.github.io/api/#target - // To https://github.com/parcel-bundler/lightningcss/blob/master/node/targets.d.ts -diff --git a/packages/vite/src/node/plugins/css/compileCSS.ts b/packages/vite/src/node/plugins/css/compileCSS.ts -new file mode 100644 -index 000000000..c939adbe3 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/compileCSS.ts -@@ -0,0 +1,449 @@ -+import path from 'node:path' -+import glob from 'fast-glob' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import colors from 'picocolors' -+import type * as PostCSS from 'postcss' -+import type { RawSourceMap } from '@ampproject/remapping' -+import { CSS_LANGS_RE } from 'packages/vite/src/node/constants'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { getHash } from 'packages/vite/src/node/utils'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { removeDirectQuery } from 'packages/vite/src/node/utils'; -+import { generateCodeFrame } from 'packages/vite/src/node/utils'; -+import { cleanUrl } from 'packages/vite/src/shared/utils'; -+import { ResolveFn } from 'packages/vite/src/node/index'; -+ -+ -+const export cssModuleRE = new RegExp(`\\.module${CSS_LANGS_RE.source}`) -+ -+export const enum PreprocessLang { -+ less = 'less', -+ sass = 'sass', -+ scss = 'scss', -+ styl = 'styl', -+ stylus = 'stylus', -+} -+ -+export const enum PureCssLang { -+ css = 'css', -+} -+ -+export const enum PostCssDialectLang { -+ sss = 'sugarss', -+} -+ -+export type CssLang = -+ | keyof typeof PureCssLang -+ | keyof typeof PreprocessLang -+ | keyof typeof PostCssDialectLang -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export function getCssResolversKeys( -+ resolvers: CSSAtImportResolvers, -+): Array { -+ return Object.keys(resolvers) as unknown as Array -+} -+ -+export async function compileCSSPreprocessors( -+ id: string, -+ lang: PreprocessLang, -+ code: string, -+ config: ResolvedConfig, -+ workerController: PreprocessorWorkerController, -+): Promise<{ code: string; map?: ExistingRawSourceMap; deps?: Set }> { -+ const { preprocessorOptions, devSourcemap } = config.css ?? {} -+ const atImportResolvers = getAtImportResolvers(config) -+ -+ const preProcessor = workerController[lang] -+ let opts = (preprocessorOptions && preprocessorOptions[lang]) || {} -+ // support @import from node dependencies by default -+ switch (lang) { -+ case PreprocessLang.scss: -+ case PreprocessLang.sass: -+ opts = { -+ includePaths: ['node_modules'], -+ alias: config.resolve.alias, -+ ...opts, -+ } -+ break -+ case PreprocessLang.less: -+ case PreprocessLang.styl: -+ case PreprocessLang.stylus: -+ opts = { -+ paths: ['node_modules'], -+ alias: config.resolve.alias, -+ ...opts, -+ } -+ } -+ // important: set this for relative import resolving -+ opts.filename = cleanUrl(id) -+ opts.enableSourcemap = devSourcemap ?? false -+ -+ const preprocessResult = await preProcessor( -+ code, -+ config.root, -+ opts, -+ atImportResolvers, -+ ) -+ if (preprocessResult.error) { -+ throw preprocessResult.error -+ } -+ -+ let deps: Set | undefined -+ if (preprocessResult.deps) { -+ const normalizedFilename = normalizePath(opts.filename) -+ // sometimes sass registers the file itself as a dep -+ deps = new Set( -+ [...preprocessResult.deps].filter( -+ (dep) => normalizePath(dep) !== normalizedFilename, -+ ), -+ ) -+ } -+ -+ return { -+ code: preprocessResult.code, -+ map: combineSourcemapsIfExists( -+ opts.filename, -+ preprocessResult.map, -+ preprocessResult.additionalMap, -+ ), -+ deps, -+ } -+} -+ -+export function createCSSResolvers(config: ResolvedConfig): CSSAtImportResolvers { -+ let cssResolve: ResolveFn | undefined -+ let sassResolve: ResolveFn | undefined -+ let lessResolve: ResolveFn | undefined -+ return { -+ get css() { -+ return ( -+ cssResolve || -+ (cssResolve = config.createResolver({ -+ extensions: ['.css'], -+ mainFields: ['style'], -+ conditions: ['style'], -+ tryIndex: false, -+ preferRelative: true, -+ })) -+ ) -+ }, -+ -+ get sass() { -+ return ( -+ sassResolve || -+ (sassResolve = config.createResolver({ -+ extensions: ['.scss', '.sass', '.css'], -+ mainFields: ['sass', 'style'], -+ conditions: ['sass', 'style'], -+ tryIndex: true, -+ tryPrefix: '_', -+ preferRelative: true, -+ })) -+ ) -+ }, -+ -+ get less() { -+ return ( -+ lessResolve || -+ (lessResolve = config.createResolver({ -+ extensions: ['.less', '.css'], -+ mainFields: ['less', 'style'], -+ conditions: ['less', 'style'], -+ tryIndex: false, -+ preferRelative: true, -+ })) -+ ) -+ }, -+ } -+} -+ -+const export configToAtImportResolvers = new WeakMap< -+ ResolvedConfig, -+ CSSAtImportResolvers -+>() -+ -+export function getAtImportResolvers(config: ResolvedConfig) { -+ let atImportResolvers = configToAtImportResolvers.get(config) -+ if (!atImportResolvers) { -+ atImportResolvers = createCSSResolvers(config) -+ configToAtImportResolvers.set(config, atImportResolvers) -+ } -+ return atImportResolvers -+} -+ -+export async function compileCSS( -+ id: string, -+ code: string, -+ config: ResolvedConfig, -+ workerController: PreprocessorWorkerController, -+ urlReplacer?: CssUrlReplacer, -+): Promise<{ -+ code: string -+ map?: SourceMapInput -+ ast?: PostCSS.Result -+ modules?: Record -+ deps?: Set -+}> { -+ if (config.css?.transformer === 'lightningcss') { -+ return compileLightningCSS(id, code, config, urlReplacer) -+ } -+ -+ const { modules: modulesOptions, devSourcemap } = config.css || {} -+ const isModule = modulesOptions !== false && cssModuleRE.test(id) -+ // although at serve time it can work without processing, we do need to -+ // crawl them in order to register watch dependencies. -+ const needInlineImport = code.includes('@import') -+ const hasUrl = cssUrlRE.test(code) || cssImageSetRE.test(code) -+ const lang = CSS_LANGS_RE.exec(id)?.[1] as CssLang | undefined -+ const postcssConfig = await resolvePostcssConfig(config) -+ -+ // 1. plain css that needs no processing -+ if ( -+ lang === 'css' && -+ !postcssConfig && -+ !isModule && -+ !needInlineImport && -+ !hasUrl -+ ) { -+ return { code, map: null } -+ } -+ -+ let modules: Record | undefined -+ const deps = new Set() -+ -+ // 2. pre-processors: sass etc. -+ let preprocessorMap: ExistingRawSourceMap | undefined -+ if (isPreProcessor(lang)) { -+ const preprocessorResult = await compileCSSPreprocessors( -+ id, -+ lang, -+ code, -+ config, -+ workerController, -+ ) -+ code = preprocessorResult.code -+ preprocessorMap = preprocessorResult.map -+ preprocessorResult.deps?.forEach((dep) => deps.add(dep)) -+ } -+ -+ // 3. postcss -+ const atImportResolvers = getAtImportResolvers(config) -+ const postcssOptions = (postcssConfig && postcssConfig.options) || {} -+ -+ const postcssPlugins = -+ postcssConfig && postcssConfig.plugins ? postcssConfig.plugins.slice() : [] -+ -+ if (needInlineImport) { -+ postcssPlugins.unshift( -+ (await importPostcssImport()).default({ -+ async resolve(id, basedir) { -+ const publicFile = checkPublicFile(id, config) -+ if (publicFile) { -+ return publicFile -+ } -+ -+ const resolved = await atImportResolvers.css( -+ id, -+ path.join(basedir, '*'), -+ ) -+ -+ if (resolved) { -+ return path.resolve(resolved) -+ } -+ -+ // postcss-import falls back to `resolve` dep if this is unresolved, -+ // but we've shimmed to remove the `resolve` dep to cut on bundle size. -+ // warn here to provide a better error message. -+ if (!path.isAbsolute(id)) { -+ config.logger.error( -+ colors.red( -+ `Unable to resolve \`@import "${id}"\` from ${basedir}`, -+ ), -+ ) -+ } -+ -+ return id -+ }, -+ async load(id) { -+ const code = await fs.promises.readFile(id, 'utf-8') -+ const lang = CSS_LANGS_RE.exec(id)?.[1] as CssLang | undefined -+ if (isPreProcessor(lang)) { -+ const result = await compileCSSPreprocessors( -+ id, -+ lang, -+ code, -+ config, -+ workerController, -+ ) -+ result.deps?.forEach((dep) => deps.add(dep)) -+ // TODO: support source map -+ return result.code -+ } -+ return code -+ }, -+ nameLayer(index) { -+ return `vite--anon-layer-${getHash(id)}-${index}` -+ }, -+ }), -+ ) -+ } -+ -+ if (urlReplacer) { -+ postcssPlugins.push( -+ UrlRewritePostcssPlugin({ -+ replacer: urlReplacer, -+ logger: config.logger, -+ }), -+ ) -+ } -+ -+ if (isModule) { -+ postcssPlugins.unshift( -+ (await importPostcssModules()).default({ -+ ...modulesOptions, -+ localsConvention: modulesOptions?.localsConvention, -+ getJSON( -+ cssFileName: string, -+ _modules: Record, -+ outputFileName: string, -+ ) { -+ modules = _modules -+ if (modulesOptions && typeof modulesOptions.getJSON === 'function') { -+ modulesOptions.getJSON(cssFileName, _modules, outputFileName) -+ } -+ }, -+ async resolve(id: string, importer: string) { -+ for (const key of getCssResolversKeys(atImportResolvers)) { -+ const resolved = await atImportResolvers[key](id, importer) -+ if (resolved) { -+ return path.resolve(resolved) -+ } -+ } -+ -+ return id -+ }, -+ }), -+ ) -+ } -+ -+ if (!postcssPlugins.length) { -+ return { -+ code, -+ map: preprocessorMap, -+ deps, -+ } -+ } -+ -+ let postcssResult: PostCSS.Result -+ try { -+ const source = removeDirectQuery(id) -+ const postcss = await importPostcss() -+ // postcss is an unbundled dep and should be lazy imported -+ postcssResult = await postcss.default(postcssPlugins).process(code, { -+ ...postcssOptions, -+ parser: lang === 'sss' ? loadSss(config.root) : postcssOptions.parser, -+ to: source, -+ from: source, -+ ...(devSourcemap -+ ? { -+ map: { -+ inline: false, -+ annotation: false, -+ // postcss may return virtual files -+ // we cannot obtain content of them, so this needs to be enabled -+ sourcesContent: true, -+ // when "prev: preprocessorMap", the result map may include duplicate filename in `postcssResult.map.sources` -+ // prev: preprocessorMap, -+ }, -+ } -+ : {}), -+ }) -+ -+ // record CSS dependencies from @imports -+ for (const message of postcssResult.messages) { -+ if (message.type === 'dependency') { -+ deps.add(normalizePath(message.file as string)) -+ } else if (message.type === 'dir-dependency') { -+ // https://github.com/postcss/postcss/blob/main/docs/guidelines/plugin.md#3-dependencies -+ const { dir, glob: globPattern = '**' } = message -+ const pattern = -+ glob.escapePath(normalizePath(path.resolve(path.dirname(id), dir))) + -+ `/` + -+ globPattern -+ const files = glob.sync(pattern, { -+ ignore: ['**/node_modules/**'], -+ }) -+ for (let i = 0; i < files.length; i++) { -+ deps.add(files[i]) -+ } -+ } else if (message.type === 'warning') { -+ const warning = message as PostCSS.Warning -+ let msg = `[vite:css] ${warning.text}` -+ msg += `\n${generateCodeFrame( -+ code, -+ { -+ line: warning.line, -+ column: warning.column - 1, // 1-based -+ }, -+ warning.endLine !== undefined && warning.endColumn !== undefined -+ ? { -+ line: warning.endLine, -+ column: warning.endColumn - 1, // 1-based -+ } -+ : undefined, -+ )}` -+ config.logger.warn(colors.yellow(msg)) -+ } -+ } -+ } catch (e) { -+ e.message = `[postcss] ${e.message}` -+ e.code = code -+ e.loc = { -+ file: e.file, -+ line: e.line, -+ column: e.column - 1, // 1-based -+ } -+ throw e -+ } -+ -+ if (!devSourcemap) { -+ return { -+ ast: postcssResult, -+ code: postcssResult.css, -+ map: { mappings: '' }, -+ modules, -+ deps, -+ } -+ } -+ -+ const rawPostcssMap = postcssResult.map.toJSON() -+ -+ const postcssMap = await formatPostcssSourceMap( -+ // version property of rawPostcssMap is declared as string -+ // but actually it is a number -+ rawPostcssMap as Omit as ExistingRawSourceMap, -+ cleanUrl(id), -+ ) -+ -+ return { -+ ast: postcssResult, -+ code: postcssResult.css, -+ map: combineSourcemapsIfExists(cleanUrl(id), postcssMap, preprocessorMap), -+ modules, -+ deps, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/compileCSSPreprocessors.ts b/packages/vite/src/node/plugins/css/compileCSSPreprocessors.ts -new file mode 100644 -index 000000000..289ca71a0 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/compileCSSPreprocessors.ts -@@ -0,0 +1,88 @@ -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { cleanUrl } from 'packages/vite/src/shared/utils'; -+ -+ -+export const enum PreprocessLang { -+ less = 'less', -+ sass = 'sass', -+ scss = 'scss', -+ styl = 'styl', -+ stylus = 'stylus', -+} -+ -+export async function compileCSSPreprocessors( -+ id: string, -+ lang: PreprocessLang, -+ code: string, -+ config: ResolvedConfig, -+ workerController: PreprocessorWorkerController, -+): Promise<{ code: string; map?: ExistingRawSourceMap; deps?: Set }> { -+ const { preprocessorOptions, devSourcemap } = config.css ?? {} -+ const atImportResolvers = getAtImportResolvers(config) -+ -+ const preProcessor = workerController[lang] -+ let opts = (preprocessorOptions && preprocessorOptions[lang]) || {} -+ // support @import from node dependencies by default -+ switch (lang) { -+ case PreprocessLang.scss: -+ case PreprocessLang.sass: -+ opts = { -+ includePaths: ['node_modules'], -+ alias: config.resolve.alias, -+ ...opts, -+ } -+ break -+ case PreprocessLang.less: -+ case PreprocessLang.styl: -+ case PreprocessLang.stylus: -+ opts = { -+ paths: ['node_modules'], -+ alias: config.resolve.alias, -+ ...opts, -+ } -+ } -+ // important: set this for relative import resolving -+ opts.filename = cleanUrl(id) -+ opts.enableSourcemap = devSourcemap ?? false -+ -+ const preprocessResult = await preProcessor( -+ code, -+ config.root, -+ opts, -+ atImportResolvers, -+ ) -+ if (preprocessResult.error) { -+ throw preprocessResult.error -+ } -+ -+ let deps: Set | undefined -+ if (preprocessResult.deps) { -+ const normalizedFilename = normalizePath(opts.filename) -+ // sometimes sass registers the file itself as a dep -+ deps = new Set( -+ [...preprocessResult.deps].filter( -+ (dep) => normalizePath(dep) !== normalizedFilename, -+ ), -+ ) -+ } -+ -+ return { -+ code: preprocessResult.code, -+ map: combineSourcemapsIfExists( -+ opts.filename, -+ preprocessResult.map, -+ preprocessResult.additionalMap, -+ ), -+ deps, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/compileLightningCSS.ts b/packages/vite/src/node/plugins/css/compileLightningCSS.ts -new file mode 100644 -index 000000000..7041a99cb ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/compileLightningCSS.ts -@@ -0,0 +1,585 @@ -+import fs from 'node:fs' -+import path from 'node:path' -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { cleanUrl } from 'packages/vite/src/shared/utils'; -+import { CSS_LANGS_RE } from 'packages/vite/src/node/constants'; -+import glob from 'fast-glob' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import colors from 'picocolors' -+import type * as PostCSS from 'postcss' -+import type { RawSourceMap } from '@ampproject/remapping' -+import { getHash } from 'packages/vite/src/node/utils'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { removeDirectQuery } from 'packages/vite/src/node/utils'; -+import { generateCodeFrame } from 'packages/vite/src/node/utils'; -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import { isExternalUrl } from 'packages/vite/src/node/utils'; -+import { isDataUrl } from 'packages/vite/src/node/utils'; -+ -+ -+const export decoder = new TextDecoder() -+ -+const export cssModuleRE = new RegExp(`\\.module${CSS_LANGS_RE.source}`) -+ -+const export styleAttrRE = /[?&]style-attr\b/ -+ -+export const enum PreprocessLang { -+ less = 'less', -+ sass = 'sass', -+ scss = 'scss', -+ styl = 'styl', -+ stylus = 'stylus', -+} -+ -+export const enum PureCssLang { -+ css = 'css', -+} -+ -+export const enum PostCssDialectLang { -+ sss = 'sugarss', -+} -+ -+export type CssLang = -+ | keyof typeof PureCssLang -+ | keyof typeof PreprocessLang -+ | keyof typeof PostCssDialectLang -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export function getCssResolversKeys( -+ resolvers: CSSAtImportResolvers, -+): Array { -+ return Object.keys(resolvers) as unknown as Array -+} -+ -+export async function compileCSSPreprocessors( -+ id: string, -+ lang: PreprocessLang, -+ code: string, -+ config: ResolvedConfig, -+ workerController: PreprocessorWorkerController, -+): Promise<{ code: string; map?: ExistingRawSourceMap; deps?: Set }> { -+ const { preprocessorOptions, devSourcemap } = config.css ?? {} -+ const atImportResolvers = getAtImportResolvers(config) -+ -+ const preProcessor = workerController[lang] -+ let opts = (preprocessorOptions && preprocessorOptions[lang]) || {} -+ // support @import from node dependencies by default -+ switch (lang) { -+ case PreprocessLang.scss: -+ case PreprocessLang.sass: -+ opts = { -+ includePaths: ['node_modules'], -+ alias: config.resolve.alias, -+ ...opts, -+ } -+ break -+ case PreprocessLang.less: -+ case PreprocessLang.styl: -+ case PreprocessLang.stylus: -+ opts = { -+ paths: ['node_modules'], -+ alias: config.resolve.alias, -+ ...opts, -+ } -+ } -+ // important: set this for relative import resolving -+ opts.filename = cleanUrl(id) -+ opts.enableSourcemap = devSourcemap ?? false -+ -+ const preprocessResult = await preProcessor( -+ code, -+ config.root, -+ opts, -+ atImportResolvers, -+ ) -+ if (preprocessResult.error) { -+ throw preprocessResult.error -+ } -+ -+ let deps: Set | undefined -+ if (preprocessResult.deps) { -+ const normalizedFilename = normalizePath(opts.filename) -+ // sometimes sass registers the file itself as a dep -+ deps = new Set( -+ [...preprocessResult.deps].filter( -+ (dep) => normalizePath(dep) !== normalizedFilename, -+ ), -+ ) -+ } -+ -+ return { -+ code: preprocessResult.code, -+ map: combineSourcemapsIfExists( -+ opts.filename, -+ preprocessResult.map, -+ preprocessResult.additionalMap, -+ ), -+ deps, -+ } -+} -+ -+export function createCSSResolvers(config: ResolvedConfig): CSSAtImportResolvers { -+ let cssResolve: ResolveFn | undefined -+ let sassResolve: ResolveFn | undefined -+ let lessResolve: ResolveFn | undefined -+ return { -+ get css() { -+ return ( -+ cssResolve || -+ (cssResolve = config.createResolver({ -+ extensions: ['.css'], -+ mainFields: ['style'], -+ conditions: ['style'], -+ tryIndex: false, -+ preferRelative: true, -+ })) -+ ) -+ }, -+ -+ get sass() { -+ return ( -+ sassResolve || -+ (sassResolve = config.createResolver({ -+ extensions: ['.scss', '.sass', '.css'], -+ mainFields: ['sass', 'style'], -+ conditions: ['sass', 'style'], -+ tryIndex: true, -+ tryPrefix: '_', -+ preferRelative: true, -+ })) -+ ) -+ }, -+ -+ get less() { -+ return ( -+ lessResolve || -+ (lessResolve = config.createResolver({ -+ extensions: ['.less', '.css'], -+ mainFields: ['less', 'style'], -+ conditions: ['less', 'style'], -+ tryIndex: false, -+ preferRelative: true, -+ })) -+ ) -+ }, -+ } -+} -+ -+const export configToAtImportResolvers = new WeakMap< -+ ResolvedConfig, -+ CSSAtImportResolvers -+>() -+ -+export function getAtImportResolvers(config: ResolvedConfig) { -+ let atImportResolvers = configToAtImportResolvers.get(config) -+ if (!atImportResolvers) { -+ atImportResolvers = createCSSResolvers(config) -+ configToAtImportResolvers.set(config, atImportResolvers) -+ } -+ return atImportResolvers -+} -+ -+export async function compileCSS( -+ id: string, -+ code: string, -+ config: ResolvedConfig, -+ workerController: PreprocessorWorkerController, -+ urlReplacer?: CssUrlReplacer, -+): Promise<{ -+ code: string -+ map?: SourceMapInput -+ ast?: PostCSS.Result -+ modules?: Record -+ deps?: Set -+}> { -+ if (config.css?.transformer === 'lightningcss') { -+ return compileLightningCSS(id, code, config, urlReplacer) -+ } -+ -+ const { modules: modulesOptions, devSourcemap } = config.css || {} -+ const isModule = modulesOptions !== false && cssModuleRE.test(id) -+ // although at serve time it can work without processing, we do need to -+ // crawl them in order to register watch dependencies. -+ const needInlineImport = code.includes('@import') -+ const hasUrl = cssUrlRE.test(code) || cssImageSetRE.test(code) -+ const lang = CSS_LANGS_RE.exec(id)?.[1] as CssLang | undefined -+ const postcssConfig = await resolvePostcssConfig(config) -+ -+ // 1. plain css that needs no processing -+ if ( -+ lang === 'css' && -+ !postcssConfig && -+ !isModule && -+ !needInlineImport && -+ !hasUrl -+ ) { -+ return { code, map: null } -+ } -+ -+ let modules: Record | undefined -+ const deps = new Set() -+ -+ // 2. pre-processors: sass etc. -+ let preprocessorMap: ExistingRawSourceMap | undefined -+ if (isPreProcessor(lang)) { -+ const preprocessorResult = await compileCSSPreprocessors( -+ id, -+ lang, -+ code, -+ config, -+ workerController, -+ ) -+ code = preprocessorResult.code -+ preprocessorMap = preprocessorResult.map -+ preprocessorResult.deps?.forEach((dep) => deps.add(dep)) -+ } -+ -+ // 3. postcss -+ const atImportResolvers = getAtImportResolvers(config) -+ const postcssOptions = (postcssConfig && postcssConfig.options) || {} -+ -+ const postcssPlugins = -+ postcssConfig && postcssConfig.plugins ? postcssConfig.plugins.slice() : [] -+ -+ if (needInlineImport) { -+ postcssPlugins.unshift( -+ (await importPostcssImport()).default({ -+ async resolve(id, basedir) { -+ const publicFile = checkPublicFile(id, config) -+ if (publicFile) { -+ return publicFile -+ } -+ -+ const resolved = await atImportResolvers.css( -+ id, -+ path.join(basedir, '*'), -+ ) -+ -+ if (resolved) { -+ return path.resolve(resolved) -+ } -+ -+ // postcss-import falls back to `resolve` dep if this is unresolved, -+ // but we've shimmed to remove the `resolve` dep to cut on bundle size. -+ // warn here to provide a better error message. -+ if (!path.isAbsolute(id)) { -+ config.logger.error( -+ colors.red( -+ `Unable to resolve \`@import "${id}"\` from ${basedir}`, -+ ), -+ ) -+ } -+ -+ return id -+ }, -+ async load(id) { -+ const code = await fs.promises.readFile(id, 'utf-8') -+ const lang = CSS_LANGS_RE.exec(id)?.[1] as CssLang | undefined -+ if (isPreProcessor(lang)) { -+ const result = await compileCSSPreprocessors( -+ id, -+ lang, -+ code, -+ config, -+ workerController, -+ ) -+ result.deps?.forEach((dep) => deps.add(dep)) -+ // TODO: support source map -+ return result.code -+ } -+ return code -+ }, -+ nameLayer(index) { -+ return `vite--anon-layer-${getHash(id)}-${index}` -+ }, -+ }), -+ ) -+ } -+ -+ if (urlReplacer) { -+ postcssPlugins.push( -+ UrlRewritePostcssPlugin({ -+ replacer: urlReplacer, -+ logger: config.logger, -+ }), -+ ) -+ } -+ -+ if (isModule) { -+ postcssPlugins.unshift( -+ (await importPostcssModules()).default({ -+ ...modulesOptions, -+ localsConvention: modulesOptions?.localsConvention, -+ getJSON( -+ cssFileName: string, -+ _modules: Record, -+ outputFileName: string, -+ ) { -+ modules = _modules -+ if (modulesOptions && typeof modulesOptions.getJSON === 'function') { -+ modulesOptions.getJSON(cssFileName, _modules, outputFileName) -+ } -+ }, -+ async resolve(id: string, importer: string) { -+ for (const key of getCssResolversKeys(atImportResolvers)) { -+ const resolved = await atImportResolvers[key](id, importer) -+ if (resolved) { -+ return path.resolve(resolved) -+ } -+ } -+ -+ return id -+ }, -+ }), -+ ) -+ } -+ -+ if (!postcssPlugins.length) { -+ return { -+ code, -+ map: preprocessorMap, -+ deps, -+ } -+ } -+ -+ let postcssResult: PostCSS.Result -+ try { -+ const source = removeDirectQuery(id) -+ const postcss = await importPostcss() -+ // postcss is an unbundled dep and should be lazy imported -+ postcssResult = await postcss.default(postcssPlugins).process(code, { -+ ...postcssOptions, -+ parser: lang === 'sss' ? loadSss(config.root) : postcssOptions.parser, -+ to: source, -+ from: source, -+ ...(devSourcemap -+ ? { -+ map: { -+ inline: false, -+ annotation: false, -+ // postcss may return virtual files -+ // we cannot obtain content of them, so this needs to be enabled -+ sourcesContent: true, -+ // when "prev: preprocessorMap", the result map may include duplicate filename in `postcssResult.map.sources` -+ // prev: preprocessorMap, -+ }, -+ } -+ : {}), -+ }) -+ -+ // record CSS dependencies from @imports -+ for (const message of postcssResult.messages) { -+ if (message.type === 'dependency') { -+ deps.add(normalizePath(message.file as string)) -+ } else if (message.type === 'dir-dependency') { -+ // https://github.com/postcss/postcss/blob/main/docs/guidelines/plugin.md#3-dependencies -+ const { dir, glob: globPattern = '**' } = message -+ const pattern = -+ glob.escapePath(normalizePath(path.resolve(path.dirname(id), dir))) + -+ `/` + -+ globPattern -+ const files = glob.sync(pattern, { -+ ignore: ['**/node_modules/**'], -+ }) -+ for (let i = 0; i < files.length; i++) { -+ deps.add(files[i]) -+ } -+ } else if (message.type === 'warning') { -+ const warning = message as PostCSS.Warning -+ let msg = `[vite:css] ${warning.text}` -+ msg += `\n${generateCodeFrame( -+ code, -+ { -+ line: warning.line, -+ column: warning.column - 1, // 1-based -+ }, -+ warning.endLine !== undefined && warning.endColumn !== undefined -+ ? { -+ line: warning.endLine, -+ column: warning.endColumn - 1, // 1-based -+ } -+ : undefined, -+ )}` -+ config.logger.warn(colors.yellow(msg)) -+ } -+ } -+ } catch (e) { -+ e.message = `[postcss] ${e.message}` -+ e.code = code -+ e.loc = { -+ file: e.file, -+ line: e.line, -+ column: e.column - 1, // 1-based -+ } -+ throw e -+ } -+ -+ if (!devSourcemap) { -+ return { -+ ast: postcssResult, -+ code: postcssResult.css, -+ map: { mappings: '' }, -+ modules, -+ deps, -+ } -+ } -+ -+ const rawPostcssMap = postcssResult.map.toJSON() -+ -+ const postcssMap = await formatPostcssSourceMap( -+ // version property of rawPostcssMap is declared as string -+ // but actually it is a number -+ rawPostcssMap as Omit as ExistingRawSourceMap, -+ cleanUrl(id), -+ ) -+ -+ return { -+ ast: postcssResult, -+ code: postcssResult.css, -+ map: combineSourcemapsIfExists(cleanUrl(id), postcssMap, preprocessorMap), -+ modules, -+ deps, -+ } -+} -+ -+export type CssUrlReplacer = ( -+ url: string, -+ importer?: string, -+) => string | Promise -+ -+const export functionCallRE = /^[A-Z_][\w-]*\(/i -+ -+export function skipUrlReplacer(rawUrl: string) { -+ return ( -+ isExternalUrl(rawUrl) || -+ isDataUrl(rawUrl) || -+ rawUrl[0] === '#' || -+ functionCallRE.test(rawUrl) -+ ) -+} -+ -+export async function compileLightningCSS( -+ id: string, -+ src: string, -+ config: ResolvedConfig, -+ urlReplacer?: CssUrlReplacer, -+): ReturnType { -+ const deps = new Set() -+ // Relative path is needed to get stable hash when using CSS modules -+ const filename = cleanUrl(path.relative(config.root, id)) -+ const toAbsolute = (filePath: string) => -+ path.isAbsolute(filePath) ? filePath : path.join(config.root, filePath) -+ -+ const res = styleAttrRE.test(id) -+ ? (await importLightningCSS()).transformStyleAttribute({ -+ filename, -+ code: Buffer.from(src), -+ targets: config.css?.lightningcss?.targets, -+ minify: config.isProduction && !!config.build.cssMinify, -+ analyzeDependencies: true, -+ }) -+ : await ( -+ await importLightningCSS() -+ ).bundleAsync({ -+ ...config.css?.lightningcss, -+ filename, -+ resolver: { -+ read(filePath) { -+ if (filePath === filename) { -+ return src -+ } -+ // This happens with html-proxy (#13776) -+ if (!filePath.endsWith('.css')) { -+ return src -+ } -+ return fs.readFileSync(toAbsolute(filePath), 'utf-8') -+ }, -+ async resolve(id, from) { -+ const publicFile = checkPublicFile(id, config) -+ if (publicFile) { -+ return publicFile -+ } -+ -+ const resolved = await getAtImportResolvers(config).css( -+ id, -+ toAbsolute(from), -+ ) -+ -+ if (resolved) { -+ deps.add(resolved) -+ return resolved -+ } -+ return id -+ }, -+ }, -+ minify: config.isProduction && !!config.build.cssMinify, -+ sourceMap: -+ config.command === 'build' -+ ? !!config.build.sourcemap -+ : config.css?.devSourcemap, -+ analyzeDependencies: true, -+ cssModules: cssModuleRE.test(id) -+ ? (config.css?.lightningcss?.cssModules ?? true) -+ : undefined, -+ }) -+ -+ // NodeJS res.code = Buffer -+ // Deno res.code = Uint8Array -+ // For correct decode compiled css need to use TextDecoder -+ let css = decoder.decode(res.code) -+ for (const dep of res.dependencies!) { -+ switch (dep.type) { -+ case 'url': -+ if (skipUrlReplacer(dep.url)) { -+ css = css.replace(dep.placeholder, () => dep.url) -+ break -+ } -+ deps.add(dep.url) -+ if (urlReplacer) { -+ const replaceUrl = await urlReplacer(dep.url, id) -+ css = css.replace(dep.placeholder, () => replaceUrl) -+ } else { -+ css = css.replace(dep.placeholder, () => dep.url) -+ } -+ break -+ default: -+ throw new Error(`Unsupported dependency type: ${dep.type}`) -+ } -+ } -+ -+ let modules: Record | undefined -+ if ('exports' in res && res.exports) { -+ modules = {} -+ // https://github.com/parcel-bundler/lightningcss/issues/291 -+ const sortedEntries = Object.entries(res.exports).sort((a, b) => -+ a[0].localeCompare(b[0]), -+ ) -+ for (const [key, value] of sortedEntries) { -+ modules[key] = value.name -+ // https://lightningcss.dev/css-modules.html#class-composition -+ for (const c of value.composes) { -+ modules[key] += ' ' + c.name -+ } -+ } -+ } -+ -+ return { -+ code: css, -+ map: 'map' in res ? res.map?.toString() : undefined, -+ deps, -+ modules, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/cssAnalysisPlugin.ts b/packages/vite/src/node/plugins/css/cssAnalysisPlugin.ts -new file mode 100644 -index 000000000..5df79c534 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/cssAnalysisPlugin.ts -@@ -0,0 +1,84 @@ -+import { ViteDevServer } from 'packages/vite/src/node/index'; -+import { SPECIAL_QUERY_RE } from 'packages/vite/src/node/constants'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { stripBase } from 'packages/vite/src/node/utils'; -+import { fileToUrl } from 'packages/vite/src/node/plugins/asset'; -+import { CSS_LANGS_RE } from 'packages/vite/src/node/constants'; -+ -+ -+const export commonjsProxyRE = /\?commonjs-proxy/ -+ -+const isCSSRequest = (request: string): boolean => -+ CSS_LANGS_RE.test(request) -+ -+export function cssAnalysisPlugin(config: ResolvedConfig): Plugin { -+ let server: ViteDevServer -+ -+ return { -+ name: 'vite:css-analysis', -+ -+ configureServer(_server) { -+ server = _server -+ }, -+ -+ async transform(_, id, options) { -+ if ( -+ !isCSSRequest(id) || -+ commonjsProxyRE.test(id) || -+ SPECIAL_QUERY_RE.test(id) -+ ) { -+ return -+ } -+ -+ const ssr = options?.ssr === true -+ const { moduleGraph } = server -+ const thisModule = moduleGraph.getModuleById(id) -+ -+ // Handle CSS @import dependency HMR and other added modules via this.addWatchFile. -+ // JS-related HMR is handled in the import-analysis plugin. -+ if (thisModule) { -+ // CSS modules cannot self-accept since it exports values -+ const isSelfAccepting = -+ !cssModulesCache.get(config)?.get(id) && -+ !inlineRE.test(id) && -+ !htmlProxyRE.test(id) -+ // attached by pluginContainer.addWatchFile -+ const pluginImports = (this as unknown as TransformPluginContext) -+ ._addedImports -+ if (pluginImports) { -+ // record deps in the module graph so edits to @import css can trigger -+ // main import to hot update -+ const depModules = new Set() -+ const devBase = config.base -+ for (const file of pluginImports) { -+ depModules.add( -+ isCSSRequest(file) -+ ? moduleGraph.createFileOnlyEntry(file) -+ : await moduleGraph.ensureEntryFromUrl( -+ stripBase( -+ await fileToUrl(file, config, this), -+ (config.server?.origin ?? '') + devBase, -+ ), -+ ssr, -+ ), -+ ) -+ } -+ moduleGraph.updateModuleInfo( -+ thisModule, -+ depModules, -+ null, -+ // The root CSS proxy module is self-accepting and should not -+ // have an explicit accept list -+ new Set(), -+ null, -+ isSelfAccepting, -+ ssr, -+ ) -+ } else { -+ thisModule.isSelfAccepting = isSelfAccepting -+ } -+ } -+ }, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/cssPlugin.ts b/packages/vite/src/node/plugins/css/cssPlugin.ts -new file mode 100644 -index 000000000..520fd5f6f ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/cssPlugin.ts -@@ -0,0 +1,177 @@ -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import { SPECIAL_QUERY_RE } from 'packages/vite/src/node/constants'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { urlRE } from 'packages/vite/src/node/utils'; -+import { CSS_LANGS_RE } from 'packages/vite/src/node/constants'; -+ -+ -+const export commonjsProxyRE = /\?commonjs-proxy/ -+ -+const isCSSRequest = (request: string): boolean => -+ CSS_LANGS_RE.test(request) -+ -+const export cssModuleRE = new RegExp(`\\.module${CSS_LANGS_RE.source}`) -+ -+const isModuleCSSRequest = (request: string): boolean => -+ cssModuleRE.test(request) -+ -+const export cssModulesCache = new WeakMap< -+ ResolvedConfig, -+ Map> -+>() -+ -+export const removedPureCssFilesCache = new WeakMap< -+ ResolvedConfig, -+ Map -+>() -+ -+/** -+ * Plugin applied before user plugins -+ */ -+export function cssPlugin(config: ResolvedConfig): Plugin { -+ const isBuild = config.command === 'build' -+ let moduleCache: Map> -+ -+ const resolveUrl = config.createResolver({ -+ preferRelative: true, -+ tryIndex: false, -+ extensions: [], -+ }) -+ -+ let preprocessorWorkerController: PreprocessorWorkerController | undefined -+ -+ // warm up cache for resolved postcss config -+ if (config.css?.transformer !== 'lightningcss') { -+ resolvePostcssConfig(config) -+ } -+ -+ return { -+ name: 'vite:css', -+ -+ buildStart() { -+ // Ensure a new cache for every build (i.e. rebuilding in watch mode) -+ moduleCache = new Map>() -+ cssModulesCache.set(config, moduleCache) -+ -+ removedPureCssFilesCache.set(config, new Map()) -+ -+ preprocessorWorkerController = createPreprocessorWorkerController( -+ normalizeMaxWorkers(config.css.preprocessorMaxWorkers), -+ ) -+ preprocessorWorkerControllerCache.set( -+ config, -+ preprocessorWorkerController, -+ ) -+ }, -+ -+ buildEnd() { -+ preprocessorWorkerController?.close() -+ }, -+ -+ async load(id) { -+ if (!isCSSRequest(id)) return -+ -+ if (urlRE.test(id)) { -+ if (isModuleCSSRequest(id)) { -+ throw new Error( -+ `?url is not supported with CSS modules. (tried to import ${JSON.stringify( -+ id, -+ )})`, -+ ) -+ } -+ -+ // *.css?url -+ // in dev, it's handled by assets plugin. -+ if (isBuild) { -+ id = injectQuery(removeUrlQuery(id), 'transform-only') -+ return ( -+ `import ${JSON.stringify(id)};` + -+ `export default "__VITE_CSS_URL__${Buffer.from(id).toString( -+ 'hex', -+ )}__"` -+ ) -+ } -+ } -+ }, -+ -+ async transform(raw, id) { -+ if ( -+ !isCSSRequest(id) || -+ commonjsProxyRE.test(id) || -+ SPECIAL_QUERY_RE.test(id) -+ ) { -+ return -+ } -+ const urlReplacer: CssUrlReplacer = async (url, importer) => { -+ const decodedUrl = decodeURI(url) -+ if (checkPublicFile(decodedUrl, config)) { -+ if (encodePublicUrlsInCSS(config)) { -+ return publicFileToBuiltUrl(decodedUrl, config) -+ } else { -+ return joinUrlSegments(config.base, decodedUrl) -+ } -+ } -+ const [id, fragment] = decodedUrl.split('#') -+ let resolved = await resolveUrl(id, importer) -+ if (resolved) { -+ if (fragment) resolved += '#' + fragment -+ return fileToUrl(resolved, config, this) -+ } -+ if (config.command === 'build') { -+ const isExternal = config.build.rollupOptions.external -+ ? resolveUserExternal( -+ config.build.rollupOptions.external, -+ decodedUrl, // use URL as id since id could not be resolved -+ id, -+ false, -+ ) -+ : false -+ -+ if (!isExternal) { -+ // #9800 If we cannot resolve the css url, leave a warning. -+ config.logger.warnOnce( -+ `\n${decodedUrl} referenced in ${id} didn't resolve at build time, it will remain unchanged to be resolved at runtime`, -+ ) -+ } -+ } -+ return url -+ } -+ -+ const { -+ code: css, -+ modules, -+ deps, -+ map, -+ } = await compileCSS( -+ id, -+ raw, -+ config, -+ preprocessorWorkerController!, -+ urlReplacer, -+ ) -+ if (modules) { -+ moduleCache.set(id, modules) -+ } -+ -+ if (deps) { -+ for (const file of deps) { -+ this.addWatchFile(file) -+ } -+ } -+ -+ return { -+ code: css, -+ map, -+ } -+ }, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/cssPostPlugin.ts b/packages/vite/src/node/plugins/css/cssPostPlugin.ts -new file mode 100644 -index 000000000..9b6ae36c6 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/cssPostPlugin.ts -@@ -0,0 +1,602 @@ -+import path from 'node:path' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import { SPECIAL_QUERY_RE } from 'packages/vite/src/node/constants'; -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { getHash } from 'packages/vite/src/node/utils'; -+import { createSerialPromiseQueue } from 'packages/vite/src/node/utils'; -+import { cleanUrl } from 'packages/vite/src/shared/utils'; -+import { addToHTMLProxyTransformResult } from 'packages/vite/src/node/plugins/html'; -+import { CSS_LANGS_RE } from 'packages/vite/src/node/constants'; -+ -+ -+const export cssModuleRE = new RegExp(`\\.module${CSS_LANGS_RE.source}`) -+ -+const export commonjsProxyRE = /\?commonjs-proxy/ -+ -+const export styleAttrRE = /[?&]style-attr\b/ -+ -+const export transformOnlyRE = /[?&]transform-only\b/ -+ -+const export cssBundleName = 'style.css' -+ -+const isCSSRequest = (request: string): boolean => -+ CSS_LANGS_RE.test(request) -+ -+const export directRequestRE = /[?&]direct\b/ -+ -+const isDirectCSSRequest = (request: string): boolean => -+ CSS_LANGS_RE.test(request) && directRequestRE.test(request) -+ -+const export cssUrlAssetRE = /__VITE_CSS_URL__([\da-f]+)__/g -+ -+/** -+ * Plugin applied after user plugins -+ */ -+export function cssPostPlugin(config: ResolvedConfig): Plugin { -+ // styles initialization in buildStart causes a styling loss in watch -+ const styles: Map = new Map() -+ // queue to emit css serially to guarantee the files are emitted in a deterministic order -+ let codeSplitEmitQueue = createSerialPromiseQueue() -+ const urlEmitQueue = createSerialPromiseQueue() -+ let pureCssChunks: Set -+ -+ // when there are multiple rollup outputs and extracting CSS, only emit once, -+ // since output formats have no effect on the generated CSS. -+ let hasEmitted = false -+ let chunkCSSMap: Map -+ -+ const rollupOptionsOutput = config.build.rollupOptions.output -+ const assetFileNames = ( -+ Array.isArray(rollupOptionsOutput) -+ ? rollupOptionsOutput[0] -+ : rollupOptionsOutput -+ )?.assetFileNames -+ const getCssAssetDirname = (cssAssetName: string) => { -+ const cssAssetNameDir = path.dirname(cssAssetName) -+ if (!assetFileNames) { -+ return path.join(config.build.assetsDir, cssAssetNameDir) -+ } else if (typeof assetFileNames === 'string') { -+ return path.join(path.dirname(assetFileNames), cssAssetNameDir) -+ } else { -+ return path.dirname( -+ assetFileNames({ -+ name: cssAssetName, -+ type: 'asset', -+ source: '/* vite internal call, ignore */', -+ }), -+ ) -+ } -+ } -+ -+ return { -+ name: 'vite:css-post', -+ -+ renderStart() { -+ // Ensure new caches for every build (i.e. rebuilding in watch mode) -+ pureCssChunks = new Set() -+ hasEmitted = false -+ chunkCSSMap = new Map() -+ codeSplitEmitQueue = createSerialPromiseQueue() -+ }, -+ -+ async transform(css, id, options) { -+ if ( -+ !isCSSRequest(id) || -+ commonjsProxyRE.test(id) || -+ SPECIAL_QUERY_RE.test(id) -+ ) { -+ return -+ } -+ -+ css = stripBomTag(css) -+ -+ // cache css compile result to map -+ // and then use the cache replace inline-style-flag -+ // when `generateBundle` in vite:build-html plugin and devHtmlHook -+ const inlineCSS = inlineCSSRE.test(id) -+ const isHTMLProxy = htmlProxyRE.test(id) -+ if (inlineCSS && isHTMLProxy) { -+ if (styleAttrRE.test(id)) { -+ css = css.replace(/"/g, '"') -+ } -+ const index = htmlProxyIndexRE.exec(id)?.[1] -+ if (index == null) { -+ throw new Error(`HTML proxy index in "${id}" not found`) -+ } -+ addToHTMLProxyTransformResult( -+ `${getHash(cleanUrl(id))}_${Number.parseInt(index)}`, -+ css, -+ ) -+ return `export default ''` -+ } -+ -+ const inlined = inlineRE.test(id) -+ const modules = cssModulesCache.get(config)!.get(id) -+ -+ // #6984, #7552 -+ // `foo.module.css` => modulesCode -+ // `foo.module.css?inline` => cssContent -+ const modulesCode = -+ modules && -+ !inlined && -+ dataToEsm(modules, { namedExports: true, preferConst: true }) -+ -+ if (config.command === 'serve') { -+ const getContentWithSourcemap = async (content: string) => { -+ if (config.css?.devSourcemap) { -+ const sourcemap = this.getCombinedSourcemap() -+ if (sourcemap.mappings) { -+ await injectSourcesContent(sourcemap, cleanUrl(id), config.logger) -+ } -+ return getCodeWithSourcemap('css', content, sourcemap) -+ } -+ return content -+ } -+ -+ if (isDirectCSSRequest(id)) { -+ return null -+ } -+ // server only -+ if (options?.ssr) { -+ return modulesCode || `export default ${JSON.stringify(css)}` -+ } -+ if (inlined) { -+ return `export default ${JSON.stringify(css)}` -+ } -+ -+ const cssContent = await getContentWithSourcemap(css) -+ const code = [ -+ `import { updateStyle as __vite__updateStyle, removeStyle as __vite__removeStyle } from ${JSON.stringify( -+ path.posix.join(config.base, CLIENT_PUBLIC_PATH), -+ )}`, -+ `const __vite__id = ${JSON.stringify(id)}`, -+ `const __vite__css = ${JSON.stringify(cssContent)}`, -+ `__vite__updateStyle(__vite__id, __vite__css)`, -+ // css modules exports change on edit so it can't self accept -+ `${modulesCode || 'import.meta.hot.accept()'}`, -+ `import.meta.hot.prune(() => __vite__removeStyle(__vite__id))`, -+ ].join('\n') -+ return { code, map: { mappings: '' } } -+ } -+ -+ // build CSS handling ---------------------------------------------------- -+ -+ // record css -+ if (!inlined) { -+ styles.set(id, css) -+ } -+ -+ let code: string -+ if (modulesCode) { -+ code = modulesCode -+ } else if (inlined) { -+ let content = css -+ if (config.build.cssMinify) { -+ content = await minifyCSS(content, config, true) -+ } -+ code = `export default ${JSON.stringify(content)}` -+ } else { -+ // empty module when it's not a CSS module nor `?inline` -+ code = '' -+ } -+ -+ return { -+ code, -+ map: { mappings: '' }, -+ // avoid the css module from being tree-shaken so that we can retrieve -+ // it in renderChunk() -+ moduleSideEffects: modulesCode || inlined ? false : 'no-treeshake', -+ } -+ }, -+ -+ async renderChunk(code, chunk, opts) { -+ let chunkCSS = '' -+ // the chunk is empty if it's a dynamic entry chunk that only contains a CSS import -+ const isJsChunkEmpty = code === '' && !chunk.isEntry -+ let isPureCssChunk = true -+ const ids = Object.keys(chunk.modules) -+ for (const id of ids) { -+ if (styles.has(id)) { -+ // ?transform-only is used for ?url and shouldn't be included in normal CSS chunks -+ if (!transformOnlyRE.test(id)) { -+ chunkCSS += styles.get(id) -+ // a css module contains JS, so it makes this not a pure css chunk -+ if (cssModuleRE.test(id)) { -+ isPureCssChunk = false -+ } -+ } -+ } else if (!isJsChunkEmpty) { -+ // if the module does not have a style, then it's not a pure css chunk. -+ // this is true because in the `transform` hook above, only modules -+ // that are css gets added to the `styles` map. -+ isPureCssChunk = false -+ } -+ } -+ -+ const publicAssetUrlMap = publicAssetUrlCache.get(config)! -+ -+ // resolve asset URL placeholders to their built file URLs -+ const resolveAssetUrlsInCss = ( -+ chunkCSS: string, -+ cssAssetName: string, -+ ) => { -+ const encodedPublicUrls = encodePublicUrlsInCSS(config) -+ -+ const relative = config.base === './' || config.base === '' -+ const cssAssetDirname = -+ encodedPublicUrls || relative -+ ? slash(getCssAssetDirname(cssAssetName)) -+ : undefined -+ -+ const toRelative = (filename: string) => { -+ // relative base + extracted CSS -+ const relativePath = path.posix.relative(cssAssetDirname!, filename) -+ return relativePath[0] === '.' ? relativePath : './' + relativePath -+ } -+ -+ // replace asset url references with resolved url. -+ chunkCSS = chunkCSS.replace(assetUrlRE, (_, fileHash, postfix = '') => { -+ const filename = this.getFileName(fileHash) + postfix -+ chunk.viteMetadata!.importedAssets.add(cleanUrl(filename)) -+ return encodeURIPath( -+ toOutputFilePathInCss( -+ filename, -+ 'asset', -+ cssAssetName, -+ 'css', -+ config, -+ toRelative, -+ ), -+ ) -+ }) -+ // resolve public URL from CSS paths -+ if (encodedPublicUrls) { -+ const relativePathToPublicFromCSS = path.posix.relative( -+ cssAssetDirname!, -+ '', -+ ) -+ chunkCSS = chunkCSS.replace(publicAssetUrlRE, (_, hash) => { -+ const publicUrl = publicAssetUrlMap.get(hash)!.slice(1) -+ return encodeURIPath( -+ toOutputFilePathInCss( -+ publicUrl, -+ 'public', -+ cssAssetName, -+ 'css', -+ config, -+ () => `${relativePathToPublicFromCSS}/${publicUrl}`, -+ ), -+ ) -+ }) -+ } -+ return chunkCSS -+ } -+ -+ function ensureFileExt(name: string, ext: string) { -+ return normalizePath( -+ path.format({ ...path.parse(name), base: undefined, ext }), -+ ) -+ } -+ -+ let s: MagicString | undefined -+ const urlEmitTasks: Array<{ -+ cssAssetName: string -+ originalFilename: string -+ content: string -+ start: number -+ end: number -+ }> = [] -+ -+ if (code.includes('__VITE_CSS_URL__')) { -+ let match: RegExpExecArray | null -+ cssUrlAssetRE.lastIndex = 0 -+ while ((match = cssUrlAssetRE.exec(code))) { -+ const [full, idHex] = match -+ const id = Buffer.from(idHex, 'hex').toString() -+ const originalFilename = cleanUrl(id) -+ const cssAssetName = ensureFileExt( -+ path.basename(originalFilename), -+ '.css', -+ ) -+ if (!styles.has(id)) { -+ throw new Error( -+ `css content for ${JSON.stringify(id)} was not found`, -+ ) -+ } -+ -+ let cssContent = styles.get(id)! -+ -+ cssContent = resolveAssetUrlsInCss(cssContent, cssAssetName) -+ -+ urlEmitTasks.push({ -+ cssAssetName, -+ originalFilename, -+ content: cssContent, -+ start: match.index, -+ end: match.index + full.length, -+ }) -+ } -+ } -+ -+ // should await even if this chunk does not include __VITE_CSS_URL__ -+ // so that code after this line runs in the same order -+ await urlEmitQueue.run(async () => -+ Promise.all( -+ urlEmitTasks.map(async (info) => { -+ info.content = await finalizeCss(info.content, true, config) -+ }), -+ ), -+ ) -+ if (urlEmitTasks.length > 0) { -+ const toRelativeRuntime = createToImportMetaURLBasedRelativeRuntime( -+ opts.format, -+ config.isWorker, -+ ) -+ s ||= new MagicString(code) -+ -+ for (const { -+ cssAssetName, -+ originalFilename, -+ content, -+ start, -+ end, -+ } of urlEmitTasks) { -+ const referenceId = this.emitFile({ -+ name: cssAssetName, -+ type: 'asset', -+ source: content, -+ }) -+ generatedAssets -+ .get(config)! -+ .set(referenceId, { originalName: originalFilename }) -+ -+ const filename = this.getFileName(referenceId) -+ chunk.viteMetadata!.importedAssets.add(cleanUrl(filename)) -+ const replacement = toOutputFilePathInJS( -+ filename, -+ 'asset', -+ chunk.fileName, -+ 'js', -+ config, -+ toRelativeRuntime, -+ ) -+ const replacementString = -+ typeof replacement === 'string' -+ ? JSON.stringify(encodeURIPath(replacement)).slice(1, -1) -+ : `"+${replacement.runtime}+"` -+ s.update(start, end, replacementString) -+ } -+ } -+ -+ if (chunkCSS) { -+ if (isPureCssChunk && (opts.format === 'es' || opts.format === 'cjs')) { -+ // this is a shared CSS-only chunk that is empty. -+ pureCssChunks.add(chunk) -+ } -+ -+ if (config.build.cssCodeSplit) { -+ if (opts.format === 'es' || opts.format === 'cjs') { -+ const isEntry = chunk.isEntry && isPureCssChunk -+ const cssFullAssetName = ensureFileExt(chunk.name, '.css') -+ // if facadeModuleId doesn't exist or doesn't have a CSS extension, -+ // that means a JS entry file imports a CSS file. -+ // in this case, only use the filename for the CSS chunk name like JS chunks. -+ const cssAssetName = -+ chunk.isEntry && -+ (!chunk.facadeModuleId || !isCSSRequest(chunk.facadeModuleId)) -+ ? path.basename(cssFullAssetName) -+ : cssFullAssetName -+ const originalFilename = getChunkOriginalFileName( -+ chunk, -+ config.root, -+ opts.format, -+ ) -+ -+ chunkCSS = resolveAssetUrlsInCss(chunkCSS, cssAssetName) -+ -+ // wait for previous tasks as well -+ chunkCSS = await codeSplitEmitQueue.run(async () => { -+ return finalizeCss(chunkCSS, true, config) -+ }) -+ -+ // emit corresponding css file -+ const referenceId = this.emitFile({ -+ name: cssAssetName, -+ type: 'asset', -+ source: chunkCSS, -+ }) -+ generatedAssets -+ .get(config)! -+ .set(referenceId, { originalName: originalFilename, isEntry }) -+ chunk.viteMetadata!.importedCss.add(this.getFileName(referenceId)) -+ } else if (!config.build.ssr) { -+ // legacy build and inline css -+ -+ // Entry chunk CSS will be collected into `chunk.viteMetadata.importedCss` -+ // and injected later by the `'vite:build-html'` plugin into the `index.html` -+ // so it will be duplicated. (https://github.com/vitejs/vite/issues/2062#issuecomment-782388010) -+ // But because entry chunk can be imported by dynamic import, -+ // we shouldn't remove the inlined CSS. (#10285) -+ -+ chunkCSS = await finalizeCss(chunkCSS, true, config) -+ let cssString = JSON.stringify(chunkCSS) -+ cssString = -+ renderAssetUrlInJS( -+ this, -+ config, -+ chunk, -+ opts, -+ cssString, -+ )?.toString() || cssString -+ const style = `__vite_style__` -+ const injectCode = -+ `var ${style} = document.createElement('style');` + -+ `${style}.textContent = ${cssString};` + -+ `document.head.appendChild(${style});` -+ let injectionPoint -+ const wrapIdx = code.indexOf('System.register') -+ if (wrapIdx >= 0) { -+ const executeFnStart = code.indexOf('execute:', wrapIdx) -+ injectionPoint = code.indexOf('{', executeFnStart) + 1 -+ } else { -+ const insertMark = "'use strict';" -+ injectionPoint = code.indexOf(insertMark) + insertMark.length -+ } -+ s ||= new MagicString(code) -+ s.appendRight(injectionPoint, injectCode) -+ } -+ } else { -+ // resolve public URL from CSS paths, we need to use absolute paths -+ chunkCSS = resolveAssetUrlsInCss(chunkCSS, cssBundleName) -+ // finalizeCss is called for the aggregated chunk in generateBundle -+ -+ chunkCSSMap.set(chunk.fileName, chunkCSS) -+ } -+ } -+ -+ if (s) { -+ if (config.build.sourcemap) { -+ return { -+ code: s.toString(), -+ map: s.generateMap({ hires: 'boundary' }), -+ } -+ } else { -+ return { code: s.toString() } -+ } -+ } -+ return null -+ }, -+ -+ augmentChunkHash(chunk) { -+ if (chunk.viteMetadata?.importedCss.size) { -+ let hash = '' -+ for (const id of chunk.viteMetadata.importedCss) { -+ hash += id -+ } -+ return hash -+ } -+ }, -+ -+ async generateBundle(opts, bundle) { -+ // @ts-expect-error asset emits are skipped in legacy bundle -+ if (opts.__vite_skip_asset_emit__) { -+ return -+ } -+ -+ function extractCss() { -+ let css = '' -+ const collected = new Set() -+ // will be populated in order they are used by entry points -+ const dynamicImports = new Set() -+ -+ function collect(chunk: OutputChunk | OutputAsset) { -+ if (!chunk || chunk.type !== 'chunk' || collected.has(chunk)) return -+ collected.add(chunk) -+ -+ // First collect all styles from the synchronous imports (lowest priority) -+ chunk.imports.forEach((importName) => collect(bundle[importName])) -+ // Save dynamic imports in deterministic order to add the styles later (to have the highest priority) -+ chunk.dynamicImports.forEach((importName) => -+ dynamicImports.add(importName), -+ ) -+ // Then collect the styles of the current chunk (might overwrite some styles from previous imports) -+ css += chunkCSSMap.get(chunk.preliminaryFileName) ?? '' -+ } -+ -+ // The bundle is guaranteed to be deterministic, if not then we have a bug in rollup. -+ // So we use it to ensure a deterministic order of styles -+ for (const chunk of Object.values(bundle)) { -+ if (chunk.type === 'chunk' && chunk.isEntry) { -+ collect(chunk) -+ } -+ } -+ // Now collect the dynamic chunks, this is done last to have the styles overwrite the previous ones -+ for (const chunkName of dynamicImports) { -+ collect(bundle[chunkName]) -+ } -+ -+ return css -+ } -+ let extractedCss = !hasEmitted && extractCss() -+ if (extractedCss) { -+ hasEmitted = true -+ extractedCss = await finalizeCss(extractedCss, true, config) -+ this.emitFile({ -+ name: cssBundleName, -+ type: 'asset', -+ source: extractedCss, -+ }) -+ } -+ -+ // remove empty css chunks and their imports -+ if (pureCssChunks.size) { -+ // map each pure css chunk (rendered chunk) to it's corresponding bundle -+ // chunk. we check that by `preliminaryFileName` as they have different -+ // `filename`s (rendered chunk has the !~{XXX}~ placeholder) -+ const prelimaryNameToChunkMap = Object.fromEntries( -+ Object.values(bundle) -+ .filter((chunk): chunk is OutputChunk => chunk.type === 'chunk') -+ .map((chunk) => [chunk.preliminaryFileName, chunk.fileName]), -+ ) -+ -+ // When running in watch mode the generateBundle is called once per output format -+ // in this case the `bundle` is not populated with the other output files -+ // but they are still in `pureCssChunks`. -+ // So we need to filter the names and only use those who are defined -+ const pureCssChunkNames = [...pureCssChunks] -+ .map((pureCssChunk) => prelimaryNameToChunkMap[pureCssChunk.fileName]) -+ .filter(Boolean) -+ -+ const replaceEmptyChunk = getEmptyChunkReplacer( -+ pureCssChunkNames, -+ opts.format, -+ ) -+ -+ for (const file in bundle) { -+ const chunk = bundle[file] -+ if (chunk.type === 'chunk') { -+ let chunkImportsPureCssChunk = false -+ // remove pure css chunk from other chunk's imports, -+ // and also register the emitted CSS files under the importer -+ // chunks instead. -+ chunk.imports = chunk.imports.filter((file) => { -+ if (pureCssChunkNames.includes(file)) { -+ const { importedCss, importedAssets } = ( -+ bundle[file] as OutputChunk -+ ).viteMetadata! -+ importedCss.forEach((file) => -+ chunk.viteMetadata!.importedCss.add(file), -+ ) -+ importedAssets.forEach((file) => -+ chunk.viteMetadata!.importedAssets.add(file), -+ ) -+ chunkImportsPureCssChunk = true -+ return false -+ } -+ return true -+ }) -+ if (chunkImportsPureCssChunk) { -+ chunk.code = replaceEmptyChunk(chunk.code) -+ } -+ } -+ } -+ -+ const removedPureCssFiles = removedPureCssFilesCache.get(config)! -+ pureCssChunkNames.forEach((fileName) => { -+ removedPureCssFiles.set(fileName, bundle[fileName] as RenderedChunk) -+ delete bundle[fileName] -+ delete bundle[`${fileName}.map`] -+ }) -+ } -+ }, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/lessProcessor.ts b/packages/vite/src/node/plugins/css/lessProcessor.ts -new file mode 100644 -index 000000000..93eed495a ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/lessProcessor.ts -@@ -0,0 +1,368 @@ -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import type { Alias } from 'dep-types/alias' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import type Less from 'less' -+import { WorkerWithFallback } from 'artichokie' -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { asyncReplace } from 'packages/vite/src/node/utils'; -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export interface StylePreprocessorResults { -+ code: string -+ map?: ExistingRawSourceMap | undefined -+ additionalMap?: ExistingRawSourceMap | undefined -+ error?: RollupError -+ deps: string[] -+} -+ -+export type StylePreprocessor = { -+ process: ( -+ source: string, -+ root: string, -+ options: StylePreprocessorOptions, -+ resolvers: CSSAtImportResolvers, -+ ) => StylePreprocessorResults | Promise -+ close: () => void -+} -+ -+// https://drafts.csswg.org/css-syntax-3/#identifier-code-point -+export const cssUrlRE = -+ /(?<=^|[^\w\-\u0080-\uffff])url\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ -+ -+export const cssDataUriRE = -+ /(?<=^|[^\w\-\u0080-\uffff])data-uri\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ -+ -+export const importCssRE = /@import ('[^']+\.css'|"[^"]+\.css"|[^'")]+\.css)/ -+ -+export type CssUrlReplacer = ( -+ url: string, -+ importer?: string, -+) => string | Promise -+ -+export function rewriteCssUrls( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, cssUrlRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doUrlReplace(rawUrl.trim(), matched, replacer) -+ }) -+} -+ -+export function rewriteCssDataUris( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, cssDataUriRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doUrlReplace(rawUrl.trim(), matched, replacer, 'data-uri') -+ }) -+} -+ -+export function rewriteImportCss( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, importCssRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doImportCSSReplace(rawUrl, matched, replacer) -+ }) -+} -+ -+/** -+ * relative url() inside \@imported sass and less files must be rebased to use -+ * root file as base. -+ */ -+export async function rebaseUrls( -+ file: string, -+ rootFile: string, -+ alias: Alias[], -+ variablePrefix: string, -+ resolver: ResolveFn, -+): Promise<{ file: string; contents?: string }> { -+ file = path.resolve(file) // ensure os-specific flashes -+ // in the same dir, no need to rebase -+ const fileDir = path.dirname(file) -+ const rootDir = path.dirname(rootFile) -+ if (fileDir === rootDir) { -+ return { file } -+ } -+ -+ const content = await fsp.readFile(file, 'utf-8') -+ // no url() -+ const hasUrls = cssUrlRE.test(content) -+ // data-uri() calls -+ const hasDataUris = cssDataUriRE.test(content) -+ // no @import xxx.css -+ const hasImportCss = importCssRE.test(content) -+ -+ if (!hasUrls && !hasDataUris && !hasImportCss) { -+ return { file } -+ } -+ -+ let rebased -+ const rebaseFn = async (url: string) => { -+ if (url[0] === '/') return url -+ // ignore url's starting with variable -+ if (url.startsWith(variablePrefix)) return url -+ // match alias, no need to rewrite -+ for (const { find } of alias) { -+ const matches = -+ typeof find === 'string' ? url.startsWith(find) : find.test(url) -+ if (matches) { -+ return url -+ } -+ } -+ const absolute = (await resolver(url, file)) || path.resolve(fileDir, url) -+ const relative = path.relative(rootDir, absolute) -+ return normalizePath(relative) -+ } -+ -+ // fix css imports in less such as `@import "foo.css"` -+ if (hasImportCss) { -+ rebased = await rewriteImportCss(content, rebaseFn) -+ } -+ -+ if (hasUrls) { -+ rebased = await rewriteCssUrls(rebased || content, rebaseFn) -+ } -+ -+ if (hasDataUris) { -+ rebased = await rewriteCssDataUris(rebased || content, rebaseFn) -+ } -+ -+ return { -+ file, -+ contents: rebased, -+ } -+} -+ -+// #region Less -+// .less -+export const makeLessWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const viteLessResolve = async ( -+ filename: string, -+ dir: string, -+ rootFile: string, -+ ) => { -+ const resolved = await resolvers.less(filename, path.join(dir, '*')) -+ if (!resolved) return undefined -+ -+ const result = await rebaseUrls( -+ resolved, -+ rootFile, -+ alias, -+ '@', -+ resolvers.less, -+ ) -+ if (result) { -+ return { -+ resolved, -+ contents: 'contents' in result ? result.contents : undefined, -+ } -+ } -+ return result -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const fsp = require('node:fs/promises') -+ // eslint-disable-next-line no-restricted-globals -+ const path = require('node:path') -+ -+ let ViteLessManager: any -+ const createViteLessPlugin = ( -+ less: typeof Less, -+ rootFile: string, -+ ): Less.Plugin => { -+ const { FileManager } = less -+ ViteLessManager ??= class ViteManager extends FileManager { -+ rootFile -+ constructor(rootFile: string) { -+ super() -+ this.rootFile = rootFile -+ } -+ override supports(filename: string) { -+ return !/^(?:https?:)?\/\//.test(filename) -+ } -+ override supportsSync() { -+ return false -+ } -+ override async loadFile( -+ filename: string, -+ dir: string, -+ opts: any, -+ env: any, -+ ): Promise { -+ const result = await viteLessResolve(filename, dir, this.rootFile) -+ if (result) { -+ return { -+ filename: path.resolve(result.resolved), -+ contents: -+ result.contents ?? -+ (await fsp.readFile(result.resolved, 'utf-8')), -+ } -+ } else { -+ return super.loadFile(filename, dir, opts, env) -+ } -+ } -+ } -+ -+ return { -+ install(_, pluginManager) { -+ pluginManager.addFileManager(new ViteLessManager(rootFile)) -+ }, -+ minVersion: [3, 0, 0], -+ } -+ } -+ -+ return async ( -+ lessPath: string, -+ content: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: StylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const nodeLess: typeof Less = require(lessPath) -+ const viteResolverPlugin = createViteLessPlugin( -+ nodeLess, -+ options.filename, -+ ) -+ const result = await nodeLess.render(content, { -+ ...options, -+ plugins: [viteResolverPlugin, ...(options.plugins || [])], -+ ...(options.enableSourcemap -+ ? { -+ sourceMap: { -+ outputSourceFiles: true, -+ sourceMapFileInline: false, -+ }, -+ } -+ : {}), -+ }) -+ return result -+ } -+ }, -+ { -+ parentFunctions: { viteLessResolve }, -+ shouldUseFake(_lessPath, _content, options) { -+ // plugins are a function and is not serializable -+ // in that case, fallback to running in main thread -+ return options.plugins?.length > 0 -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -+ -+export const lessProcessor = (maxWorkers: number | undefined): StylePreprocessor => { -+ const workerMap = new Map>() -+ -+ return { -+ close() { -+ for (const worker of workerMap.values()) { -+ worker.stop() -+ } -+ }, -+ async process(source, root, options, resolvers) { -+ const lessPath = loadPreprocessorPath(PreprocessLang.less, root) -+ -+ if (!workerMap.has(options.alias)) { -+ workerMap.set( -+ options.alias, -+ makeLessWorker(resolvers, options.alias, maxWorkers), -+ ) -+ } -+ const worker = workerMap.get(options.alias)! -+ -+ const { content, map: additionalMap } = await getSource( -+ source, -+ options.filename, -+ options.additionalData, -+ options.enableSourcemap, -+ ) -+ -+ let result: Less.RenderOutput | undefined -+ const optionsWithoutAdditionalData = { -+ ...options, -+ additionalData: undefined, -+ } -+ try { -+ result = await worker.run( -+ lessPath, -+ content, -+ optionsWithoutAdditionalData, -+ ) -+ } catch (e) { -+ const error = e as Less.RenderError -+ // normalize error info -+ const normalizedError: RollupError = new Error( -+ `[less] ${error.message || error.type}`, -+ ) as RollupError -+ normalizedError.loc = { -+ file: error.filename || options.filename, -+ line: error.line, -+ column: error.column, -+ } -+ return { code: '', error: normalizedError, deps: [] } -+ } -+ -+ const map: ExistingRawSourceMap = result.map && JSON.parse(result.map) -+ if (map) { -+ delete map.sourcesContent -+ } -+ -+ return { -+ code: result.css.toString(), -+ map, -+ additionalMap, -+ deps: result.imports, -+ } -+ }, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/makeLessWorker.ts b/packages/vite/src/node/plugins/css/makeLessWorker.ts -new file mode 100644 -index 000000000..b288ea776 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/makeLessWorker.ts -@@ -0,0 +1,283 @@ -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import type Less from 'less' -+import type { Alias } from 'dep-types/alias' -+import { WorkerWithFallback } from 'artichokie' -+import { Plugin } from 'packages/vite/src/node/plugin'; -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { asyncReplace } from 'packages/vite/src/node/utils'; -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+// https://drafts.csswg.org/css-syntax-3/#identifier-code-point -+export const cssUrlRE = -+ /(?<=^|[^\w\-\u0080-\uffff])url\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ -+ -+export const cssDataUriRE = -+ /(?<=^|[^\w\-\u0080-\uffff])data-uri\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ -+ -+export const importCssRE = /@import ('[^']+\.css'|"[^"]+\.css"|[^'")]+\.css)/ -+ -+export type CssUrlReplacer = ( -+ url: string, -+ importer?: string, -+) => string | Promise -+ -+export function rewriteCssUrls( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, cssUrlRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doUrlReplace(rawUrl.trim(), matched, replacer) -+ }) -+} -+ -+export function rewriteCssDataUris( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, cssDataUriRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doUrlReplace(rawUrl.trim(), matched, replacer, 'data-uri') -+ }) -+} -+ -+export function rewriteImportCss( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, importCssRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doImportCSSReplace(rawUrl, matched, replacer) -+ }) -+} -+ -+/** -+ * relative url() inside \@imported sass and less files must be rebased to use -+ * root file as base. -+ */ -+export async function rebaseUrls( -+ file: string, -+ rootFile: string, -+ alias: Alias[], -+ variablePrefix: string, -+ resolver: ResolveFn, -+): Promise<{ file: string; contents?: string }> { -+ file = path.resolve(file) // ensure os-specific flashes -+ // in the same dir, no need to rebase -+ const fileDir = path.dirname(file) -+ const rootDir = path.dirname(rootFile) -+ if (fileDir === rootDir) { -+ return { file } -+ } -+ -+ const content = await fsp.readFile(file, 'utf-8') -+ // no url() -+ const hasUrls = cssUrlRE.test(content) -+ // data-uri() calls -+ const hasDataUris = cssDataUriRE.test(content) -+ // no @import xxx.css -+ const hasImportCss = importCssRE.test(content) -+ -+ if (!hasUrls && !hasDataUris && !hasImportCss) { -+ return { file } -+ } -+ -+ let rebased -+ const rebaseFn = async (url: string) => { -+ if (url[0] === '/') return url -+ // ignore url's starting with variable -+ if (url.startsWith(variablePrefix)) return url -+ // match alias, no need to rewrite -+ for (const { find } of alias) { -+ const matches = -+ typeof find === 'string' ? url.startsWith(find) : find.test(url) -+ if (matches) { -+ return url -+ } -+ } -+ const absolute = (await resolver(url, file)) || path.resolve(fileDir, url) -+ const relative = path.relative(rootDir, absolute) -+ return normalizePath(relative) -+ } -+ -+ // fix css imports in less such as `@import "foo.css"` -+ if (hasImportCss) { -+ rebased = await rewriteImportCss(content, rebaseFn) -+ } -+ -+ if (hasUrls) { -+ rebased = await rewriteCssUrls(rebased || content, rebaseFn) -+ } -+ -+ if (hasDataUris) { -+ rebased = await rewriteCssDataUris(rebased || content, rebaseFn) -+ } -+ -+ return { -+ file, -+ contents: rebased, -+ } -+} -+ -+// #region Less -+// .less -+export const makeLessWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const viteLessResolve = async ( -+ filename: string, -+ dir: string, -+ rootFile: string, -+ ) => { -+ const resolved = await resolvers.less(filename, path.join(dir, '*')) -+ if (!resolved) return undefined -+ -+ const result = await rebaseUrls( -+ resolved, -+ rootFile, -+ alias, -+ '@', -+ resolvers.less, -+ ) -+ if (result) { -+ return { -+ resolved, -+ contents: 'contents' in result ? result.contents : undefined, -+ } -+ } -+ return result -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const fsp = require('node:fs/promises') -+ // eslint-disable-next-line no-restricted-globals -+ const path = require('node:path') -+ -+ let ViteLessManager: any -+ const createViteLessPlugin = ( -+ less: typeof Less, -+ rootFile: string, -+ ): Less.Plugin => { -+ const { FileManager } = less -+ ViteLessManager ??= class ViteManager extends FileManager { -+ rootFile -+ constructor(rootFile: string) { -+ super() -+ this.rootFile = rootFile -+ } -+ override supports(filename: string) { -+ return !/^(?:https?:)?\/\//.test(filename) -+ } -+ override supportsSync() { -+ return false -+ } -+ override async loadFile( -+ filename: string, -+ dir: string, -+ opts: any, -+ env: any, -+ ): Promise { -+ const result = await viteLessResolve(filename, dir, this.rootFile) -+ if (result) { -+ return { -+ filename: path.resolve(result.resolved), -+ contents: -+ result.contents ?? -+ (await fsp.readFile(result.resolved, 'utf-8')), -+ } -+ } else { -+ return super.loadFile(filename, dir, opts, env) -+ } -+ } -+ } -+ -+ return { -+ install(_, pluginManager) { -+ pluginManager.addFileManager(new ViteLessManager(rootFile)) -+ }, -+ minVersion: [3, 0, 0], -+ } -+ } -+ -+ return async ( -+ lessPath: string, -+ content: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: StylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const nodeLess: typeof Less = require(lessPath) -+ const viteResolverPlugin = createViteLessPlugin( -+ nodeLess, -+ options.filename, -+ ) -+ const result = await nodeLess.render(content, { -+ ...options, -+ plugins: [viteResolverPlugin, ...(options.plugins || [])], -+ ...(options.enableSourcemap -+ ? { -+ sourceMap: { -+ outputSourceFiles: true, -+ sourceMapFileInline: false, -+ }, -+ } -+ : {}), -+ }) -+ return result -+ } -+ }, -+ { -+ parentFunctions: { viteLessResolve }, -+ shouldUseFake(_lessPath, _content, options) { -+ // plugins are a function and is not serializable -+ // in that case, fallback to running in main thread -+ return options.plugins?.length > 0 -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/makeModernCompilerScssWorker.ts b/packages/vite/src/node/plugins/css/makeModernCompilerScssWorker.ts -new file mode 100644 -index 000000000..7078df845 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/makeModernCompilerScssWorker.ts -@@ -0,0 +1,243 @@ -+import { fileURLToPath, pathToFileURL } from 'node:url' -+import type { Alias } from 'dep-types/alias' -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import type Sass from 'sass' -+import { WorkerWithFallback } from 'artichokie' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export type SassStylePreprocessorOptions = StylePreprocessorOptions & -+ Omit, 'data' | 'file' | 'outFile'> & { -+ api?: 'legacy' | 'modern' | 'modern-compiler' -+ } -+ -+// in unix, scss might append `location.href` in environments that shim `location` -+// see https://github.com/sass/dart-sass/issues/710 -+export function cleanScssBugUrl(url: string) { -+ if ( -+ // check bug via `window` and `location` global -+ typeof window !== 'undefined' && -+ typeof location !== 'undefined' && -+ typeof location?.href === 'string' -+ ) { -+ const prefix = location.href.replace(/\/$/, '') -+ return url.replace(prefix, '') -+ } else { -+ return url -+ } -+} -+ -+export const makeModernScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const internalCanonicalize = async ( -+ url: string, -+ importer: string, -+ ): Promise => { -+ importer = cleanScssBugUrl(importer) -+ const resolved = await resolvers.sass(url, importer) -+ return resolved ?? null -+ } -+ -+ const internalLoad = async (file: string, rootFile: string) => { -+ const result = await rebaseUrls(file, rootFile, alias, '$', resolvers.sass) -+ if (result.contents) { -+ return result.contents -+ } -+ return await fsp.readFile(result.file, 'utf-8') -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => -+ async ( -+ sassPath: string, -+ data: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: SassStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const sass: typeof Sass = require(sassPath) -+ // eslint-disable-next-line no-restricted-globals -+ const path: typeof import('node:path') = require('node:path') -+ -+ const { fileURLToPath, pathToFileURL }: typeof import('node:url') = -+ // eslint-disable-next-line no-restricted-globals -+ require('node:url') -+ -+ const sassOptions = { ...options } as Sass.StringOptions<'async'> -+ sassOptions.url = pathToFileURL(options.filename) -+ sassOptions.sourceMap = options.enableSourcemap -+ -+ const internalImporter: Sass.Importer<'async'> = { -+ async canonicalize(url, context) { -+ const importer = context.containingUrl -+ ? fileURLToPath(context.containingUrl) -+ : options.filename -+ const resolved = await internalCanonicalize(url, importer) -+ return resolved ? pathToFileURL(resolved) : null -+ }, -+ async load(canonicalUrl) { -+ const ext = path.extname(canonicalUrl.pathname) -+ let syntax: Sass.Syntax = 'scss' -+ if (ext === '.sass') { -+ syntax = 'indented' -+ } else if (ext === '.css') { -+ syntax = 'css' -+ } -+ const contents = await internalLoad( -+ fileURLToPath(canonicalUrl), -+ options.filename, -+ ) -+ return { contents, syntax } -+ }, -+ } -+ sassOptions.importers = [ -+ ...(sassOptions.importers ?? []), -+ internalImporter, -+ ] -+ -+ const result = await sass.compileStringAsync(data, sassOptions) -+ return { -+ css: result.css, -+ map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -+ stats: { -+ includedFiles: result.loadedUrls -+ .filter((url) => url.protocol === 'file:') -+ .map((url) => fileURLToPath(url)), -+ }, -+ } satisfies ScssWorkerResult -+ }, -+ { -+ parentFunctions: { -+ internalCanonicalize, -+ internalLoad, -+ }, -+ shouldUseFake(_sassPath, _data, options) { -+ // functions and importer is a function and is not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ (options.functions && Object.keys(options.functions).length > 0) || -+ (options.importers && -+ (!Array.isArray(options.importers) || options.importers.length > 0)) -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -+ -+// this is mostly a copy&paste of makeModernScssWorker -+// however sharing code between two is hard because -+// makeModernScssWorker above needs function inlined for worker. -+export const makeModernCompilerScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ _maxWorkers: number | undefined, -+) => { -+ let compiler: Sass.AsyncCompiler | undefined -+ -+ const worker: Awaited> = { -+ async run(sassPath, data, options) { -+ // need pathToFileURL for windows since import("D:...") fails -+ // https://github.com/nodejs/node/issues/31710 -+ const sass: typeof Sass = (await import(pathToFileURL(sassPath).href)) -+ .default -+ compiler ??= await sass.initAsyncCompiler() -+ -+ const sassOptions = { ...options } as Sass.StringOptions<'async'> -+ sassOptions.url = pathToFileURL(options.filename) -+ sassOptions.sourceMap = options.enableSourcemap -+ -+ const internalImporter: Sass.Importer<'async'> = { -+ async canonicalize(url, context) { -+ const importer = context.containingUrl -+ ? fileURLToPath(context.containingUrl) -+ : options.filename -+ const resolved = await resolvers.sass(url, cleanScssBugUrl(importer)) -+ return resolved ? pathToFileURL(resolved) : null -+ }, -+ async load(canonicalUrl) { -+ const ext = path.extname(canonicalUrl.pathname) -+ let syntax: Sass.Syntax = 'scss' -+ if (ext === '.sass') { -+ syntax = 'indented' -+ } else if (ext === '.css') { -+ syntax = 'css' -+ } -+ const result = await rebaseUrls( -+ fileURLToPath(canonicalUrl), -+ options.filename, -+ alias, -+ '$', -+ resolvers.sass, -+ ) -+ const contents = -+ result.contents ?? (await fsp.readFile(result.file, 'utf-8')) -+ return { contents, syntax } -+ }, -+ } -+ sassOptions.importers = [ -+ ...(sassOptions.importers ?? []), -+ internalImporter, -+ ] -+ -+ const result = await compiler.compileStringAsync(data, sassOptions) -+ return { -+ css: result.css, -+ map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -+ stats: { -+ includedFiles: result.loadedUrls -+ .filter((url) => url.protocol === 'file:') -+ .map((url) => fileURLToPath(url)), -+ }, -+ } satisfies ScssWorkerResult -+ }, -+ async stop() { -+ compiler?.dispose() -+ compiler = undefined -+ }, -+ } -+ -+ return worker -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/makeModernScssWorker.ts b/packages/vite/src/node/plugins/css/makeModernScssWorker.ts -new file mode 100644 -index 000000000..834ef15ba ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/makeModernScssWorker.ts -@@ -0,0 +1,168 @@ -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import { fileURLToPath, pathToFileURL } from 'node:url' -+import type Sass from 'sass' -+import type { Alias } from 'dep-types/alias' -+import { WorkerWithFallback } from 'artichokie' -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export type SassStylePreprocessorOptions = StylePreprocessorOptions & -+ Omit, 'data' | 'file' | 'outFile'> & { -+ api?: 'legacy' | 'modern' | 'modern-compiler' -+ } -+ -+// in unix, scss might append `location.href` in environments that shim `location` -+// see https://github.com/sass/dart-sass/issues/710 -+export function cleanScssBugUrl(url: string) { -+ if ( -+ // check bug via `window` and `location` global -+ typeof window !== 'undefined' && -+ typeof location !== 'undefined' && -+ typeof location?.href === 'string' -+ ) { -+ const prefix = location.href.replace(/\/$/, '') -+ return url.replace(prefix, '') -+ } else { -+ return url -+ } -+} -+ -+export const makeModernScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const internalCanonicalize = async ( -+ url: string, -+ importer: string, -+ ): Promise => { -+ importer = cleanScssBugUrl(importer) -+ const resolved = await resolvers.sass(url, importer) -+ return resolved ?? null -+ } -+ -+ const internalLoad = async (file: string, rootFile: string) => { -+ const result = await rebaseUrls(file, rootFile, alias, '$', resolvers.sass) -+ if (result.contents) { -+ return result.contents -+ } -+ return await fsp.readFile(result.file, 'utf-8') -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => -+ async ( -+ sassPath: string, -+ data: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: SassStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const sass: typeof Sass = require(sassPath) -+ // eslint-disable-next-line no-restricted-globals -+ const path: typeof import('node:path') = require('node:path') -+ -+ const { fileURLToPath, pathToFileURL }: typeof import('node:url') = -+ // eslint-disable-next-line no-restricted-globals -+ require('node:url') -+ -+ const sassOptions = { ...options } as Sass.StringOptions<'async'> -+ sassOptions.url = pathToFileURL(options.filename) -+ sassOptions.sourceMap = options.enableSourcemap -+ -+ const internalImporter: Sass.Importer<'async'> = { -+ async canonicalize(url, context) { -+ const importer = context.containingUrl -+ ? fileURLToPath(context.containingUrl) -+ : options.filename -+ const resolved = await internalCanonicalize(url, importer) -+ return resolved ? pathToFileURL(resolved) : null -+ }, -+ async load(canonicalUrl) { -+ const ext = path.extname(canonicalUrl.pathname) -+ let syntax: Sass.Syntax = 'scss' -+ if (ext === '.sass') { -+ syntax = 'indented' -+ } else if (ext === '.css') { -+ syntax = 'css' -+ } -+ const contents = await internalLoad( -+ fileURLToPath(canonicalUrl), -+ options.filename, -+ ) -+ return { contents, syntax } -+ }, -+ } -+ sassOptions.importers = [ -+ ...(sassOptions.importers ?? []), -+ internalImporter, -+ ] -+ -+ const result = await sass.compileStringAsync(data, sassOptions) -+ return { -+ css: result.css, -+ map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -+ stats: { -+ includedFiles: result.loadedUrls -+ .filter((url) => url.protocol === 'file:') -+ .map((url) => fileURLToPath(url)), -+ }, -+ } satisfies ScssWorkerResult -+ }, -+ { -+ parentFunctions: { -+ internalCanonicalize, -+ internalLoad, -+ }, -+ shouldUseFake(_sassPath, _data, options) { -+ // functions and importer is a function and is not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ (options.functions && Object.keys(options.functions).length > 0) || -+ (options.importers && -+ (!Array.isArray(options.importers) || options.importers.length > 0)) -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/makeScssWorker.ts b/packages/vite/src/node/plugins/css/makeScssWorker.ts -new file mode 100644 -index 000000000..087a0a82c ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/makeScssWorker.ts -@@ -0,0 +1,192 @@ -+import path from 'node:path' -+import type Sass from 'sass' -+import type { Alias } from 'dep-types/alias' -+import { WorkerWithFallback } from 'artichokie' -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import fs from 'node:fs' -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export type SassStylePreprocessorOptions = StylePreprocessorOptions & -+ Omit, 'data' | 'file' | 'outFile'> & { -+ api?: 'legacy' | 'modern' | 'modern-compiler' -+ } -+ -+// in unix, scss might append `location.href` in environments that shim `location` -+// see https://github.com/sass/dart-sass/issues/710 -+export function cleanScssBugUrl(url: string) { -+ if ( -+ // check bug via `window` and `location` global -+ typeof window !== 'undefined' && -+ typeof location !== 'undefined' && -+ typeof location?.href === 'string' -+ ) { -+ const prefix = location.href.replace(/\/$/, '') -+ return url.replace(prefix, '') -+ } else { -+ return url -+ } -+} -+ -+export function fixScssBugImportValue( -+ data: Sass.LegacyImporterResult, -+): Sass.LegacyImporterResult { -+ // the scss bug doesn't load files properly so we have to load it ourselves -+ // to prevent internal error when it loads itself -+ if ( -+ // check bug via `window` and `location` global -+ typeof window !== 'undefined' && -+ typeof location !== 'undefined' && -+ data && -+ 'file' in data && -+ (!('contents' in data) || data.contents == null) -+ ) { -+ // @ts-expect-error we need to preserve file property for HMR -+ data.contents = fs.readFileSync(data.file, 'utf-8') -+ } -+ return data -+} -+ -+// #region Sass -+// .scss/.sass processor -+export const makeScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const internalImporter = async ( -+ url: string, -+ importer: string, -+ filename: string, -+ ) => { -+ importer = cleanScssBugUrl(importer) -+ const resolved = await resolvers.sass(url, importer) -+ if (resolved) { -+ try { -+ const data = await rebaseUrls( -+ resolved, -+ filename, -+ alias, -+ '$', -+ resolvers.sass, -+ ) -+ return fixScssBugImportValue(data) -+ } catch (data) { -+ return data -+ } -+ } else { -+ return null -+ } -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => -+ async ( -+ sassPath: string, -+ data: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: SassStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const sass: typeof Sass = require(sassPath) -+ // eslint-disable-next-line no-restricted-globals -+ const path: typeof import('node:path') = require('node:path') -+ -+ // NOTE: `sass` always runs it's own importer first, and only falls back to -+ // the `importer` option when it can't resolve a path -+ const _internalImporter: Sass.LegacyAsyncImporter = ( -+ url, -+ importer, -+ done, -+ ) => { -+ internalImporter(url, importer, options.filename).then((data) => -+ done?.(data), -+ ) -+ } -+ const importer = [_internalImporter] -+ if (options.importer) { -+ Array.isArray(options.importer) -+ ? importer.unshift(...options.importer) -+ : importer.unshift(options.importer) -+ } -+ -+ const finalOptions: Sass.LegacyOptions<'async'> = { -+ ...options, -+ data, -+ file: options.filename, -+ outFile: options.filename, -+ importer, -+ ...(options.enableSourcemap -+ ? { -+ sourceMap: true, -+ omitSourceMapUrl: true, -+ sourceMapRoot: path.dirname(options.filename), -+ } -+ : {}), -+ } -+ return new Promise((resolve, reject) => { -+ sass.render(finalOptions, (err, res) => { -+ if (err) { -+ reject(err) -+ } else { -+ resolve({ -+ css: res!.css.toString(), -+ map: res!.map?.toString(), -+ stats: res!.stats, -+ }) -+ } -+ }) -+ }) -+ }, -+ { -+ parentFunctions: { internalImporter }, -+ shouldUseFake(_sassPath, _data, options) { -+ // functions and importer is a function and is not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ (options.functions && Object.keys(options.functions).length > 0) || -+ (options.importer && -+ (!Array.isArray(options.importer) || options.importer.length > 0)) -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/makeStylWorker.ts b/packages/vite/src/node/plugins/css/makeStylWorker.ts -new file mode 100644 -index 000000000..4fbdf448d ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/makeStylWorker.ts -@@ -0,0 +1,91 @@ -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import type Stylus from 'stylus' -+import { WorkerWithFallback } from 'artichokie' -+import type { Alias } from 'dep-types/alias' -+ -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export type StylusStylePreprocessorOptions = StylePreprocessorOptions & { -+ define?: Record -+} -+ -+// #region Stylus -+// .styl -+export const makeStylWorker = (maxWorkers: number | undefined) => { -+ const worker = new WorkerWithFallback( -+ () => { -+ return async ( -+ stylusPath: string, -+ content: string, -+ root: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: StylusStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const nodeStylus: typeof Stylus = require(stylusPath) -+ -+ const ref = nodeStylus(content, options) -+ if (options.define) { -+ for (const key in options.define) { -+ ref.define(key, options.define[key]) -+ } -+ } -+ if (options.enableSourcemap) { -+ ref.set('sourcemap', { -+ comment: false, -+ inline: false, -+ basePath: root, -+ }) -+ } -+ -+ return { -+ code: ref.render(), -+ // @ts-expect-error sourcemap exists -+ map: ref.sourcemap as ExistingRawSourceMap | undefined, -+ deps: ref.deps(), -+ } -+ } -+ }, -+ { -+ shouldUseFake(_stylusPath, _content, _root, options) { -+ // define can include functions and those are not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ options.define && -+ Object.values(options.define).some((d) => typeof d === 'function') -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/minifyCSS.ts b/packages/vite/src/node/plugins/css/minifyCSS.ts -new file mode 100644 -index 000000000..37b49a888 ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/minifyCSS.ts -@@ -0,0 +1,67 @@ -+import colors from 'picocolors' -+import { formatMessages, transform } from 'esbuild' -+import { ResolvedConfig } from 'packages/vite/src/node/config'; -+ -+ -+const export decoder = new TextDecoder() -+ -+const export cssBundleName = 'style.css' -+ -+export async function minifyCSS( -+ css: string, -+ config: ResolvedConfig, -+ inlined: boolean, -+) { -+ // We want inlined CSS to not end with a linebreak, while ensuring that -+ // regular CSS assets do end with a linebreak. -+ // See https://github.com/vitejs/vite/pull/13893#issuecomment-1678628198 -+ -+ if (config.build.cssMinify === 'lightningcss') { -+ const { code, warnings } = (await importLightningCSS()).transform({ -+ ...config.css?.lightningcss, -+ targets: convertTargets(config.build.cssTarget), -+ cssModules: undefined, -+ filename: cssBundleName, -+ code: Buffer.from(css), -+ minify: true, -+ }) -+ if (warnings.length) { -+ config.logger.warn( -+ colors.yellow( -+ `warnings when minifying css:\n${warnings -+ .map((w) => w.message) -+ .join('\n')}`, -+ ), -+ ) -+ } -+ -+ // NodeJS res.code = Buffer -+ // Deno res.code = Uint8Array -+ // For correct decode compiled css need to use TextDecoder -+ // LightningCSS output does not return a linebreak at the end -+ return decoder.decode(code) + (inlined ? '' : '\n') -+ } -+ try { -+ const { code, warnings } = await transform(css, { -+ loader: 'css', -+ target: config.build.cssTarget || undefined, -+ ...resolveMinifyCssEsbuildOptions(config.esbuild || {}), -+ }) -+ if (warnings.length) { -+ const msgs = await formatMessages(warnings, { kind: 'warning' }) -+ config.logger.warn( -+ colors.yellow(`warnings when minifying css:\n${msgs.join('\n')}`), -+ ) -+ } -+ // esbuild output does return a linebreak at the end -+ return inlined ? code.trimEnd() : code -+ } catch (e) { -+ if (e.errors) { -+ e.message = '[esbuild css minify] ' + e.message -+ const msgs = await formatMessages(e.errors, { kind: 'error' }) -+ e.frame = '\n' + msgs.join('\n') -+ e.loc = e.errors[0].location -+ } -+ throw e -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/rebaseUrls.ts b/packages/vite/src/node/plugins/css/rebaseUrls.ts -new file mode 100644 -index 000000000..179ab3dcb ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/rebaseUrls.ts -@@ -0,0 +1,119 @@ -+import fsp from 'node:fs/promises' -+import path from 'node:path' -+import type { Alias } from 'dep-types/alias' -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import { normalizePath } from 'packages/vite/src/node/utils'; -+import { asyncReplace } from 'packages/vite/src/node/utils'; -+ -+ -+// https://drafts.csswg.org/css-syntax-3/#identifier-code-point -+export const cssUrlRE = -+ /(?<=^|[^\w\-\u0080-\uffff])url\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ -+ -+export const cssDataUriRE = -+ /(?<=^|[^\w\-\u0080-\uffff])data-uri\((\s*('[^']+'|"[^"]+")\s*|[^'")]+)\)/ -+ -+export const importCssRE = /@import ('[^']+\.css'|"[^"]+\.css"|[^'")]+\.css)/ -+ -+export type CssUrlReplacer = ( -+ url: string, -+ importer?: string, -+) => string | Promise -+ -+export function rewriteCssUrls( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, cssUrlRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doUrlReplace(rawUrl.trim(), matched, replacer) -+ }) -+} -+ -+export function rewriteCssDataUris( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, cssDataUriRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doUrlReplace(rawUrl.trim(), matched, replacer, 'data-uri') -+ }) -+} -+ -+export function rewriteImportCss( -+ css: string, -+ replacer: CssUrlReplacer, -+): Promise { -+ return asyncReplace(css, importCssRE, async (match) => { -+ const [matched, rawUrl] = match -+ return await doImportCSSReplace(rawUrl, matched, replacer) -+ }) -+} -+ -+/** -+ * relative url() inside \@imported sass and less files must be rebased to use -+ * root file as base. -+ */ -+export async function rebaseUrls( -+ file: string, -+ rootFile: string, -+ alias: Alias[], -+ variablePrefix: string, -+ resolver: ResolveFn, -+): Promise<{ file: string; contents?: string }> { -+ file = path.resolve(file) // ensure os-specific flashes -+ // in the same dir, no need to rebase -+ const fileDir = path.dirname(file) -+ const rootDir = path.dirname(rootFile) -+ if (fileDir === rootDir) { -+ return { file } -+ } -+ -+ const content = await fsp.readFile(file, 'utf-8') -+ // no url() -+ const hasUrls = cssUrlRE.test(content) -+ // data-uri() calls -+ const hasDataUris = cssDataUriRE.test(content) -+ // no @import xxx.css -+ const hasImportCss = importCssRE.test(content) -+ -+ if (!hasUrls && !hasDataUris && !hasImportCss) { -+ return { file } -+ } -+ -+ let rebased -+ const rebaseFn = async (url: string) => { -+ if (url[0] === '/') return url -+ // ignore url's starting with variable -+ if (url.startsWith(variablePrefix)) return url -+ // match alias, no need to rewrite -+ for (const { find } of alias) { -+ const matches = -+ typeof find === 'string' ? url.startsWith(find) : find.test(url) -+ if (matches) { -+ return url -+ } -+ } -+ const absolute = (await resolver(url, file)) || path.resolve(fileDir, url) -+ const relative = path.relative(rootDir, absolute) -+ return normalizePath(relative) -+ } -+ -+ // fix css imports in less such as `@import "foo.css"` -+ if (hasImportCss) { -+ rebased = await rewriteImportCss(content, rebaseFn) -+ } -+ -+ if (hasUrls) { -+ rebased = await rewriteCssUrls(rebased || content, rebaseFn) -+ } -+ -+ if (hasDataUris) { -+ rebased = await rewriteCssDataUris(rebased || content, rebaseFn) -+ } -+ -+ return { -+ file, -+ contents: rebased, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/scssProcessor.ts b/packages/vite/src/node/plugins/css/scssProcessor.ts -new file mode 100644 -index 000000000..b9cbbc28d ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/scssProcessor.ts -@@ -0,0 +1,457 @@ -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import type Sass from 'sass' -+import type { Alias } from 'dep-types/alias' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import path from 'node:path' -+import { WorkerWithFallback } from 'artichokie' -+import fs from 'node:fs' -+import fsp from 'node:fs/promises' -+import { fileURLToPath, pathToFileURL } from 'node:url' -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export type SassStylePreprocessorOptions = StylePreprocessorOptions & -+ Omit, 'data' | 'file' | 'outFile'> & { -+ api?: 'legacy' | 'modern' | 'modern-compiler' -+ } -+ -+export interface StylePreprocessorResults { -+ code: string -+ map?: ExistingRawSourceMap | undefined -+ additionalMap?: ExistingRawSourceMap | undefined -+ error?: RollupError -+ deps: string[] -+} -+ -+export type SassStylePreprocessor = { -+ process: ( -+ source: string, -+ root: string, -+ options: SassStylePreprocessorOptions, -+ resolvers: CSSAtImportResolvers, -+ ) => StylePreprocessorResults | Promise -+ close: () => void -+} -+ -+// in unix, scss might append `location.href` in environments that shim `location` -+// see https://github.com/sass/dart-sass/issues/710 -+export function cleanScssBugUrl(url: string) { -+ if ( -+ // check bug via `window` and `location` global -+ typeof window !== 'undefined' && -+ typeof location !== 'undefined' && -+ typeof location?.href === 'string' -+ ) { -+ const prefix = location.href.replace(/\/$/, '') -+ return url.replace(prefix, '') -+ } else { -+ return url -+ } -+} -+ -+export function fixScssBugImportValue( -+ data: Sass.LegacyImporterResult, -+): Sass.LegacyImporterResult { -+ // the scss bug doesn't load files properly so we have to load it ourselves -+ // to prevent internal error when it loads itself -+ if ( -+ // check bug via `window` and `location` global -+ typeof window !== 'undefined' && -+ typeof location !== 'undefined' && -+ data && -+ 'file' in data && -+ (!('contents' in data) || data.contents == null) -+ ) { -+ // @ts-expect-error we need to preserve file property for HMR -+ data.contents = fs.readFileSync(data.file, 'utf-8') -+ } -+ return data -+} -+ -+// #region Sass -+// .scss/.sass processor -+export const makeScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const internalImporter = async ( -+ url: string, -+ importer: string, -+ filename: string, -+ ) => { -+ importer = cleanScssBugUrl(importer) -+ const resolved = await resolvers.sass(url, importer) -+ if (resolved) { -+ try { -+ const data = await rebaseUrls( -+ resolved, -+ filename, -+ alias, -+ '$', -+ resolvers.sass, -+ ) -+ return fixScssBugImportValue(data) -+ } catch (data) { -+ return data -+ } -+ } else { -+ return null -+ } -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => -+ async ( -+ sassPath: string, -+ data: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: SassStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const sass: typeof Sass = require(sassPath) -+ // eslint-disable-next-line no-restricted-globals -+ const path: typeof import('node:path') = require('node:path') -+ -+ // NOTE: `sass` always runs it's own importer first, and only falls back to -+ // the `importer` option when it can't resolve a path -+ const _internalImporter: Sass.LegacyAsyncImporter = ( -+ url, -+ importer, -+ done, -+ ) => { -+ internalImporter(url, importer, options.filename).then((data) => -+ done?.(data), -+ ) -+ } -+ const importer = [_internalImporter] -+ if (options.importer) { -+ Array.isArray(options.importer) -+ ? importer.unshift(...options.importer) -+ : importer.unshift(options.importer) -+ } -+ -+ const finalOptions: Sass.LegacyOptions<'async'> = { -+ ...options, -+ data, -+ file: options.filename, -+ outFile: options.filename, -+ importer, -+ ...(options.enableSourcemap -+ ? { -+ sourceMap: true, -+ omitSourceMapUrl: true, -+ sourceMapRoot: path.dirname(options.filename), -+ } -+ : {}), -+ } -+ return new Promise((resolve, reject) => { -+ sass.render(finalOptions, (err, res) => { -+ if (err) { -+ reject(err) -+ } else { -+ resolve({ -+ css: res!.css.toString(), -+ map: res!.map?.toString(), -+ stats: res!.stats, -+ }) -+ } -+ }) -+ }) -+ }, -+ { -+ parentFunctions: { internalImporter }, -+ shouldUseFake(_sassPath, _data, options) { -+ // functions and importer is a function and is not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ (options.functions && Object.keys(options.functions).length > 0) || -+ (options.importer && -+ (!Array.isArray(options.importer) || options.importer.length > 0)) -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -+ -+export const makeModernScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ maxWorkers: number | undefined, -+) => { -+ const internalCanonicalize = async ( -+ url: string, -+ importer: string, -+ ): Promise => { -+ importer = cleanScssBugUrl(importer) -+ const resolved = await resolvers.sass(url, importer) -+ return resolved ?? null -+ } -+ -+ const internalLoad = async (file: string, rootFile: string) => { -+ const result = await rebaseUrls(file, rootFile, alias, '$', resolvers.sass) -+ if (result.contents) { -+ return result.contents -+ } -+ return await fsp.readFile(result.file, 'utf-8') -+ } -+ -+ const worker = new WorkerWithFallback( -+ () => -+ async ( -+ sassPath: string, -+ data: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: SassStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const sass: typeof Sass = require(sassPath) -+ // eslint-disable-next-line no-restricted-globals -+ const path: typeof import('node:path') = require('node:path') -+ -+ const { fileURLToPath, pathToFileURL }: typeof import('node:url') = -+ // eslint-disable-next-line no-restricted-globals -+ require('node:url') -+ -+ const sassOptions = { ...options } as Sass.StringOptions<'async'> -+ sassOptions.url = pathToFileURL(options.filename) -+ sassOptions.sourceMap = options.enableSourcemap -+ -+ const internalImporter: Sass.Importer<'async'> = { -+ async canonicalize(url, context) { -+ const importer = context.containingUrl -+ ? fileURLToPath(context.containingUrl) -+ : options.filename -+ const resolved = await internalCanonicalize(url, importer) -+ return resolved ? pathToFileURL(resolved) : null -+ }, -+ async load(canonicalUrl) { -+ const ext = path.extname(canonicalUrl.pathname) -+ let syntax: Sass.Syntax = 'scss' -+ if (ext === '.sass') { -+ syntax = 'indented' -+ } else if (ext === '.css') { -+ syntax = 'css' -+ } -+ const contents = await internalLoad( -+ fileURLToPath(canonicalUrl), -+ options.filename, -+ ) -+ return { contents, syntax } -+ }, -+ } -+ sassOptions.importers = [ -+ ...(sassOptions.importers ?? []), -+ internalImporter, -+ ] -+ -+ const result = await sass.compileStringAsync(data, sassOptions) -+ return { -+ css: result.css, -+ map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -+ stats: { -+ includedFiles: result.loadedUrls -+ .filter((url) => url.protocol === 'file:') -+ .map((url) => fileURLToPath(url)), -+ }, -+ } satisfies ScssWorkerResult -+ }, -+ { -+ parentFunctions: { -+ internalCanonicalize, -+ internalLoad, -+ }, -+ shouldUseFake(_sassPath, _data, options) { -+ // functions and importer is a function and is not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ (options.functions && Object.keys(options.functions).length > 0) || -+ (options.importers && -+ (!Array.isArray(options.importers) || options.importers.length > 0)) -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -+ -+// this is mostly a copy&paste of makeModernScssWorker -+// however sharing code between two is hard because -+// makeModernScssWorker above needs function inlined for worker. -+export const makeModernCompilerScssWorker = ( -+ resolvers: CSSAtImportResolvers, -+ alias: Alias[], -+ _maxWorkers: number | undefined, -+) => { -+ let compiler: Sass.AsyncCompiler | undefined -+ -+ const worker: Awaited> = { -+ async run(sassPath, data, options) { -+ // need pathToFileURL for windows since import("D:...") fails -+ // https://github.com/nodejs/node/issues/31710 -+ const sass: typeof Sass = (await import(pathToFileURL(sassPath).href)) -+ .default -+ compiler ??= await sass.initAsyncCompiler() -+ -+ const sassOptions = { ...options } as Sass.StringOptions<'async'> -+ sassOptions.url = pathToFileURL(options.filename) -+ sassOptions.sourceMap = options.enableSourcemap -+ -+ const internalImporter: Sass.Importer<'async'> = { -+ async canonicalize(url, context) { -+ const importer = context.containingUrl -+ ? fileURLToPath(context.containingUrl) -+ : options.filename -+ const resolved = await resolvers.sass(url, cleanScssBugUrl(importer)) -+ return resolved ? pathToFileURL(resolved) : null -+ }, -+ async load(canonicalUrl) { -+ const ext = path.extname(canonicalUrl.pathname) -+ let syntax: Sass.Syntax = 'scss' -+ if (ext === '.sass') { -+ syntax = 'indented' -+ } else if (ext === '.css') { -+ syntax = 'css' -+ } -+ const result = await rebaseUrls( -+ fileURLToPath(canonicalUrl), -+ options.filename, -+ alias, -+ '$', -+ resolvers.sass, -+ ) -+ const contents = -+ result.contents ?? (await fsp.readFile(result.file, 'utf-8')) -+ return { contents, syntax } -+ }, -+ } -+ sassOptions.importers = [ -+ ...(sassOptions.importers ?? []), -+ internalImporter, -+ ] -+ -+ const result = await compiler.compileStringAsync(data, sassOptions) -+ return { -+ css: result.css, -+ map: result.sourceMap ? JSON.stringify(result.sourceMap) : undefined, -+ stats: { -+ includedFiles: result.loadedUrls -+ .filter((url) => url.protocol === 'file:') -+ .map((url) => fileURLToPath(url)), -+ }, -+ } satisfies ScssWorkerResult -+ }, -+ async stop() { -+ compiler?.dispose() -+ compiler = undefined -+ }, -+ } -+ -+ return worker -+} -+ -+export const scssProcessor = ( -+ maxWorkers: number | undefined, -+): SassStylePreprocessor => { -+ const workerMap = new Map>() -+ -+ return { -+ close() { -+ for (const worker of workerMap.values()) { -+ worker.stop() -+ } -+ }, -+ async process(source, root, options, resolvers) { -+ const sassPackage = loadSassPackage(root) -+ // TODO: change default in v6 -+ // options.api ?? sassPackage.name === "sass-embedded" ? "modern-compiler" : "modern"; -+ const api = options.api ?? 'legacy' -+ -+ if (!workerMap.has(options.alias)) { -+ workerMap.set( -+ options.alias, -+ api === 'modern-compiler' -+ ? makeModernCompilerScssWorker(resolvers, options.alias, maxWorkers) -+ : api === 'modern' -+ ? makeModernScssWorker(resolvers, options.alias, maxWorkers) -+ : makeScssWorker(resolvers, options.alias, maxWorkers), -+ ) -+ } -+ const worker = workerMap.get(options.alias)! -+ -+ const { content: data, map: additionalMap } = await getSource( -+ source, -+ options.filename, -+ options.additionalData, -+ options.enableSourcemap, -+ ) -+ -+ const optionsWithoutAdditionalData = { -+ ...options, -+ additionalData: undefined, -+ } -+ try { -+ const result = await worker.run( -+ sassPackage.path, -+ data, -+ optionsWithoutAdditionalData, -+ ) -+ const deps = result.stats.includedFiles.map((f) => cleanScssBugUrl(f)) -+ const map: ExistingRawSourceMap | undefined = result.map -+ ? JSON.parse(result.map.toString()) -+ : undefined -+ -+ return { -+ code: result.css.toString(), -+ map, -+ additionalMap, -+ deps, -+ } -+ } catch (e) { -+ // normalize SASS error -+ e.message = `[sass] ${e.message}` -+ e.id = e.file -+ e.frame = e.formatted -+ return { code: '', error: e, deps: [] } -+ } -+ }, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/css/stylProcessor.ts b/packages/vite/src/node/plugins/css/stylProcessor.ts -new file mode 100644 -index 000000000..946e7890e ---- /dev/null -+++ b/packages/vite/src/node/plugins/css/stylProcessor.ts -@@ -0,0 +1,177 @@ -+import { ResolveFn } from 'packages/vite/src/node/index'; -+import type { Alias } from 'dep-types/alias' -+import type { -+ ExistingRawSourceMap, -+ ModuleFormat, -+ OutputAsset, -+ OutputChunk, -+ RenderedChunk, -+ RollupError, -+ SourceMapInput, -+} from 'rollup' -+import type Stylus from 'stylus' -+import { WorkerWithFallback } from 'artichokie' -+ -+ -+export interface CSSAtImportResolvers { -+ css: ResolveFn -+ sass: ResolveFn -+ less: ResolveFn -+} -+ -+export type PreprocessorAdditionalDataResult = -+ | string -+ | { content: string; map?: ExistingRawSourceMap } -+ -+export type PreprocessorAdditionalData = -+ | string -+ | (( -+ source: string, -+ filename: string, -+ ) => -+ | PreprocessorAdditionalDataResult -+ | Promise) -+ -+export type StylePreprocessorOptions = { -+ [key: string]: any -+ additionalData?: PreprocessorAdditionalData -+ maxWorkers?: number | true -+ filename: string -+ alias: Alias[] -+ enableSourcemap: boolean -+} -+ -+export type StylusStylePreprocessorOptions = StylePreprocessorOptions & { -+ define?: Record -+} -+ -+export interface StylePreprocessorResults { -+ code: string -+ map?: ExistingRawSourceMap | undefined -+ additionalMap?: ExistingRawSourceMap | undefined -+ error?: RollupError -+ deps: string[] -+} -+ -+export type StylusStylePreprocessor = { -+ process: ( -+ source: string, -+ root: string, -+ options: StylusStylePreprocessorOptions, -+ resolvers: CSSAtImportResolvers, -+ ) => StylePreprocessorResults | Promise -+ close: () => void -+} -+ -+// #region Stylus -+// .styl -+export const makeStylWorker = (maxWorkers: number | undefined) => { -+ const worker = new WorkerWithFallback( -+ () => { -+ return async ( -+ stylusPath: string, -+ content: string, -+ root: string, -+ // additionalData can a function that is not cloneable but it won't be used -+ options: StylusStylePreprocessorOptions & { additionalData: undefined }, -+ ) => { -+ // eslint-disable-next-line no-restricted-globals -- this function runs inside a cjs worker -+ const nodeStylus: typeof Stylus = require(stylusPath) -+ -+ const ref = nodeStylus(content, options) -+ if (options.define) { -+ for (const key in options.define) { -+ ref.define(key, options.define[key]) -+ } -+ } -+ if (options.enableSourcemap) { -+ ref.set('sourcemap', { -+ comment: false, -+ inline: false, -+ basePath: root, -+ }) -+ } -+ -+ return { -+ code: ref.render(), -+ // @ts-expect-error sourcemap exists -+ map: ref.sourcemap as ExistingRawSourceMap | undefined, -+ deps: ref.deps(), -+ } -+ } -+ }, -+ { -+ shouldUseFake(_stylusPath, _content, _root, options) { -+ // define can include functions and those are not serializable -+ // in that case, fallback to running in main thread -+ return !!( -+ options.define && -+ Object.values(options.define).some((d) => typeof d === 'function') -+ ) -+ }, -+ max: maxWorkers, -+ }, -+ ) -+ return worker -+} -+ -+export const stylProcessor = ( -+ maxWorkers: number | undefined, -+): StylusStylePreprocessor => { -+ const workerMap = new Map>() -+ -+ return { -+ close() { -+ for (const worker of workerMap.values()) { -+ worker.stop() -+ } -+ }, -+ async process(source, root, options, resolvers) { -+ const stylusPath = loadPreprocessorPath(PreprocessLang.stylus, root) -+ -+ if (!workerMap.has(options.alias)) { -+ workerMap.set(options.alias, makeStylWorker(maxWorkers)) -+ } -+ const worker = workerMap.get(options.alias)! -+ -+ // Get source with preprocessor options.additionalData. Make sure a new line separator -+ // is added to avoid any render error, as added stylus content may not have semi-colon separators -+ const { content, map: additionalMap } = await getSource( -+ source, -+ options.filename, -+ options.additionalData, -+ options.enableSourcemap, -+ '\n', -+ ) -+ // Get preprocessor options.imports dependencies as stylus -+ // does not return them with its builtin `.deps()` method -+ const importsDeps = (options.imports ?? []).map((dep: string) => -+ path.resolve(dep), -+ ) -+ const optionsWithoutAdditionalData = { -+ ...options, -+ additionalData: undefined, -+ } -+ try { -+ const { code, map, deps } = await worker.run( -+ stylusPath, -+ content, -+ root, -+ optionsWithoutAdditionalData, -+ ) -+ return { -+ code, -+ map: formatStylusSourceMap(map, root), -+ additionalMap, -+ // Concat imports deps with computed deps -+ deps: [...deps, ...importsDeps], -+ } -+ } catch (e) { -+ const wrapped = new Error(`[stylus] ${e.message}`) -+ wrapped.name = e.name -+ wrapped.stack = e.stack -+ return { code: '', error: wrapped, deps: [] } -+ } -+ }, -+ } -+} -\ No newline at end of file -diff --git a/packages/vite/src/node/plugins/define.ts b/packages/vite/src/node/plugins/define.ts -index 585bc0154..ef6bfd644 100644 ---- a/packages/vite/src/node/plugins/define.ts -+++ b/packages/vite/src/node/plugins/define.ts -@@ -1,9 +1,11 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssAnalysisPlugin'; - import { transform } from 'esbuild' - import { TraceMap, decodedMap, encodedMap } from '@jridgewell/trace-mapping' --import type { ResolvedConfig } from '../config' - import type { Plugin } from '../plugin' - import { escapeRegex } from '../utils' --import { isCSSRequest } from './css' - import { isHTMLRequest } from './html' - - const nonJsRe = /\.json(?:$|\?)/ -diff --git a/packages/vite/src/node/plugins/dynamicImportVars.ts b/packages/vite/src/node/plugins/dynamicImportVars.ts -index 8c55632a7..bed219778 100644 ---- a/packages/vite/src/node/plugins/dynamicImportVars.ts -+++ b/packages/vite/src/node/plugins/dynamicImportVars.ts -@@ -1,3 +1,4 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; - import { posix } from 'node:path' - import MagicString from 'magic-string' - import { init, parse as parseImports } from 'es-module-lexer' -@@ -5,7 +6,6 @@ import type { ImportSpecifier } from 'es-module-lexer' - import { parseAst } from 'rollup/parseAst' - import { dynamicImportToGlob } from '@rollup/plugin-dynamic-import-vars' - import type { Plugin } from '../plugin' --import type { ResolvedConfig } from '../config' - import { CLIENT_ENTRY } from '../constants' - import { - createFilter, -diff --git a/packages/vite/src/node/plugins/esbuild.ts b/packages/vite/src/node/plugins/esbuild.ts -index fda6ca02a..b041a5e1f 100644 ---- a/packages/vite/src/node/plugins/esbuild.ts -+++ b/packages/vite/src/node/plugins/esbuild.ts -@@ -1,3 +1,7 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/_createServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/restartServer'; - import path from 'node:path' - import colors from 'picocolors' - import type { -@@ -18,8 +22,6 @@ import { - ensureWatchedFile, - generateCodeFrame, - } from '../utils' --import type { ViteDevServer } from '../server' --import type { ResolvedConfig } from '../config' - import type { Plugin } from '../plugin' - import { cleanUrl } from '../../shared/utils' - -diff --git a/packages/vite/src/node/plugins/html.ts b/packages/vite/src/node/plugins/html.ts -index b7109debc..53b83e3f0 100644 ---- a/packages/vite/src/node/plugins/html.ts -+++ b/packages/vite/src/node/plugins/html.ts -@@ -1,3 +1,10 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssAnalysisPlugin'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/ViteDevServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/_createServer'; -+import { ViteDevServer } from 'packages/vite/src/node/server/index/restartServer'; - import path from 'node:path' - import type { - OutputAsset, -@@ -11,7 +18,6 @@ import colors from 'picocolors' - import type { DefaultTreeAdapterMap, ParserError, Token } from 'parse5' - import { stripLiteral } from 'strip-literal' - import type { Plugin } from '../plugin' --import type { ViteDevServer } from '../server' - import { - encodeURIPath, - generateCodeFrame, -@@ -25,7 +31,6 @@ import { - unique, - urlCanParse, - } from '../utils' --import type { ResolvedConfig } from '../config' - import { checkPublicFile } from '../publicDir' - import { toOutputFilePathInHtml } from '../build' - import { resolveEnvPrefix } from '../env' -@@ -37,7 +42,6 @@ import { - publicAssetUrlRE, - urlToBuiltUrl, - } from './asset' --import { isCSSRequest } from './css' - import { modulePreloadPolyfillId } from './modulePreloadPolyfill' - - interface ScriptAssetsUrl { -diff --git a/packages/vite/src/node/plugins/importAnalysis.ts b/packages/vite/src/node/plugins/importAnalysis.ts -index 8027bae9a..f9b95778a 100644 ---- a/packages/vite/src/node/plugins/importAnalysis.ts -+++ b/packages/vite/src/node/plugins/importAnalysis.ts -@@ -1,3 +1,30 @@ -+import { ResolvedConfig } from 'packages/vite/src/node/config/resolveConfig'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isDirectCSSRequest } from 'packages/vite/src/node/plugins/css/cssPostPlugin'; -+import { isCSSRequest } from 'packages/vite/src/node/plugins/css/cssAnalysisPlugin'; -+import { debugHmr } from 'packages/vite/src/node/server/hmr/isNodeWithinCircularImports'; -+import { lexAcceptedHmrDeps } from 'packages/vite/src/node/server/hmr/lexAcceptedHmrDeps'; -+import { extractImportedBindings } from 'packages/vite/src/node/plugins/importAnalysis/extractImportedBindings'; -+import { extractImportedBindings } from 'packages/vite/src/node/plugins/importAnalysis'; -+import { clientDir } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { canSkipImportAnalysis } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { optimizedDepChunkRE } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { optimizedDepDynamicRE } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { urlIsStringRE } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { templateLiteralRE } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { isExplicitImportRequired } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { extractImportedBindings } from 'packages/vite/src/node/plugins/importAnalysis/importAnalysisPlugin'; -+import { importAnalysisPlugin } from 'packages/vite/src/node/plugins/importAnalysis'; -+import { interopHelper } from 'packages/vite/src/node/plugins/importAnalysis/interopNamedImports'; -+import { interopNamedImports } from 'packages/vite/src/node/plugins/importAnalysis/interopNamedImports'; -+import { interopNamedImports } from 'packages/vite/src/node/plugins/importAnalysis'; -+import { interopHelper } from 'packages/vite/src/node/plugins/importAnalysis/transformCjsImport'; -+import { ImportNameSpecifier } from 'packages/vite/src/node/plugins/importAnalysis/transformCjsImport'; -+import { transformCjsImport } from 'packages/vite/src/node/plugins/importAnalysis/transformCjsImport'; -+import { transformCjsImport } from 'packages/vite/src/node/plugins/importAnalysis'; -+import { browserExternalId } from 'packages/vite/src/node/plugins/resolve/resolvePlugin'; -+import { browserExternalId } from 'packages/vite/src/node/plugins/resolve/resolveDeepImport'; - import path from 'node:path' - import { performance } from 'node:perf_hooks' - import colors from 'picocolors' -@@ -21,9 +48,7 @@ import { - SPECIAL_QUERY_RE, - } from '../constants' - import { -- debugHmr, - handlePrunedModules, -- lexAcceptedHmrDeps, - lexAcceptedHmrExports, - normalizeHmrUrl, - } from '../server/hmr' -@@ -53,7 +78,6 @@ import { - import { getFsUtils } from '../fsUtils' - import { checkPublicFile } from '../publicDir' - import { getDepOptimizationConfig } from '../config' --import type { ResolvedConfig } from '../config' - import type { Plugin } from '../plugin' - import { shouldExternalizeForSSR } from '../ssr/ssrExternal' - import { getDepsOptimizer, optimizedDepNeedsInterop } from '../optimizer' -@@ -65,760 +89,21 @@ import { - } from '../../shared/utils' - import type { TransformPluginContext } from '../server/pluginContainer' - import { throwOutdatedRequest } from './optimizedDeps' --import { isCSSRequest, isDirectCSSRequest } from './css' --import { browserExternalId } from './resolve' - import { serializeDefine } from './define' - import { WORKER_FILE_ID } from './worker' - import { getAliasPatternMatcher } from './preAlias' - - const debug = createDebugger('vite:import-analysis') -- --const clientDir = normalizePath(CLIENT_DIR) -- --const skipRE = /\.(?:map|json)(?:$|\?)/ --export const canSkipImportAnalysis = (id: string): boolean => -- skipRE.test(id) || isDirectCSSRequest(id) -- --const optimizedDepChunkRE = /\/chunk-[A-Z\d]{8}\.js/ --const optimizedDepDynamicRE = /-[A-Z\d]{8}\.js/ -+export - - export const hasViteIgnoreRE = /\/\*\s*@vite-ignore\s*\*\// - --const urlIsStringRE = /^(?:'.*'|".*"|`.*`)$/ -- --const templateLiteralRE = /^\s*`(.*)`\s*$/ -- - interface UrlPosition { - url: string - start: number - end: number - } - --export function isExplicitImportRequired(url: string): boolean { -- return !isJSRequest(url) && !isCSSRequest(url) --} -- --function extractImportedBindings( -- id: string, -- source: string, -- importSpec: ImportSpecifier, -- importedBindings: Map>, --) { -- let bindings = importedBindings.get(id) -- if (!bindings) { -- bindings = new Set() -- importedBindings.set(id, bindings) -- } -- -- const isDynamic = importSpec.d > -1 -- const isMeta = importSpec.d === -2 -- if (isDynamic || isMeta) { -- // this basically means the module will be impacted by any change in its dep -- bindings.add('*') -- return -- } -- -- const exp = source.slice(importSpec.ss, importSpec.se) -- ESM_STATIC_IMPORT_RE.lastIndex = 0 -- const match = ESM_STATIC_IMPORT_RE.exec(exp) -- if (!match) { -- return -- } -- -- const staticImport: StaticImport = { -- type: 'static', -- code: match[0], -- start: match.index, -- end: match.index + match[0].length, -- imports: match.groups!.imports, -- specifier: match.groups!.specifier, -- } -- const parsed = parseStaticImport(staticImport) -- if (!parsed) { -- return -- } -- if (parsed.namespacedImport) { -- bindings.add('*') -- } -- if (parsed.defaultImport) { -- bindings.add('default') -- } -- if (parsed.namedImports) { -- for (const name of Object.keys(parsed.namedImports)) { -- bindings.add(name) -- } -- } --} -- --/** -- * Server-only plugin that lexes, resolves, rewrites and analyzes url imports. -- * -- * - Imports are resolved to ensure they exist on disk -- * -- * - Lexes HMR accept calls and updates import relationships in the module graph -- * -- * - Bare module imports are resolved (by @rollup-plugin/node-resolve) to -- * absolute file paths, e.g. -- * -- * ```js -- * import 'foo' -- * ``` -- * is rewritten to -- * ```js -- * import '/@fs//project/node_modules/foo/dist/foo.js' -- * ``` -- * -- * - CSS imports are appended with `.js` since both the js module and the actual -- * css (referenced via ``) may go through the transform pipeline: -- * -- * ```js -- * import './style.css' -- * ``` -- * is rewritten to -- * ```js -- * import './style.css.js' -- * ``` -- */ --export function importAnalysisPlugin(config: ResolvedConfig): Plugin { -- const { root, base } = config -- const fsUtils = getFsUtils(config) -- const clientPublicPath = path.posix.join(base, CLIENT_PUBLIC_PATH) -- const enablePartialAccept = config.experimental?.hmrPartialAccept -- const matchAlias = getAliasPatternMatcher(config.resolve.alias) -- let server: ViteDevServer -- -- let _env: string | undefined -- let _ssrEnv: string | undefined -- function getEnv(ssr: boolean) { -- if (!_ssrEnv || !_env) { -- const importMetaEnvKeys: Record = {} -- const userDefineEnv: Record = {} -- for (const key in config.env) { -- importMetaEnvKeys[key] = JSON.stringify(config.env[key]) -- } -- for (const key in config.define) { -- // non-import.meta.env.* is handled in `clientInjection` plugin -- if (key.startsWith('import.meta.env.')) { -- userDefineEnv[key.slice(16)] = config.define[key] -- } -- } -- const env = `import.meta.env = ${serializeDefine({ -- ...importMetaEnvKeys, -- SSR: '__vite_ssr__', -- ...userDefineEnv, -- })};` -- _ssrEnv = env.replace('__vite_ssr__', 'true') -- _env = env.replace('__vite_ssr__', 'false') -- } -- return ssr ? _ssrEnv : _env -- } -- -- return { -- name: 'vite:import-analysis', -- -- configureServer(_server) { -- server = _server -- }, -- -- async transform(source, importer, options) { -- // In a real app `server` is always defined, but it is undefined when -- // running src/node/server/__tests__/pluginContainer.spec.ts -- if (!server) { -- return null -- } -- -- const ssr = options?.ssr === true -- -- if (canSkipImportAnalysis(importer)) { -- debug?.(colors.dim(`[skipped] ${prettifyUrl(importer, root)}`)) -- return null -- } -- -- const msAtStart = debug ? performance.now() : 0 -- await init -- let imports!: readonly ImportSpecifier[] -- let exports!: readonly ExportSpecifier[] -- source = stripBomTag(source) -- try { -- ;[imports, exports] = parseImports(source) -- } catch (_e: unknown) { -- const e = _e as EsModuleLexerParseError -- const { message, showCodeFrame } = createParseErrorInfo( -- importer, -- source, -- ) -- this.error(message, showCodeFrame ? e.idx : undefined) -- } -- -- const depsOptimizer = getDepsOptimizer(config, ssr) -- -- const { moduleGraph } = server -- // since we are already in the transform phase of the importer, it must -- // have been loaded so its entry is guaranteed in the module graph. -- const importerModule = moduleGraph.getModuleById(importer)! -- if (!importerModule) { -- // This request is no longer valid. It could happen for optimized deps -- // requests. A full reload is going to request this id again. -- // Throwing an outdated error so we properly finish the request with a -- // 504 sent to the browser. -- throwOutdatedRequest(importer) -- } -- -- if ( -- !imports.length && -- !(this as unknown as TransformPluginContext)._addedImports -- ) { -- importerModule.isSelfAccepting = false -- debug?.( -- `${timeFrom(msAtStart)} ${colors.dim( -- `[no imports] ${prettifyUrl(importer, root)}`, -- )}`, -- ) -- return source -- } -- -- let hasHMR = false -- let isSelfAccepting = false -- let hasEnv = false -- let needQueryInjectHelper = false -- let s: MagicString | undefined -- const str = () => s || (s = new MagicString(source)) -- let isPartiallySelfAccepting = false -- const importedBindings = enablePartialAccept -- ? new Map>() -- : null -- const toAbsoluteUrl = (url: string) => -- path.posix.resolve(path.posix.dirname(importerModule.url), url) -- -- const normalizeUrl = async ( -- url: string, -- pos: number, -- forceSkipImportAnalysis: boolean = false, -- ): Promise<[string, string]> => { -- url = stripBase(url, base) -- -- let importerFile = importer -- -- const optimizeDeps = getDepOptimizationConfig(config, ssr) -- if (moduleListContains(optimizeDeps?.exclude, url)) { -- if (depsOptimizer) { -- await depsOptimizer.scanProcessing -- -- // if the dependency encountered in the optimized file was excluded from the optimization -- // the dependency needs to be resolved starting from the original source location of the optimized file -- // because starting from node_modules/.vite will not find the dependency if it was not hoisted -- // (that is, if it is under node_modules directory in the package source of the optimized file) -- for (const optimizedModule of depsOptimizer.metadata.depInfoList) { -- if (!optimizedModule.src) continue // Ignore chunks -- if (optimizedModule.file === importerModule.file) { -- importerFile = optimizedModule.src -- } -- } -- } -- } -- -- const resolved = await this.resolve(url, importerFile) -- -- if (!resolved || resolved.meta?.['vite:alias']?.noResolved) { -- // in ssr, we should let node handle the missing modules -- if (ssr) { -- return [url, url] -- } -- // fix#9534, prevent the importerModuleNode being stopped from propagating updates -- importerModule.isSelfAccepting = false -- moduleGraph._hasResolveFailedErrorModules.add(importerModule) -- return this.error( -- `Failed to resolve import "${url}" from "${normalizePath( -- path.relative(process.cwd(), importerFile), -- )}". Does the file exist?`, -- pos, -- ) -- } -- -- if (isExternalUrl(resolved.id)) { -- return [resolved.id, resolved.id] -- } -- -- const isRelative = url[0] === '.' -- const isSelfImport = !isRelative && cleanUrl(url) === cleanUrl(importer) -- -- // normalize all imports into resolved URLs -- // e.g. `import 'foo'` -> `import '/@fs/.../node_modules/foo/index.js'` -- if (resolved.id.startsWith(withTrailingSlash(root))) { -- // in root: infer short absolute path from root -- url = resolved.id.slice(root.length) -- } else if ( -- depsOptimizer?.isOptimizedDepFile(resolved.id) || -- // vite-plugin-react isn't following the leading \0 virtual module convention. -- // This is a temporary hack to avoid expensive fs checks for React apps. -- // We'll remove this as soon we're able to fix the react plugins. -- (resolved.id !== '/@react-refresh' && -- path.isAbsolute(resolved.id) && -- fsUtils.existsSync(cleanUrl(resolved.id))) -- ) { -- // an optimized deps may not yet exists in the filesystem, or -- // a regular file exists but is out of root: rewrite to absolute /@fs/ paths -- url = path.posix.join(FS_PREFIX, resolved.id) -- } else { -- url = resolved.id -- } -- -- // if the resolved id is not a valid browser import specifier, -- // prefix it to make it valid. We will strip this before feeding it -- // back into the transform pipeline -- if (url[0] !== '.' && url[0] !== '/') { -- url = wrapId(resolved.id) -- } -- -- // make the URL browser-valid if not SSR -- if (!ssr) { -- // mark non-js/css imports with `?import` -- if (isExplicitImportRequired(url)) { -- url = injectQuery(url, 'import') -- } else if ( -- (isRelative || isSelfImport) && -- !DEP_VERSION_RE.test(url) -- ) { -- // If the url isn't a request for a pre-bundled common chunk, -- // for relative js/css imports, or self-module virtual imports -- // (e.g. vue blocks), inherit importer's version query -- // do not do this for unknown type imports, otherwise the appended -- // query can break 3rd party plugin's extension checks. -- const versionMatch = DEP_VERSION_RE.exec(importer) -- if (versionMatch) { -- url = injectQuery(url, versionMatch[1]) -- } -- } -- -- // check if the dep has been hmr updated. If yes, we need to attach -- // its last updated timestamp to force the browser to fetch the most -- // up-to-date version of this module. -- try { -- // delay setting `isSelfAccepting` until the file is actually used (#7870) -- // We use an internal function to avoid resolving the url again -- const depModule = await moduleGraph._ensureEntryFromUrl( -- unwrapId(url), -- ssr, -- canSkipImportAnalysis(url) || forceSkipImportAnalysis, -- resolved, -- ) -- if (depModule.lastHMRTimestamp > 0) { -- url = injectQuery(url, `t=${depModule.lastHMRTimestamp}`) -- } -- } catch (e: any) { -- // it's possible that the dep fails to resolve (non-existent import) -- // attach location to the missing import -- e.pos = pos -- throw e -- } -- -- // prepend base -- url = joinUrlSegments(base, url) -- } -- -- return [url, resolved.id] -- } -- -- const orderedImportedUrls = new Array(imports.length) -- const orderedAcceptedUrls = new Array | undefined>( -- imports.length, -- ) -- const orderedAcceptedExports = new Array | undefined>( -- imports.length, -- ) -- -- await Promise.all( -- imports.map(async (importSpecifier, index) => { -- const { -- s: start, -- e: end, -- ss: expStart, -- se: expEnd, -- d: dynamicIndex, -- a: attributeIndex, -- } = importSpecifier -- -- // #2083 User may use escape path, -- // so use imports[index].n to get the unescaped string -- let specifier = importSpecifier.n -- -- const rawUrl = source.slice(start, end) -- -- // check import.meta usage -- if (rawUrl === 'import.meta') { -- const prop = source.slice(end, end + 4) -- if (prop === '.hot') { -- hasHMR = true -- const endHot = end + 4 + (source[end + 4] === '?' ? 1 : 0) -- if (source.slice(endHot, endHot + 7) === '.accept') { -- // further analyze accepted modules -- if (source.slice(endHot, endHot + 14) === '.acceptExports') { -- const importAcceptedExports = (orderedAcceptedExports[index] = -- new Set()) -- lexAcceptedHmrExports( -- source, -- source.indexOf('(', endHot + 14) + 1, -- importAcceptedExports, -- ) -- isPartiallySelfAccepting = true -- } else { -- const importAcceptedUrls = (orderedAcceptedUrls[index] = -- new Set()) -- if ( -- lexAcceptedHmrDeps( -- source, -- source.indexOf('(', endHot + 7) + 1, -- importAcceptedUrls, -- ) -- ) { -- isSelfAccepting = true -- } -- } -- } -- } else if (prop === '.env') { -- hasEnv = true -- } -- return -- } else if (templateLiteralRE.test(rawUrl)) { -- // If the import has backticks but isn't transformed as a glob import -- // (as there's nothing to glob), check if it's simply a plain string. -- // If so, we can replace the specifier as a plain string to prevent -- // an incorrect "cannot be analyzed" warning. -- if (!(rawUrl.includes('${') && rawUrl.includes('}'))) { -- specifier = rawUrl.replace(templateLiteralRE, '$1') -- } -- } -- -- const isDynamicImport = dynamicIndex > -1 -- -- // strip import attributes as we can process them ourselves -- if (!isDynamicImport && attributeIndex > -1) { -- str().remove(end + 1, expEnd) -- } -- -- // static import or valid string in dynamic import -- // If resolvable, let's resolve it -- if (specifier !== undefined) { -- // skip external / data uri -- if (isExternalUrl(specifier) || isDataUrl(specifier)) { -- return -- } -- // skip ssr external -- if (ssr && !matchAlias(specifier)) { -- if (shouldExternalizeForSSR(specifier, importer, config)) { -- return -- } -- if (isBuiltin(specifier)) { -- return -- } -- } -- // skip client -- if (specifier === clientPublicPath) { -- return -- } -- -- // warn imports to non-asset /public files -- if ( -- specifier[0] === '/' && -- !( -- config.assetsInclude(cleanUrl(specifier)) || -- urlRE.test(specifier) -- ) && -- checkPublicFile(specifier, config) -- ) { -- throw new Error( -- `Cannot import non-asset file ${specifier} which is inside /public. ` + -- `JS/CSS files inside /public are copied as-is on build and ` + -- `can only be referenced via