diff --git a/.github/workflows/accuracy_test.yaml b/.github/workflows/accuracy_test.yaml index 0297f625d0..6086bbb457 100644 --- a/.github/workflows/accuracy_test.yaml +++ b/.github/workflows/accuracy_test.yaml @@ -22,6 +22,9 @@ name: Benchmarks / accuracy on: + schedule: + # Runs every 6 hours + - cron: '0 */6 * * *' pull_request: types: [ labeled ] workflow_dispatch: @@ -34,6 +37,7 @@ on: # Current supported vLLM versions options: - main + - v0.9.2 - v0.9.1 - v0.7.3 vllm-ascend-version: @@ -42,6 +46,7 @@ on: type: choice options: - main + - v0.9.1-dev - v0.7.3-dev models: description: 'model:' @@ -49,9 +54,9 @@ on: type: choice options: - all - - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-VL-7B-Instruct - Qwen/Qwen3-8B-Base + - Qwen/Qwen3-30B-A3B default: 'all' # Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly @@ -73,56 +78,57 @@ jobs: ${{ (contains(github.event.pull_request.labels.*.name, 'accuracy-test') || contains(github.event.pull_request.labels.*.name, 'vl-accuracy-test') || + contains(github.event.pull_request.labels.*.name, 'moe-accuracy-test') || contains(github.event.pull_request.labels.*.name, 'dense-accuracy-test')) && contains(github.event.pull_request.labels.*.name, 'ready-for-test') || - github.event_name == 'workflow_dispatch' + github.event_name == 'workflow_dispatch' || github.event_name == 'schedule' }} runs-on: >- ${{ - (matrix.model_name == 'Qwen/Qwen2.5-VL-7B-Instruct' && 'linux-arm64-npu-4') || + (matrix.model_name == 'Qwen/Qwen3-30B-A3B' && 'linux-arm64-npu-4') || 'linux-arm64-npu-2' }} strategy: matrix: - vllm_use_version: [0, 1] + vllm_use_version: [1] # the accuracy test will run: # 1. workflow_dispatch with models input - # - all: Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-VL-7B-Instruct, Qwen/Qwen3-8B-Base - # - specified but not all: Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-VL-7B-Instruct, Qwen/Qwen3-8B-Base + # - all: Qwen/Qwen3-30B-A3B, Qwen/Qwen2.5-VL-7B-Instruct, Qwen/Qwen3-8B-Base + # - specified but not all: Qwen/Qwen3-30B-A3B, Qwen/Qwen2.5-VL-7B-Instruct, Qwen/Qwen3-8B-Base # 2. PR labeled with "*-accuracy-test" - # - accuracy-test: Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-VL-7B-Instruct - # - dense-accuracy-test: Qwen/Qwen2.5-7B-Instruct + # - accuracy-test: Qwen/Qwen3-8B-Base, Qwen/Qwen2.5-VL-7B-Instruct, Qwen/Qwen3-30B-A3B + # - dense-accuracy-test: Qwen/Qwen3-8B-Base # - vl-accuracy-test: Qwen/Qwen2.5-VL-7B-Instruct + # - moe-accuracy-test: Qwen/Qwen3-30B-A3B model_name: ${{ fromJSON( + (github.event_name == 'schedule' && + '["Qwen/Qwen3-30B-A3B","Qwen/Qwen2.5-VL-7B-Instruct","Qwen/Qwen3-8B-Base"]') || (github.event.inputs.models == 'all' && - '["Qwen/Qwen2.5-7B-Instruct","Qwen/Qwen2.5-VL-7B-Instruct","Qwen/Qwen3-8B-Base"]') || - (github.event.inputs.models == 'Qwen/Qwen2.5-7B-Instruct' && - '["Qwen/Qwen2.5-7B-Instruct"]') || + '["Qwen/Qwen3-30B-A3B","Qwen/Qwen2.5-VL-7B-Instruct","Qwen/Qwen3-8B-Base"]') || + (github.event.inputs.models == 'Qwen/Qwen3-30B-A3B' && + '["Qwen/Qwen3-30B-A3B"]') || (github.event.inputs.models == 'Qwen/Qwen2.5-VL-7B-Instruct' && '["Qwen/Qwen2.5-VL-7B-Instruct"]') || (github.event.inputs.models == 'Qwen/Qwen3-8B-Base' && '["Qwen/Qwen3-8B-Base"]') || contains(github.event.pull_request.labels.*.name, 'accuracy-test') && - '["Qwen/Qwen2.5-7B-Instruct","Qwen/Qwen2.5-VL-7B-Instruct"]' || + '["Qwen/Qwen3-8B-Base","Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen3-30B-A3B"]' || contains(github.event.pull_request.labels.*.name, 'dense-accuracy-test') && - '["Qwen/Qwen2.5-7B-Instruct"]' || + '["Qwen/Qwen3-8B-Base"]' || contains(github.event.pull_request.labels.*.name, 'vl-accuracy-test') && - '["Qwen/Qwen2.5-VL-7B-Instruct"]' + '["Qwen/Qwen2.5-VL-7B-Instruct"]' || + contains(github.event.pull_request.labels.*.name, 'moe-accuracy-test') && + '["Qwen/Qwen3-30B-A3B"]' ) }} - # Remove exclude after https://github.com/vllm-project/vllm-ascend/issues/1044 resolved - exclude: - - model_name: Qwen/Qwen2.5-VL-7B-Instruct - vllm_use_version: 1 fail-fast: false name: ${{ matrix.model_name }} accuracy V${{ matrix.vllm_use_version }} container: image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 env: - HF_ENDPOINT: https://hf-mirror.com - HF_TOKEN: ${{ secrets.HF_TOKEN }} DATASET_SOURCE: ModelScope VLLM_USE_MODELSCOPE: True + USE_MODELSCOPE_HUB: 1 # 1. If version specified (work_dispatch), do specified branch accuracy test # 2. If no version (labeled PR), do accuracy test by default ref: # The branch, tag or SHA to checkout. When checking out the repository that @@ -158,7 +164,7 @@ jobs: repository: vllm-project/vllm path: ./vllm-empty # Please also update this when bump matched version - ref: ${{ github.event.inputs.vllm-version || 'v0.9.1' }} + ref: ${{ github.event.inputs.vllm-version || 'v0.9.2' }} - name: Install vllm-project/vllm from source working-directory: ./vllm-empty @@ -177,11 +183,28 @@ jobs: PIP_EXTRA_INDEX_URL: https://mirrors.huaweicloud.com/ascend/repos/pypi run: | pip install -r requirements-dev.txt - pip install -e . + pip install -v -e . + + - name: Get vLLM commit hash and URL + working-directory: ./vllm-empty + run: | + VLLM_COMMIT=$(git rev-parse --short=7 HEAD) + echo "VLLM_COMMIT=$VLLM_COMMIT" >> $GITHUB_ENV + + - name: Get vLLM-Ascend commit hash and URL + working-directory: ./vllm-ascend + run: | + VLLM_ASCEND_COMMIT=$(git rev-parse --short=7 HEAD) + echo "VLLM_ASCEND_COMMIT=$VLLM_ASCEND_COMMIT" >> $GITHUB_ENV + + - name: Print resolved hashes + run: | + echo "vLLM : ${{ env.VLLM_COMMIT }}" + echo "vLLM-Ascend: ${{ env.VLLM_ASCEND_COMMIT }}" - name: Install lm-eval, ray, and datasets run: | - pip install lm-eval + pip install lm-eval==0.4.8 - name: Collect version info run: | @@ -233,7 +256,10 @@ jobs: --cann_version "${{ env.GHA_CANN_VERSION }}" \ --torch_npu_version "${{ env.GHA_TORCH_NPU_VERSION }}" \ --torch_version "${{ env.GHA_TORCH_VERSION }}" \ - --vllm_version "${{ env.GHA_VLLM_VERSION }}" + --vllm_version "${{ env.GHA_VLLM_VERSION }}" \ + --vllm_commit "${{ env.VLLM_COMMIT }}" \ + --vllm_ascend_commit "${{ env.VLLM_ASCEND_COMMIT }}" \ + --vllm_use_v1 "$VLLM_USE_V1" - name: Generate step summary if: ${{ always() }} @@ -245,12 +271,122 @@ jobs: SAFE_VLLM_ASCEND_VERSION="${GHA_VLLM_ASCEND_VERSION//\//-}" echo "SAFE_VLLM_ASCEND_VERSION=$SAFE_VLLM_ASCEND_VERSION" >> "$GITHUB_ENV" + - name: Check report first line for failure + id: check_report + run: | + REPORT_PATH="./benchmarks/accuracy/${{ steps.report.outputs.markdown_name }}.md" + echo "Scanning $REPORT_PATH for ❌ …" + if grep -q '❌' "$REPORT_PATH"; then + echo "contains_fail=true" >> $GITHUB_OUTPUT + else + echo "contains_fail=false" >> $GITHUB_OUTPUT + fi + - name: Upload Report for V${{ matrix.vllm_use_version }} - if: ${{ github.event_name == 'workflow_dispatch' }} + if: ${{ github.event_name == 'workflow_dispatch' && steps.check_report.outputs.contains_fail == 'false' }} uses: actions/upload-artifact@v4 with: - name: "${{ env.SAFE_VLLM_ASCEND_VERSION }}-${{ steps.report.outputs.markdown_name }}-report" + name: "report-${{ env.SAFE_VLLM_ASCEND_VERSION }}-${{ steps.report.outputs.markdown_name }}" path: ./benchmarks/accuracy/${{ steps.report.outputs.markdown_name }}.md if-no-files-found: warn retention-days: 90 overwrite: true + + create_pr: + runs-on: ubuntu-latest + needs: accuracy_tests + if: ${{ github.event_name == 'workflow_dispatch' }} + env: + UPSTREAM_REPO: vllm-project/vllm-ascend + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + repository: vllm-ascend-ci/vllm-ascend + token: ${{ secrets.PAT_TOKEN }} + ref: main + + - name: Add upstream remote + run: | + git remote add upstream https://github.com/${{ env.UPSTREAM_REPO }}.git + git fetch upstream + git remote -v + + - name: Set Git user info dynamically + run: | + git config user.name "${{ github.actor }}" + git config user.email "${{ github.actor }}@users.noreply.github.com" + + - name: Create or switch to branch + run: | + TIMESTAMP=$(date +%Y%m%d%H%M%S) + BRANCH_NAME="auto-pr/accuracy-report-${TIMESTAMP}" + echo "BRANCH_NAME=${BRANCH_NAME}" >> $GITHUB_ENV + git checkout -B "${BRANCH_NAME}" upstream/${{ github.event.inputs.vllm-ascend-version }} + + - name: Download only current run reports + uses: actions/download-artifact@v4 + with: + path: ./docs/source/developer_guide/evaluation/accuracy_report + pattern: report-* + github-token: ${{ secrets.GITHUB_TOKEN }} + run-id: ${{ github.run_id }} + + - name: Delete old report + run: | + find ./docs/source/developer_guide/evaluation/accuracy_report -maxdepth 1 -type f -name '*.md' ! -name 'index.md' -delete + find ./docs/source/developer_guide/evaluation/accuracy_report -mindepth 2 -type f -name '*.md' -exec mv -f {} ./docs/source/developer_guide/evaluation/accuracy_report \; + find ./docs/source/developer_guide/evaluation/accuracy_report -mindepth 1 -type d -empty -delete + + - name: Update accuracy_report/index.md + run: | + REPORT_DIR="./docs/source/developer_guide/evaluation/accuracy_report" + INDEX_MD="$REPORT_DIR/index.md" + { + echo "# Accuracy Report" + echo "" + echo ":::{toctree}" + echo ":caption: Accuracy Report" + echo ":maxdepth: 1" + + for report in "$REPORT_DIR"/*.md; do + filename="$(basename "$report" .md)" + if [ "$filename" != "index" ]; then + echo "$filename" + fi + done + echo ":::" + } > "$INDEX_MD" + + - name: push accuracy report + env: + GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} + run: | + git add ./docs/source/developer_guide/evaluation/accuracy_report/*.md + git commit -s -m "[Doc] Update accuracy reports for ${{ github.event.inputs.vllm-ascend-version }}" + git push -f origin "${{ env.BRANCH_NAME }}" + + - name: Create PR in upstream via API + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.PAT_TOKEN }} + script: | + const pr = await github.rest.pulls.create({ + owner: 'vllm-project', + repo: 'vllm-ascend', + head: `vllm-ascend-ci:${{ env.BRANCH_NAME }}`, + base: '${{ github.event.inputs.vllm-ascend-version }}', + title: `[Doc] Update accuracy reports for ${{ github.event.inputs.vllm-ascend-version }}`, + body: `The accuracy results running on NPU Altlas A2 have changed, updating reports for: + ${{ + github.event.inputs.models == 'all' + && 'All models (Qwen/Qwen3-30B-A3B, Qwen2.5-VL-7B-Instruct, Qwen3-8B-Base)' + || github.event.inputs.models + }} + + - [Workflow run][1] + + [1]: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}` + }); + core.info(`Created PR #${pr.data.number}`); + \ No newline at end of file diff --git a/.github/workflows/doc_codespell.yaml b/.github/workflows/doc_codespell.yaml new file mode 100644 index 0000000000..156ad71e59 --- /dev/null +++ b/.github/workflows/doc_codespell.yaml @@ -0,0 +1,33 @@ + +name: 'doc-codespell' + +on: + pull_request: + branches: + - 'main' + - '*-dev' + paths: + - 'docs/**' + +jobs: + codespell: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + - name: Run codespell check + run: | + CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') + CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever') + + codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" diff --git a/.github/workflows/format_pr_body.yaml b/.github/workflows/format_pr_body.yaml new file mode 100644 index 0000000000..2c91ab2278 --- /dev/null +++ b/.github/workflows/format_pr_body.yaml @@ -0,0 +1,63 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +name: format / pr body + +on: + # The PR updated when PR opened and push new commits + pull_request_target: + types: [opened, synchronize] + branches: + - 'main' + +permissions: + pull-requests: write + +jobs: + update-description: + name: update vLLM version + runs-on: ubuntu-latest + + steps: + - name: Checkout vllm-project/vllm repo + uses: actions/checkout@v4 + with: + repository: vllm-project/vllm + path: ./vllm-empty + + - name: Get vLLM version + working-directory: ./vllm-empty + run: | + VLLM_COMMIT=$(git rev-parse HEAD) + echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV + + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + + - name: Get vLLM release version + run: | + VLLM_VERSION=$(python3 docs/source/conf.py | jq .vllm_version | tr -d '"') + echo "VLLM_VERSION=$VLLM_VERSION" >> $GITHUB_ENV + + - name: Update PR description + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + bash .github/format_pr_body.sh "${{ github.event.number }}" "${{ env.VLLM_VERSION }}" "${{ env.VLLM_COMMIT }}" diff --git a/.github/workflows/image_310p_openeuler.yml b/.github/workflows/image_310p_openeuler.yml new file mode 100644 index 0000000000..2626a77a8c --- /dev/null +++ b/.github/workflows/image_310p_openeuler.yml @@ -0,0 +1,114 @@ +name: 'image / openEuler / 310p' +# This is a docker build check and publish job: +# 1. PR Triggered docker image build check +# - is for image build check +# - Enable on main/*-dev branch +# - push: ${{ github.event_name != 'pull_request' }} ==> false +# 2. branches push trigger image publish +# - is for branch/dev/nightly image +# - commits are merge into main/*-dev ==> vllm-ascend:main / vllm-ascend:*-dev +# 3. tags push trigger image publish +# - is for final release image +# - Publish when tag with v* (pep440 version) ===> vllm-ascend:v1.2.3-openeuler|latest / vllm-ascend:v1.2.3rc1-openeuler +on: + pull_request: + branches: + - 'main' + - '*-dev' + paths: + - '.github/workflows/image_310p_openeuler.yml' + - 'Dockerfile.310p.openEuler' + - 'vllm_ascend/**' + - 'setup.py' + - 'pyproject.toml' + - 'requirements.txt' + - 'cmake/**' + - 'CMakeLists.txt' + - 'csrc/**' + push: + # Publish image when tagging, the Dockerfile in tag will be build as tag image + branches: + - 'main' + - '*-dev' + tags: + - 'v*' + paths: + - '.github/workflows/image_310p.openeuler.yml' + - 'Dockerfile.310p.openEuler' + - 'vllm_ascend/**' + +jobs: + build: + name: vllm-ascend image build + runs-on: >- + ${{ + github.event_name == 'push' && github.repository_owner == 'vllm-project' && + 'ubuntu-latest' || + 'ubuntu-24.04-arm' + }} + steps: + - uses: actions/checkout@v4 + + - name: Print + run: | + lscpu + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + # TODO(yikun): add more hub image and a note on release policy for container image + images: | + quay.io/ascend/vllm-ascend + # Note for test case + # https://github.com/marketplace/actions/docker-metadata-action#typeref + # 1. branch job pulish per main/*-dev branch commits + # 2. main and dev pull_request is build only, so the tag pr-N-openeuler is fine + # 3. only pep440 matched tag will be published: + # - v0.7.1 --> v0.7.1-openeuler, latest + # - pre/post/dev: v0.7.1rc1-openeuler/v0.7.1rc1-openeuler/v0.7.1rc1.dev1-openeuler/v0.7.1.post1-openeuler, no latest + # which follow the rule from vLLM with prefix v + # TODO(yikun): the post release might be considered as latest release + tags: | + type=ref,event=branch,suffix=-310p-openeuler + type=ref,event=pr,suffix=-openeuler + type=pep440,pattern={{raw}},suffix=-310p-openeuler + + - name: Free up disk space + uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1 + with: + tool-cache: true + docker-images: false + + - name: Build - Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build - Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Publish - Login to Quay Container Registry + if: ${{ github.event_name == 'push' && github.repository_owner == 'vllm-project' }} + uses: docker/login-action@v3 + with: + registry: quay.io + username: ${{ vars.QUAY_USERNAME }} + password: ${{ secrets.QUAY_PASSWORD }} + + - name: Build and push 310p + uses: docker/build-push-action@v6 + with: + platforms: >- + ${{ + github.event_name == 'push' && github.repository_owner == 'vllm-project' && + 'linux/amd64,linux/arm64' || + 'linux/arm64' + }} + # use the current repo path as the build context, ensure .git is contained + context: . + # only trigger when tag, branch/main push + push: ${{ github.event_name == 'push' && github.repository_owner == 'vllm-project' }} + labels: ${{ steps.meta.outputs.labels }} + tags: ${{ steps.meta.outputs.tags }} + file: Dockerfile.310p.openEuler + build-args: | + PIP_INDEX_URL=https://pypi.org/simple diff --git a/.github/workflows/image_310p_ubuntu.yml b/.github/workflows/image_310p_ubuntu.yml new file mode 100644 index 0000000000..638c0e328a --- /dev/null +++ b/.github/workflows/image_310p_ubuntu.yml @@ -0,0 +1,110 @@ +name: 'image / Ubuntu / 310p' +# This is a docker build check and publish job: +# 1. PR Triggered docker image build check +# - is for image build check +# - Enable on main/*-dev branch +# - push: ${{ github.event_name != 'pull_request' }} ==> false +# 2. branches push trigger image publish +# - is for branch/dev/nightly image +# - commits are merge into main/*-dev ==> vllm-ascend:main / vllm-ascend:*-dev +# 3. tags push trigger image publish +# - is for final release image +# - Publish when tag with v* (pep440 version) ===> vllm-ascend:v1.2.3|latest / vllm-ascend:v1.2.3rc1 +on: + pull_request: + branches: + - 'main' + - '*-dev' + paths: + - '.github/workflows/image_310p_ubuntu.yml' + - 'Dockerfile.310p' + - 'vllm_ascend/**' + - 'setup.py' + - 'pyproject.toml' + - 'requirements.txt' + - 'cmake/**' + - 'CMakeLists.txt' + - 'csrc/**' + push: + # Publish image when tagging, the Dockerfile in tag will be build as tag image + branches: + - 'main' + - '*-dev' + tags: + - 'v*' + paths: + - '.github/workflows/image_310p_ubuntu.yml' + - 'Dockerfile.310p' + - 'vllm_ascend/**' +jobs: + + build: + name: vllm-ascend image build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Print + run: | + lscpu + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + # TODO(yikun): add more hub image and a note on release policy for container image + images: | + quay.io/ascend/vllm-ascend + # Note for test case + # https://github.com/marketplace/actions/docker-metadata-action#typeref + # 1. branch job pulish per main/*-dev branch commits + # 2. main and dev pull_request is build only, so the tag pr-N is fine + # 3. only pep440 matched tag will be published: + # - v0.7.1 --> v0.7.1, latest + # - pre/post/dev: v0.7.1rc1/v0.7.1rc1/v0.7.1rc1.dev1/v0.7.1.post1, no latest + # which follow the rule from vLLM with prefix v + # TODO(yikun): the post release might be considered as latest release + tags: | + type=ref,event=branch,suffix=-310p + type=ref,event=pr,suffix=-310p + type=pep440,pattern={{raw}},suffix=-310p + + - name: Free up disk space + uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1 + with: + tool-cache: true + docker-images: false + + - name: Build - Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build - Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Publish - Login to Quay Container Registry + if: ${{ github.event_name == 'push' && github.repository_owner == 'vllm-project' }} + uses: docker/login-action@v3 + with: + registry: quay.io + username: ${{ vars.QUAY_USERNAME }} + password: ${{ secrets.QUAY_PASSWORD }} + + - name: Build and push 310p + uses: docker/build-push-action@v6 + with: + platforms: >- + ${{ + github.event_name == 'push' && github.repository_owner == 'vllm-project' && + 'linux/amd64,linux/arm64' || + 'linux/amd64' + }} + # use the current repo path as the build context, ensure .git is contained + context: . + file: Dockerfile.310p + # only trigger when tag, branch/main push + push: ${{ github.event_name == 'push' && github.repository_owner == 'vllm-project' }} + labels: ${{ steps.meta.outputs.labels }} + tags: ${{ steps.meta.outputs.tags }} + build-args: | + PIP_INDEX_URL=https://pypi.org/simple diff --git a/.github/workflows/image_openeuler.yml b/.github/workflows/image_openeuler.yml index 690d8145cd..c954e569b3 100644 --- a/.github/workflows/image_openeuler.yml +++ b/.github/workflows/image_openeuler.yml @@ -94,7 +94,7 @@ jobs: username: ${{ vars.QUAY_USERNAME }} password: ${{ secrets.QUAY_PASSWORD }} - - name: Build and push + - name: Build and push 910b uses: docker/build-push-action@v6 with: platforms: >- diff --git a/.github/workflows/image_ubuntu.yml b/.github/workflows/image_ubuntu.yml index a2cfbcefb0..69fe385fe9 100644 --- a/.github/workflows/image_ubuntu.yml +++ b/.github/workflows/image_ubuntu.yml @@ -90,7 +90,7 @@ jobs: username: ${{ vars.QUAY_USERNAME }} password: ${{ secrets.QUAY_PASSWORD }} - - name: Build and push + - name: Build and push 910b uses: docker/build-push-action@v6 with: platforms: >- @@ -101,6 +101,7 @@ jobs: }} # use the current repo path as the build context, ensure .git is contained context: . + file: Dockerfile # only trigger when tag, branch/main push push: ${{ github.event_name == 'push' && github.repository_owner == 'vllm-project' }} labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/nightly_benchmarks.yaml b/.github/workflows/nightly_benchmarks.yaml index 2b9c062957..6644e6f9ef 100644 --- a/.github/workflows/nightly_benchmarks.yaml +++ b/.github/workflows/nightly_benchmarks.yaml @@ -50,10 +50,7 @@ jobs: strategy: matrix: include: - - vllm_branch: v0.9.1 - vllm_ascend_branch: main - vllm_use_v1: 0 - - vllm_branch: v0.9.0 + - vllm_branch: v0.9.2 vllm_ascend_branch: main vllm_use_v1: 1 max-parallel: 1 @@ -72,8 +69,7 @@ jobs: --device /dev/devmm_svm --device /dev/hisi_hdc env: - HF_ENDPOINT: https://hf-mirror.com - HF_TOKEN: ${{ secrets.HF_TOKEN }} + VLLM_USE_MODELSCOPE: True ES_OM_DOMAIN: ${{ secrets.ES_OM_DOMAIN }} ES_OM_AUTHORIZATION: ${{ secrets.ES_OM_AUTHORIZATION }} VLLM_USE_V1: ${{ matrix.vllm_use_v1 }} @@ -118,6 +114,7 @@ jobs: env: PIP_EXTRA_INDEX_URL: https://mirrors.huaweicloud.com/ascend/repos/pypi run: | + pip install "transformers<=4.52.4" pip install -e . pip install -r benchmarks/requirements-bench.txt @@ -148,8 +145,8 @@ jobs: - name: Install elastic_tool if: github.event_name != 'pull_request' run: | - pip install escli-tool==0.2.2 - + pip install escli-tool==0.2.3 + - name: Collect pr info from vllm-project/vllm-ascend if: github.event_name != 'pull_request' run: | @@ -179,7 +176,7 @@ jobs: commit_time=$(git show -s --format=%cd $commit_hash --date=iso-strict) commit_time_no_tz=${commit_time::19} pip install -e . - + echo "------------------------" echo "commit_id: $commit_id" echo "commit_title: $commit_title" @@ -187,9 +184,12 @@ jobs: echo "vllm branch: ${{ matrix.vllm_branch }}" echo "vllm-ascend branch: ${{ matrix.vllm_ascend_branch }}" echo "------------------------" - + cd /github/home - bash benchmarks/scripts/run-performance-benchmarks.sh + ERROR_MSG="" + if ! bash benchmarks/scripts/run-performance-benchmarks.sh; then + ERROR_MSG="Benchmark failed to run" + fi # send the result to es escli add --vllm_branch ${{ matrix.vllm_branch }} \ --vllm_ascend_branch ${{ matrix.vllm_ascend_branch }} \ @@ -197,6 +197,7 @@ jobs: --commit_title "$commit_title" \ --created_at "$commit_time_no_tz" \ --res_dir ./benchmarks/results \ + --error "$ERROR_MSG" \ --extra_feat '{"VLLM_USE_V1": "${{ matrix.vllm_use_v1 }}"}' rm -rf ./benchmarks/results cd - diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index f66a01588e..9e9d124c14 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -18,6 +18,9 @@ name: build / wheel on: + schedule: + # Runs at 23:00 UTC (7:00 AM Beijing) every day + - cron: '0 23 * * *' pull_request: branches: - 'main' @@ -55,7 +58,11 @@ jobs: strategy: matrix: os: [ubuntu-24.04, ubuntu-24.04-arm] - python-version: ['3.9', '3.10', '3.11'] + # PR only trigger latest version + python-version: ${{ fromJSON( + (github.event_name == 'pull_request' && '["3.11"]') || + '["3.9", "3.10", "3.11"]' + ) }} runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/vllm_ascend_doctest.yaml b/.github/workflows/vllm_ascend_doctest.yaml index 67f98fbaf7..f26df372cd 100644 --- a/.github/workflows/vllm_ascend_doctest.yaml +++ b/.github/workflows/vllm_ascend_doctest.yaml @@ -30,8 +30,8 @@ on: - 'tests/e2e/common.sh' - 'tests/e2e/run_doctests.sh' schedule: - # Runs every 4 hours - - cron: '0 */4 * * *' + # Runs every 12 hours + - cron: '0 */12 * * *' # Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly # declared as "shell: bash -el {0}" on steps that need to be properly activated. @@ -65,37 +65,18 @@ jobs: cd /vllm-workspace/vllm git --no-pager log -1 || true - - name: Config OS mirrors - Ubuntu - if: ${{ !endsWith(matrix.vllm_verison, '-openeuler') }} - run: | - sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list - apt-get update -y - apt install -y gcc g++ libnuma-dev git curl jq - - - name: Config OS mirrors - openEuler - if: ${{ endsWith(matrix.vllm_verison, '-openeuler') }} - run: | - yum update -y - yum install -y gcc g++ numactl-devel git curl jq - - - name: Config pip mirrors - run: | - pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - - name: Checkout vllm-project/vllm-ascend repo uses: actions/checkout@v4 - name: Run vllm-ascend/tests/e2e/run_doctests.sh run: | # PWD: /__w/vllm-ascend/vllm-ascend + # Make sure e2e tests are latest echo "Replacing /vllm-workspace/vllm-ascend/tests/e2e ..." rm -rf /vllm-workspace/vllm-ascend/tests/e2e mkdir -p /vllm-workspace/vllm-ascend/tests cp -r tests/e2e /vllm-workspace/vllm-ascend/tests/ - # TODO(yikun): Remove this after conf.py merged - cp docs/source/conf.py /vllm-workspace/vllm-ascend/docs/source/ - # Simulate container to enter directory cd /workspace diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 540680dd2f..236b10f13c 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -18,8 +18,22 @@ name: 'test' on: - schedule: - - cron: '0 23 * * *' + push: + # Enable merged test per commit + branches: + - 'main' + paths: + - '*.txt' + - '**/*.py' + - '.github/workflows/vllm_ascend_test.yaml' + - '!docs/**' + - 'pytest.ini' + - '!benchmarks/**' + - 'tools/mypy.sh' + - 'mypy.ini' + - '.github/workflows/*.ya?ml' + - '.github/workflows/actionlint.*' + - '.github/workflows/matchers/actionlint.json' pull_request: branches: - 'main' @@ -29,6 +43,7 @@ on: - '**/*.py' - '.github/workflows/vllm_ascend_test.yaml' - '!docs/**' + - '!examples/**' - 'pytest.ini' - '!benchmarks/**' - 'tools/mypy.sh' @@ -52,6 +67,8 @@ concurrency: jobs: lint: + # Only trigger lint on pull request + if: ${{ github.event_name == 'pull_request' }} runs-on: ubuntu-latest strategy: matrix: @@ -69,7 +86,7 @@ jobs: - name: Run codespell check run: | CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') - CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn') + CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever') codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" - name: Analysing the code with ruff @@ -86,10 +103,6 @@ jobs: pip install yapf==0.32.0 yapf --diff --recursive . - - name: Install dependencies - run: | - pip install -r requirements-dev.txt --extra-index-url https://download.pytorch.org/whl/cpu - - name: Checkout vllm-project/vllm repo uses: actions/checkout@v4 with: @@ -109,6 +122,10 @@ jobs: pip install -r requirements/build.txt --extra-index-url https://download.pytorch.org/whl/cpu VLLM_TARGET_DEVICE=empty pip install . + - name: Install dependencies + run: | + pip install -r requirements-dev.txt --extra-index-url https://download.pytorch.org/whl/cpu + - name: Mypy Check run: | echo "::add-matcher::.github/workflows/matchers/mypy.json" @@ -117,21 +134,22 @@ jobs: ut: needs: [lint] name: unit test - if: ${{ needs.lint.result == 'success' }} + # only trigger e2e test on [pull request after lint passed] and [merged commit] + if: ${{ needs.lint.result == 'success' || github.event_name == 'push' }} runs-on: ubuntu-latest container: - image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 + image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 env: VLLM_LOGGING_LEVEL: ERROR VLLM_USE_MODELSCOPE: True strategy: matrix: - vllm_version: [main, v0.9.1] + vllm_version: [main, v0.9.2] steps: - name: Install packages run: | apt-get update -y - apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev + apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev curl gnupg2 - name: Checkout vllm-project/vllm repo uses: actions/checkout@v4 @@ -163,16 +181,27 @@ jobs: TORCH_DEVICE_BACKEND_AUTOLOAD: 0 run: | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib - pytest -sv tests/ut + pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut + + - name: Upload coverage to Codecov + if: ${{ matrix.vllm_version == 'main' }} + uses: codecov/codecov-action@v5 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + flags: unittests + name: vllm-ascend + verbose: true e2e: needs: [lint] - if: ${{ needs.lint.result == 'success' }} + # only trigger e2e test on pull request after lint passed + if: ${{ needs.lint.result == 'success' && github.event_name == 'pull_request' }} strategy: max-parallel: 2 matrix: os: [linux-arm64-npu-1] - vllm_version: [main, v0.9.1] + vllm_version: [main, v0.9.2] name: singlecard e2e test runs-on: ${{ matrix.os }} container: @@ -180,6 +209,7 @@ jobs: image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 env: VLLM_LOGGING_LEVEL: ERROR + VLLM_USE_MODELSCOPE: True steps: - name: Check npu and CANN info run: | @@ -228,17 +258,22 @@ jobs: VLLM_USE_MODELSCOPE: True run: | pytest -sv tests/e2e/singlecard/test_offline_inference.py - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/singlecard/test_ilama_lora.py - # TODO(sss): guided decoding doesn't work, fix it later - # pytest -sv tests/e2e/singlecard/test_guided_decoding.py + pytest -sv tests/e2e/singlecard/test_ilama_lora.py + pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_camem.py + pytest -sv tests/e2e/singlecard/test_embedding.py pytest -sv tests/e2e/singlecard/ \ --ignore=tests/e2e/singlecard/test_offline_inference.py \ --ignore=tests/e2e/singlecard/test_ilama_lora.py \ --ignore=tests/e2e/singlecard/test_guided_decoding.py \ - --ignore=tests/e2e/singlecard/test_camem.py + --ignore=tests/e2e/singlecard/test_camem.py \ + --ignore=tests/e2e/singlecard/test_embedding.py \ + --ignore=tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py \ + --ignore=tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py + # ------------------------------------ v1 spec decode test ------------------------------------ # + VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py + # TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed + VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py - name: Run e2e test on V0 engine if: ${{ github.event_name == 'schedule' }} @@ -247,21 +282,18 @@ jobs: VLLM_USE_MODELSCOPE: True run: | pytest -sv tests/e2e/singlecard/test_offline_inference.py - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/singlecard/test_ilama_lora.py - # guided decoding doesn't work, fix it later - # pytest -sv tests/e2e/singlecard/test_guided_decoding.py + pytest -sv tests/e2e/singlecard/test_ilama_lora.py + pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_camem.py pytest -sv tests/e2e/singlecard/test_prompt_embedding.py + pytest -sv tests/e2e/singlecard/test_embedding.py pytest -sv tests/e2e/singlecard/ \ --ignore=tests/e2e/singlecard/test_offline_inference.py \ --ignore=tests/e2e/singlecard/test_ilama_lora.py \ --ignore=tests/e2e/singlecard/test_guided_decoding.py \ --ignore=tests/e2e/singlecard/test_camem.py \ --ignore=tests/e2e/singlecard/test_prompt_embedding.py \ - --ignore=tests/e2e/singlecard/core/test_ascend_scheduler.py \ - --ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py + --ignore=tests/e2e/singlecard/test_embedding.py e2e-4-cards: needs: [e2e] @@ -270,7 +302,7 @@ jobs: max-parallel: 1 matrix: os: [linux-arm64-npu-4] - vllm_version: [main, v0.9.1] + vllm_version: [main, v0.9.2] name: multicard e2e test runs-on: ${{ matrix.os }} container: @@ -326,16 +358,19 @@ jobs: VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_USE_MODELSCOPE: True run: | - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py + pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. # To avoid oom, we need to run the test in a single process. + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 - pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + pytest -sv tests/e2e/multicard/test_data_parallel.py + pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ + --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ + --ignore=tests/e2e/multicard/test_data_parallel.py - name: Run vllm-project/vllm-ascend test on V0 engine if: ${{ github.event_name == 'schedule' }} @@ -343,13 +378,13 @@ jobs: VLLM_USE_V1: 0 VLLM_USE_MODELSCOPE: True run: | - # TODO: switch hf to modelscope - VLLM_USE_MODELSCOPE=False HF_ENDPOINT=https://hf-mirror.com \ - pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py + pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. # To avoid oom, we need to run the test in a single process. pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ - pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 - pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py + pytest -sv tests/e2e/multicard/test_data_parallel.py + pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ + --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ + --ignore=tests/e2e/multicard/test_data_parallel.py diff --git a/.github/workflows/vllm_ascend_test_long_term.yaml b/.github/workflows/vllm_ascend_test_long_term.yaml index e249849e19..9a33b3aca8 100644 --- a/.github/workflows/vllm_ascend_test_long_term.yaml +++ b/.github/workflows/vllm_ascend_test_long_term.yaml @@ -43,16 +43,15 @@ jobs: max-parallel: 2 matrix: os: [linux-arm64-npu-1, linux-arm64-npu-4] - vllm_version: [main, v0.9.1] + vllm_version: [main, v0.9.2] name: vLLM Ascend long term test runs-on: ${{ matrix.os }} container: # TODO(yikun): Remove m.daocloud.io prefix when infra proxy ready image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 env: - HF_ENDPOINT: https://hf-mirror.com - HF_TOKEN: ${{ secrets.HF_TOKEN }} VLLM_LOGGING_LEVEL: ERROR + VLLM_USE_MODELSCOPE: True steps: - name: Check npu and CANN info run: | @@ -97,13 +96,13 @@ jobs: - name: Run vllm-project/vllm-ascend long term test run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then - # spec decode test - VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py - # TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed - # VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py - VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process - pytest -sv tests/e2e/long_term/spec_decode --ignore=tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py + # v0 spec decode test + # TODO: Revert me when test_mtp_correctness is fixed + # VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process + pytest -sv tests/e2e/long_term/spec_decode_v0 --ignore=tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py + # accuracy test single card pytest -sv tests/e2e/long_term/test_accuracy.py else + # accuracy test multi card VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py fi diff --git a/.github/workflows/vllm_ascend_test_pd.yaml b/.github/workflows/vllm_ascend_test_pd.yaml index 932b3e59b3..e6bb6a6988 100644 --- a/.github/workflows/vllm_ascend_test_pd.yaml +++ b/.github/workflows/vllm_ascend_test_pd.yaml @@ -41,7 +41,11 @@ jobs: if: ${{ contains(github.event.pull_request.labels.*.name, 'pd-test') && contains(github.event.pull_request.labels.*.name, 'ready-for-test') || github.event_name == 'schedule' }} strategy: matrix: - vllm_verison: [main, v0.9.1] + vllm_verison: [ + # revert me when V1 disaggregation prefill is merged in main + # main, + v0.9.1 + ] name: vLLM Ascend prefilling decoding disaggregation test runs-on: linux-arm64-npu-static-8 @@ -60,8 +64,7 @@ jobs: --device /dev/devmm_svm --device /dev/hisi_hdc env: - HF_ENDPOINT: https://hf-mirror.com - HF_TOKEN: ${{ secrets.HF_TOKEN }} + VLLM_USE_MODELSCOPE: True steps: - name: Check npu and CANN info run: | diff --git a/README.md b/README.md index 7d0966c8d4..7e5918c763 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l - Software: * Python >= 3.9, < 3.12 * CANN >= 8.1.RC1 - * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250528 + * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250619 * vLLM (the same version as vllm-ascend) ## Getting Started @@ -46,7 +46,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l Please refer to [QuickStart](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html) and [Installation](https://vllm-ascend.readthedocs.io/en/latest/installation.html) for more details. ## Contributing -See [CONTRIBUTING](https://vllm-ascend.readthedocs.io/en/main/developer_guide/contributing.html) for more details, which is a step-by-step guide to help you set up development environment, build and test. +See [CONTRIBUTING](https://vllm-ascend.readthedocs.io/en/latest/developer_guide/contribution/index.html) for more details, which is a step-by-step guide to help you set up development environment, build and test. We welcome and value any contributions and collaborations: - Please let us know if you encounter a bug by [filing an issue](https://github.com/vllm-project/vllm-ascend/issues) @@ -67,7 +67,7 @@ Below is maintained branches: | v0.7.1-dev | Unmaintained | Only doc fixed is allowed | | v0.7.3-dev | Maintained | CI commitment for vLLM 0.7.3 version | -Please refer to [Versioning policy](https://vllm-ascend.readthedocs.io/en/main/developer_guide/versioning_policy.html) for more details. +Please refer to [Versioning policy](https://vllm-ascend.readthedocs.io/en/latest/community/versioning_policy.html) for more details. ## Weekly Meeting diff --git a/README.zh.md b/README.zh.md index 2d2062a8b4..55a40f5380 100644 --- a/README.zh.md +++ b/README.zh.md @@ -39,7 +39,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP - 软件: * Python >= 3.9, < 3.12 * CANN >= 8.1.RC1 - * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250528 + * PyTorch >= 2.5.1, torch-npu >= 2.5.1.post1.dev20250619 * vLLM (与vllm-ascend版本一致) ## 开始使用 @@ -47,7 +47,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP 请查看[快速开始](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)和[安装指南](https://vllm-ascend.readthedocs.io/en/latest/installation.html)了解更多. ## 贡献 -请参考 [CONTRIBUTING]((https://vllm-ascend.readthedocs.io/en/main/developer_guide/contributing.html)) 文档了解更多关于开发环境搭建、功能测试以及 PR 提交规范的信息。 +请参考 [CONTRIBUTING]((https://vllm-ascend.readthedocs.io/en/latest/developer_guide/contribution/index.html)) 文档了解更多关于开发环境搭建、功能测试以及 PR 提交规范的信息。 我们欢迎并重视任何形式的贡献与合作: - 请通过[Issue](https://github.com/vllm-project/vllm-ascend/issues)来告知我们您遇到的任何Bug。 @@ -67,7 +67,7 @@ vllm-ascend有主干分支和开发分支。 | v0.7.1-dev | Unmaintained | 只允许文档修复 | | v0.7.3-dev | Maintained | 基于vLLM v0.7.3版本CI看护 | -请参阅[版本策略](https://vllm-ascend.readthedocs.io/en/main/developer_guide/versioning_policy.html)了解更多详细信息。 +请参阅[版本策略](https://vllm-ascend.readthedocs.io/en/latest/community/versioning_policy.html)了解更多详细信息。 ## 社区例会 diff --git a/docs/source/faqs.md b/docs/source/faqs.md index 1de3befb2d..be6d689eff 100644 --- a/docs/source/faqs.md +++ b/docs/source/faqs.md @@ -114,7 +114,7 @@ In scenarios where NPUs have limited HBM (High Bandwidth Memory) capacity, dynam - **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime, see more note in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html). -### 15. Failed to enable NPU graph mode when running DeepSeek? +### 16. Failed to enable NPU graph mode when running DeepSeek? You may encounter the following error if running DeepSeek with NPU graph mode enabled. The allowed number of queries per kv when enabling both MLA and Graph mode only support {32, 64, 128}, **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be done in the future. And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads / num_kv_heads in {32, 64, 128}. @@ -123,3 +123,6 @@ And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tenso [rank0]: RuntimeError: EZ9999: Inner Error! [rank0]: EZ9999: [PID: 62938] 2025-05-27-06:52:12.455.807 numHeads / numKvHeads = 8, MLA only support {32, 64, 128}.[FUNC:CheckMlaAttrs][FILE:incre_flash_attention_tiling_check.cc][LINE:1218] ``` + +### 17. Failed to reinstall vllm-ascend from source after uninstalling vllm-ascend? +You may encounter the problem of C compilation failure when reinstalling vllm-ascend from source using pip. If the installation fails, it is recommended to use `python setup.py install` to install, or use `python setup.py clean` to clear the cache. diff --git a/docs/source/installation.md b/docs/source/installation.md index c290f7e5f7..b7eb611a03 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -12,7 +12,7 @@ This document describes how to install vllm-ascend manually. | Software | Supported version | Note | |---------------|----------------------------------|-------------------------------------------| | CANN | >= 8.1.RC1 | Required for vllm-ascend and torch-npu | - | torch-npu | >= 2.5.1.post1.dev20250528 | Required for vllm-ascend | + | torch-npu | >= 2.5.1.post1.dev20250619 | Required for vllm-ascend | | torch | >= 2.5.1 | Required for torch-npu and vllm | You have 2 way to install: @@ -246,8 +246,7 @@ for output in outputs: Then run: ```bash -# Try `export VLLM_USE_MODELSCOPE=true` and `pip install modelscope` -# to speed up download if huggingface is not reachable. +# export VLLM_USE_MODELSCOPE=true to speed up download if huggingface is not reachable. python example.py ``` diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index d4756ef5e1..2a0194209b 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -28,7 +28,6 @@ The following table lists the additional configuration options available in vLLM |-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------| | `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | | `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | -| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. | | `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | diff --git a/examples/disaggregate_prefill_v1/README.md b/examples/disaggregate_prefill_v1/README.md new file mode 100644 index 0000000000..544d5ba020 --- /dev/null +++ b/examples/disaggregate_prefill_v1/README.md @@ -0,0 +1,234 @@ +# Disaggregated Prefill-Decode Deployment Guide + +## Overview +This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2). + +## Prerequisites +- Ascend NPU environment with vLLM 0.9.1 installed +- Network interfaces configured for distributed communication (eg: eth0) +- Model weights located at `/data01/deepseek_r1_w8a8_zhw` + +## Rank table generation +The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode: + +Run the following command on every node to generate the rank table: +```shell +cd vllm-ascend/examples/disaggregate_prefill_v1/ +bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \ + --npus-per-node 8 --network-card-name enp189s0f0 --prefill-device-cnt 16 --decode-device-cnt 16 +``` +Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json` + +## Start disaggregated vLLM-ascend service +Execution Sequence +- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 +- Start Prefill on Node 1 (P1) +- Start Prefill on Node 2 (P2) +- Start Decode on Node 1 (D1) +- Start Decode on Node 2 (D2) +- Start proxy server on Node1 + +* Run prefill server P1 on first node +```shell +export HCCL_IF_IP=172.19.32.175 # node ip +export GLOO_SOCKET_IFNAME="eth0" # network card name +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --data-parallel-size 2 \ + --data-parallel-size-local 1 \ + --api-server-count 2 \ + --data-parallel-address 172.19.32.175 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run prefill server P2 on second node +```shell +export HCCL_IF_IP=172.19.241.49 +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --headless \ + --data-parallel-size 2 \ + --data-parallel-start-rank 1 \ + --data-parallel-size-local 1 \ + --data-parallel-address 172.19.32.175 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", \ + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run decode server d1 on third node +```shell +export HCCL_IF_IP=172.19.123.51 +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --data-parallel-size 2 \ + --data-parallel-size-local 1 \ + --api-server-count 2 \ + --data-parallel-address 172.19.123.51 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run decode server d2 on last node +```shell +export HCCL_IF_IP=172.19.190.36 +export GLOO_SOCKET_IFNAME="eth0" +export TP_SOCKET_IFNAME="eth0" +export HCCL_SOCKET_IFNAME="eth0" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 +export VLLM_USE_V1=1 +vllm serve /data01/deepseek_r1_w8a8_zhw \ + --host 0.0.0.0 \ + --port 20002 \ + --headless \ + --data-parallel-size 2 \ + --data-parallel-start-rank 1 \ + --data-parallel-size-local 1 \ + --data-parallel-address 172.19.123.51 \ + --data-parallel-rpc-port 13356 \ + --tensor-parallel-size 8 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --seed 1024 \ + --served-model-name deepseek \ + --max-model-len 6144 \ + --max-num-batched-tokens 6144 \ + --trust-remote-code \ + --enforce-eager \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" + }' \ + --additional-config \ + '{"torchair_graph_config": {"enabled":false, "enable_multistream_shared_expert":false}, "ascend_scheduler_config":{"enabled":true, "enable_chunked_prefill":false}}' +``` + +* Run proxy server on the first node +```shell +cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1 +python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002 +``` + +* Verification +Check service health using the proxy server endpoint: +```shell +curl http://localhost:1025/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "deepseek", + "prompt": "Who are you?", + "max_tokens": 100, + "temperature": 0 + }' +``` + +* Performance +Test performance with vllm benchmark +```shell +cd /vllm-workspace/vllm/benchmarks +python3 benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --random-input-len 4096 \ + --random-output-len 1536 \ + --num-prompts 256 \ + --ignore-eos \ + --model deepseek \ + --tokenizer /data01/deepseek_r1_w8a8_zhw \ + --host localhost \ + --port 8000 \ + --endpoint /v1/completions \ + --max-concurrency 4 \ + --request-rate 4 +``` \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.py b/examples/disaggregate_prefill_v1/gen_ranktable.py new file mode 100644 index 0000000000..d170f3ba06 --- /dev/null +++ b/examples/disaggregate_prefill_v1/gen_ranktable.py @@ -0,0 +1,120 @@ +import argparse +import json +import os + +import torch.distributed as dist + +from vllm_ascend.soc_info import NPUSocInfo + +parser = argparse.ArgumentParser( + description="Arguments of rank table generator", ) +parser.add_argument("--local-host", type=str, required=True, help="local ip") +parser.add_argument("--prefill-device-cnt", + type=int, + required=True, + help="number of prefill devices") +parser.add_argument("--decode-device-cnt", + type=int, + required=True, + help="number of decode devices") +args = parser.parse_args() +local_host = args.local_host +prefill_device_cnt = args.prefill_device_cnt +decode_device_cnt = args.decode_device_cnt + +print("enter py") + +hccn_tool_path = os.environ.get("HCCN_TOOL_PATH", + "/usr/local/Ascend/driver/tools/hccn_tool") +master_addr = os.environ.get("MASTER_ADDR") +master_port = os.environ.get("MASTER_PORT") +rank = os.environ.get("RANK") +local_rank = os.environ.get("LOCAL_RANK") +# This variable is set by torchrun, +# and is different from WORLD_SIZE in gen_rank_table.sh. +world_size = os.environ.get("WORLD_SIZE") +soc_info = NPUSocInfo() + + +def get_cmd_stdout(cmd): + import subprocess + return subprocess.run(cmd, capture_output=True, + shell=True).stdout.decode("utf-8").strip() + + +print(f"local_host: {local_host}") +print("gen ranktable.json") + +num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split( + ":")[1].strip() +num_cards = int(num_cards) +chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split( + "\n")[0].split(":")[1].strip() +chips_per_card = int(chips_per_card) + +# generate local device list for local rank 0, and gather it to all ranks +local_device_list: list[dict[str, str]] = list() +if local_rank == "0": + super_pod_id = "0" + for card_id in range(num_cards): + for chip_id in range(chips_per_card): + device_id = card_id * chips_per_card + chip_id + if soc_info.is_a3: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr" + ).split(":")[1].strip() + super_device_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID" + ).split(":")[1].strip() + super_pod_id = get_cmd_stdout( + f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\"" + ).split(":")[1].strip() + else: + device_ip = get_cmd_stdout( + f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr" + ).split(":")[1].strip() + + device_info = { + "server_id": local_host, + "device_id": str(device_id), + "device_ip": str(device_ip), + } + if soc_info.is_a3: + device_info.update({ + "super_pod_id": str(super_pod_id), + "super_device_id": str(super_device_id) + }) + local_device_list.append(device_info) + +dist.init_process_group(backend=dist.Backend.GLOO) +global_device_list = [None] * dist.get_world_size() +dist.all_gather_object(global_device_list, local_device_list) +global_device_list = [ + device_info for device_list in global_device_list + for device_info in device_list # type: ignore[attr-defined] +] +cnt = 1 +for device_info in global_device_list: # type: ignore[assignment] + device_info["cluster_id"] = str(cnt) + cnt += 1 +assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \ +"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster" +ranktable = { + "version": + "1.2", + "server_count": + str(world_size), + "prefill_device_list": + global_device_list[:prefill_device_cnt], + "decode_device_list": + global_device_list[prefill_device_cnt:prefill_device_cnt + + decode_device_cnt], + "status": + "completed" +} + +if local_rank == '0': + with open("ranktable.json", "w") as f: + json.dump(ranktable, f, indent=4) + + print("gen ranktable.json done") diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.sh b/examples/disaggregate_prefill_v1/gen_ranktable.sh new file mode 100644 index 0000000000..33d4a32e8d --- /dev/null +++ b/examples/disaggregate_prefill_v1/gen_ranktable.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} + +NPUS_PER_NODE=8 +while [[ $# -gt 0 ]]; do + case "$1" in + --ips) + shift + while [[ $# -gt 0 && ! "$1" == --* ]]; do + IPs+=("$1") + shift + done + ;; + --npus-per-node) + shift + NPUS_PER_NODE="$1" + shift + ;; + --network-card-name) + shift + NETWORK_CARD_NAME="$1" + shift + ;; + --prefill-device-cnt) + shift + PREFILL_DEVICE_CNT="$1" + shift + ;; + --decode-device-cnt) + shift + DECODE_DEVICE_CNT="$1" + shift + ;; + esac +done +LOCAL_HOSTS=($(hostname -I)) +LOCAL_HOST="127.0.0.1" +MASTER_ADDR=${IPs[0]} +MASTER_PORT=6657 +NNODES=${#IPs[@]} +NODE_RANK="8" +for i in "${!IPs[@]}"; do + ip="${IPs[$i]}" + for local_host in "${LOCAL_HOSTS[@]}"; do + if [[ "$local_host" == "$ip" ]]; then + LOCAL_HOST=$local_host + NODE_RANK=$i + break 2 + fi + done +done + +if [[ $NODE_RANK == "" ]];then + echo "[Error] para \"NODE_RANK\" must be defined" + exit 1 +fi + +WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES)) +RANKSTART=`expr $NPUS_PER_NODE \* $NODE_RANK` + +echo "========>param:" +echo "LOCAL_HOST": $LOCAL_HOST +echo "WORLD_SIZE: " $WORLD_SIZE +echo "RANKSTART": $RANKSTART +echo "NNODES": $NNODES +echo "NODE_RANK": $NODE_RANK +echo "===============" + +if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then + GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \ + --nproc_per_node 1 \ + --nnodes ${NNODES} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT +fi \ No newline at end of file diff --git a/examples/disaggregate_prefill_v1/run_server.sh b/examples/disaggregate_prefill_v1/run_server.sh new file mode 100644 index 0000000000..37cf6d3aee --- /dev/null +++ b/examples/disaggregate_prefill_v1/run_server.sh @@ -0,0 +1,32 @@ +export HCCL_IF_IP=141.61.39.117 +export GLOO_SOCKET_IFNAME="enp48s3u1u1" +export TP_SOCKET_IFNAME="enp48s3u1u1" +export HCCL_SOCKET_IFNAME="enp48s3u1u1" +export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=path-to-rank-table + +export OMP_PROC_BIND=false +export OMP_NUM_THREADS=100 + +export VLLM_USE_V1=1 + +vllm serve model_path \ + --host 0.0.0.0 \ + --port 20002 \ + --tensor-parallel-size 1\ + --seed 1024 \ + --served-model-name dsv3 \ + --max-model-len 2000 \ + ---max-num-batched-tokens 2000 \ + --trust-remote-code \ + --gpu-memory-utilization 0.9 \ + --kv-transfer-config \ + '{"kv_connector": "LLMDataDistCMgrConnector", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": 0, + "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_connector_v1_a3" + }' \ + --additional-config \ + '{"enable_graph_mode": "True"}'\ diff --git a/examples/disaggregate_prefill_v1/toy_proxy_server.py b/examples/disaggregate_prefill_v1/toy_proxy_server.py new file mode 100644 index 0000000000..4478073f74 --- /dev/null +++ b/examples/disaggregate_prefill_v1/toy_proxy_server.py @@ -0,0 +1,261 @@ +# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import itertools +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + + # Create prefill clients + for i, (host, port) in enumerate(global_args.prefiller_instances): + prefiller_base_url = f'http://{host}:{port}/v1' + app.state.prefill_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Create decode clients + for i, (host, port) in enumerate(global_args.decoder_instances): + decoder_base_url = f'http://{host}:{port}/v1' + app.state.decode_clients.append({ + 'client': + httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + 'host': + host, + 'port': + port, + 'id': + i + }) + + # Initialize round-robin iterators + app.state.prefill_iterator = itertools.cycle( + range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle( + range(len(app.state.decode_clients))) + + print(f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients.") + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info['client'].aclose() + + for client_info in app.state.decode_clients: + await client_info['client'].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + + # For prefiller instances + parser.add_argument("--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + "--prefiller-port", + type=int, + nargs="+", + default=[8100]) + + # For decoder instances + parser.add_argument("--decoder-hosts", + "--decoder-host", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", + "--decoder-port", + type=int, + nargs="+", + default=[8200]) + + args = parser.parse_args() + + # Validate and pair hosts with ports + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + + # Create tuples of (host, port) for each service type + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + + return args + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == 'prefill': + client_idx = next(app.state.prefill_iterator) + return app.state.prefill_clients[client_idx] + elif service_type == 'decode': + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + response = await client_info['client'].post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client_info: dict, endpoint: str, + req_data: dict, request_id: str): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + + async with client_info['client'].stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info = get_next_client(request.app, 'prefill') + + # Send request to prefill service + response = await send_request_to_service(prefill_client_info, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + # Get the next decode client in round-robin fashion + decode_client_info = get_next_client(request.app, 'decode') + + logger.debug("Using %s %s", prefill_client_info, decode_client_info) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(decode_client_info, + "/completions", + req_data, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/dp_offline/data_parallel.py b/examples/dp_offline/data_parallel.py index b06c52d8c5..37a14d5f7b 100644 --- a/examples/dp_offline/data_parallel.py +++ b/examples/dp_offline/data_parallel.py @@ -1,85 +1,226 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py # SPDX-License-Identifier: Apache-2.0 -# usage: -# python examples/offline_inference_data_parallel.py -# we need to have a launcher to create multiple data parallel -# ranks. And each rank will create a vLLM instance to process its own prompts. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Usage: +Single node: + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 + +Multi-node: + Node 0 (assume the node has ip of 10.99.48.128): + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=0 \ + --master-addr=10.99.48.128 \ + --master-port=13345 + Node 1: + python examples/offline_inference/data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=1 \ + --master-addr=10.99.48.128 \ + --master-port=13345 +""" -import gc import os +from time import sleep + +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="Data Parallel Inference") + parser.add_argument( + "--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path", + ) + parser.add_argument("--dp-size", + type=int, + default=2, + help="Data parallel size") + parser.add_argument("--tp-size", + type=int, + default=2, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") + parser.add_argument("--enforce-eager", + action="store_true", + help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action="store_true", + help="Trust remote code.") + return parser.parse_args() -def main(): - dp_rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - dp_size = int(os.environ['WORLD_SIZE']) - master_addr = os.environ['MASTER_ADDR'] - master_port = os.environ['MASTER_PORT'] - tp_size = 1 - etp_size = 1 - os.environ["VLLM_DP_RANK"] = str(dp_rank) +def main( + model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + GPUs_per_dp_rank, + enforce_eager, + trust_remote_code, +): + os.environ["VLLM_DP_RANK"] = str(global_dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) - os.environ["VLLM_DP_MASTER_IP"] = master_addr - os.environ["VLLM_DP_MASTER_PORT"] = master_port - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join( - str(i) - for i in range(local_rank * tp_size, (local_rank + 1) * tp_size)) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) - import torch - from vllm import LLM, SamplingParams - from vllm.distributed.parallel_state import ( - destroy_distributed_environment, destroy_model_parallel) + # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the + # engine processes. + # Sample prompts. prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", - ] * 4 + ] * 100 - promts_per_rank = len(prompts) // dp_size - start = dp_rank * promts_per_rank - end = start + promts_per_rank - prompts = prompts[start:end] + # with DP, each rank should process different prompts. + # usually all the DP ranks process a full dataset, + # and each rank processes a different part of the dataset. + floor = len(prompts) // dp_size + remainder = len(prompts) % dp_size + + # Distribute prompts into even groups. + def start(rank): + return rank * floor + min(rank, remainder) + + prompts = prompts[start(global_dp_rank):start(global_dp_rank + 1)] if len(prompts) == 0: + # if any rank has no prompts to process, + # we need to set a placeholder prompt prompts = ["Placeholder"] - print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") - num_seqs = len(prompts) + print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts") + + # Create a sampling params object. + # since we are doing data parallel, every rank can have different + # sampling params. here we set different max_tokens for different + # ranks for demonstration. + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=32, + ) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=4, - min_tokens=4) # Create an LLM. - llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - tensor_parallel_size=tp_size, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=num_seqs, - additional_config={ - 'expert_tensor_parallel_size': etp_size, - 'torchair_graph_config': { - 'enabled': False, - }, - }) + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + trust_remote_code=trust_remote_code, + distributed_executor_backend="mp", + max_model_len=2048, + max_num_batched_tokens=2048, + max_num_seqs=16, + enable_prefix_caching=False, + enable_expert_parallel=True, + gpu_memory_utilization=0.9, + additional_config={ + "ascend_scheduler_config": { + "enabled": True + }, + "torchair_graph_config": { + "enabled": False, + "enable_multistream_shared_expert": False + }, + }, + ) outputs = llm.generate(prompts, sampling_params) - for output in outputs: + # Print the outputs. + for i, output in enumerate(outputs): + if i >= 5: + # print only 5 outputs + break prompt = output.prompt generated_text = output.outputs[0].text - print(f"DP rank {dp_rank}, Prompt: {prompt!r}, " + print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}") - del llm - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() + # Give engines time to pause their processing loops before exiting. + sleep(1) if __name__ == "__main__": - main() + args = parse_args() + + dp_size = args.dp_size + tp_size = args.tp_size + node_size = args.node_size + node_rank = args.node_rank + + if node_size == 1: + dp_master_ip = "127.0.0.1" + dp_master_port = get_open_port() + else: + dp_master_ip = args.master_addr + dp_master_port = args.master_port + + assert dp_size % node_size == 0, "dp_size should be divisible by node_size" + dp_per_node = dp_size // node_size + + from multiprocessing import Process + + procs = [] + for local_dp_rank, global_dp_rank in enumerate( + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): + proc = Process( + target=main, + args=( + args.model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + args.enforce_eager, + args.trust_remote_code, + ), + ) + proc.start() + procs.append(proc) + exit_code = 0 + for proc in procs: + proc.join(timeout=3000) + if proc.exitcode is None: + print( + f"Killing process {proc.pid} that didn't stop within 5 minutes." + ) + proc.kill() + exit_code = 1 + elif proc.exitcode: + exit_code = proc.exitcode + + exit(exit_code) diff --git a/examples/dp_offline/run_dp.sh b/examples/dp_offline/run_dp.sh index 405df604a4..508d966651 100644 --- a/examples/dp_offline/run_dp.sh +++ b/examples/dp_offline/run_dp.sh @@ -1,19 +1,28 @@ +rm -rf ./.torchair_cache/ +rm -rf ./dynamo_* +rm -rf /root/ascend/log/debug/plog/* + +ifname="ifname" +local_ip="local ip" +master_addr="master ip" +model_path="path to model ckpt" + export HCCL_IF_IP=${local_ip} export GLOO_SOCKET_IFNAME=${ifname} export TP_SOCKET_IFNAME=${ifname} export HCCL_SOCKET_IFNAME=${ifname} -# dp_size = node_size * dp_per_node -node_size=1 -node_rank=0 -dp_per_node=4 -master_addr=127.0.0.1 -master_port=12345 - -rm -rf ./.torchair_cache/ -rm -rf ./dynamo_* -rm -rf /root/ascend/log/debug/plog/* +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 +# export VLLM_VERSION=0.9.0 -torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \ - --node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \ - data_parallel.py +python data_parallel.py \ + --model=${model_path} \ + --dp-size=4 \ + --tp-size=4 \ + --enforce-eager \ + --trust-remote-code \ + --node-size=1 \ + --node-rank=0 \ + --master-addr=${master_addr} \ + --master-port=13345 diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py index d8153e38ca..dd8ee9aeb1 100644 --- a/examples/offline_dualbatch_overlap_npu.py +++ b/examples/offline_dualbatch_overlap_npu.py @@ -20,6 +20,7 @@ def main(): tensor_parallel_size=2, max_model_len=4096, trust_remote_code=True, + enable_expert_parallel=True, additional_config={ "torchair_graph_config": { "enabled": False @@ -27,7 +28,6 @@ def main(): "ascend_scheduler_config": { "enabled": True }, - "expert_tensor_parallel_size": 1 }) # Generate texts from the prompts. The output is a list of RequestOutput diff --git a/examples/run_dp_server.sh b/examples/run_dp_server.sh index e2bf4c8158..eb3cfbf510 100644 --- a/examples/run_dp_server.sh +++ b/examples/run_dp_server.sh @@ -1,3 +1,7 @@ +rm -rf ./.torchair_cache/ +rm -rf ./dynamo_* +rm -rf /root/ascend/log/debug/plog/* + export HCCL_IF_IP=2.0.0.0 export GLOO_SOCKET_IFNAME="enp189s0f0" export TP_SOCKET_IFNAME="enp189s0f0" @@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0" export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 -export VLLM_USE_V1=0 - -export ASCEND_RT_VISIBLE_DEVICES=0,1 -export VLLM_DP_SIZE=2 -export VLLM_DP_RANK=0 -export VLLM_DP_MASTER_IP="2.0.0.0" -export VLLM_DP_MASTER_PORT=40001 -export VLLM_DP_PROXY_IP="2.0.0.0" -export VLLM_DP_PROXY_PORT=30002 -export VLLM_DP_MONITOR_PORT=30003 -export VLLM_HTTP_PORT=20001 +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 vllm serve /data/weights/Qwen2.5-0.5B-Instruct \ --host 0.0.0.0 \ - --port 20001 \ - --tensor-parallel-size 1 \ - --seed 1024 \ + --port 20002 \ --served-model-name Qwen \ - --max-model-len 2000 \ - --max-num-batched-tokens 2000 \ + --data-parallel-size 4 \ + --data-parallel-size-local 4 \ + --data-parallel-address 2.0.0.0 \ + --data-parallel-rpc-port 13389 \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --max-num-seqs 16 \ + --max-model-len 4096 \ + --max-num-batched-tokens 4096 \ + --gpu-memory-utilization 0.9 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ \ No newline at end of file + --enforce-eager \ + --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}' diff --git a/examples/run_dp_with_cached_graph_etp16.sh b/examples/run_dp_with_cached_graph_etp16.sh new file mode 100644 index 0000000000..5f1d3b782b --- /dev/null +++ b/examples/run_dp_with_cached_graph_etp16.sh @@ -0,0 +1,25 @@ +export HCCL_IF_IP=2.0.0.0 +export GLOO_SOCKET_IFNAME="enp189s0f0" +export TP_SOCKET_IFNAME="enp189s0f0" +export HCCL_SOCKET_IFNAME="enp189s0f0" + +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 +# export VLLM_VERSION=0.9.0 + +nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \ + --host 0.0.0.0 \ + --port 20002 \ + --quantization ascend \ + -dp=2 \ + -tp=8 \ + --no-enable-prefix-caching \ + --max-num-seqs 24 \ + --max-model-len 4096 \ + --max-num-batched-tokens 4096 \ + --gpu-memory-utilization 0.96 \ + --trust-remote-code \ + --distributed-executor-backend=mp \ + --additional-config '{"torchair_graph_config":{"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24]},"ascend_scheduler_config":{"enabled":true}}' \ + & > run.log & +disown diff --git a/pyproject.toml b/pyproject.toml index 514b755c32..fc7c7c2c71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,5 +19,7 @@ requires = [ "msgpack", "quart", "numba", + # Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 + "transformers<4.53.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index eadb96f1e9..6d84ec658c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,7 @@ numba # Install torch_npu --pre --extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi -torch-npu==2.5.1.post1.dev20250528 +torch-npu==2.5.1.post1.dev20250619 + +# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 +transformers<4.53.0 diff --git a/tests/conftest.py b/tests/conftest.py index e0d70a19d2..eccfa38e9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,18 +19,24 @@ import contextlib import gc -from typing import List, Optional, Tuple, TypeVar, Union +import os +from typing import Any, List, Optional, Tuple, TypeVar, Union import numpy as np import pytest import torch -from huggingface_hub import snapshot_download +from modelscope import snapshot_download # type: ignore[import-untyped] from PIL import Image +from torch import nn +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + BatchEncoding, BatchFeature) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm import LLM, SamplingParams -from vllm.config import TaskOption +from vllm.config import TaskOption, _get_and_verify_dtype from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import is_list_of from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs, @@ -45,6 +51,7 @@ from vllm.distributed.parallel_state import ( # noqa E402 destroy_distributed_environment, destroy_model_parallel) +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) _M = TypeVar("_M") _PromptMultiModalInput = Union[List[_M], List[List[_M]]] @@ -53,6 +60,9 @@ PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]] PromptVideoInput = _PromptMultiModalInput[np.ndarray] +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "e2e", "prompts", "example.txt")] + def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_model_parallel() @@ -361,6 +371,148 @@ def prompt_template(request): return PROMPT_TEMPLATES[request.param] +def _read_prompts(filename: str) -> list[str]: + with open(filename) as f: + prompts = f.readlines() + return prompts + + +@pytest.fixture +def example_prompts() -> list[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + @pytest.fixture(scope="session") def ilama_lora_files(): - return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") + return snapshot_download(repo_id="vllm-ascend/ilama-text2sql-spider") + + +class HfRunner: + + def get_default_device(self): + from vllm.platforms import current_platform + + return ("cpu" + if current_platform.is_cpu() else current_platform.device_type) + + def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + if x is None or isinstance(x, (bool, )): + return x + + if device is None: + device = self.device + + if isinstance(x, dict): + return {k: self.wrap_device(v, device) for k, v in x.items()} + + if hasattr(x, "device") and x.device.type == device: + return x + + return x.to(device) + + def __init__( + self, + model_name: str, + dtype: str = "auto", + *, + model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, + is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, + skip_tokenizer_init: bool = False, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, + ) -> None: + model_name = maybe_model_redirect(model_name) + self.model_name = model_name + + self.config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + ) + self.device = self.get_default_device() + self.dtype = torch_dtype = _get_and_verify_dtype( + self.model_name, + self.config, + dtype=dtype, + is_pooling_model=is_sentence_transformer or is_cross_encoder, + ) + + model_kwargs = model_kwargs if model_kwargs is not None else {} + model_kwargs.setdefault("torch_dtype", torch_dtype) + + if is_sentence_transformer: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer( + model_name, + device=self.device, + model_kwargs=model_kwargs, + trust_remote_code=trust_remote_code, + ) + elif is_cross_encoder: + # Lazy init required for AMD CI + from sentence_transformers import CrossEncoder + + self.model = CrossEncoder( + model_name, + device=self.device, + automodel_args=model_kwargs, + trust_remote_code=trust_remote_code, + ) + else: + model = auto_cls.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + + # in case some unquantized custom models are not in same dtype + if (getattr(model, "quantization_method", None) is None + and any(p.dtype != self.dtype + for p in model.parameters())): + model = model.to(dtype=self.dtype) + + if (getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device + for p in model.parameters()}) < 2): + model = model.to(device=self.device) + + self.model = model + + if not skip_tokenizer_init: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + if skip_tokenizer_init: + self.tokenizer = self.processor.tokenizer + + def encode(self, prompts: list[str], *args, + **kwargs) -> list[list[torch.Tensor]]: + return self.model.encode(prompts, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner diff --git a/tests/e2e/doctests/001-quickstart-test.sh b/tests/e2e/doctests/001-quickstart-test.sh index 44dce0d7bc..6490908c8c 100755 --- a/tests/e2e/doctests/001-quickstart-test.sh +++ b/tests/e2e/doctests/001-quickstart-test.sh @@ -16,6 +16,16 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +function install_system_packages() { + if command -v apt-get >/dev/null; then + sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list + apt-get update -y && apt install -y curl + elif command -v yum >/dev/null; then + yum update -y && yum install -y curl + else + echo "Unknown package manager. Please install gcc, g++, numactl-devel, git, curl, and jq manually." + fi +} function simple_test() { # Do real import test @@ -28,6 +38,7 @@ function quickstart_offline_test() { } function quickstart_online_test() { + install_system_packages vllm serve Qwen/Qwen2.5-0.5B-Instruct & wait_url_ready "vllm serve" "localhost:8000/v1/models" # Do real curl test diff --git a/tests/e2e/doctests/002-pip-binary-installation-test.sh b/tests/e2e/doctests/002-pip-binary-installation-test.sh index 48a33bf6ff..a763cefb05 100644 --- a/tests/e2e/doctests/002-pip-binary-installation-test.sh +++ b/tests/e2e/doctests/002-pip-binary-installation-test.sh @@ -18,14 +18,34 @@ # trap clean_venv EXIT +function install_system_packages() { + if command -v apt-get >/dev/null; then + sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list + apt-get update -y && apt-get install -y gcc g++ cmake libnuma-dev wget git curl jq + elif command -v yum >/dev/null; then + yum update -y && yum install -y gcc g++ cmake numactl-devel wget git curl jq + else + echo "Unknown package manager. Please install curl manually." + fi +} + +function config_pip_mirror() { + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +} + function install_binary_test() { + install_system_packages + config_pip_mirror create_vllm_venv PIP_VLLM_VERSION=$(get_version pip_vllm_version) PIP_VLLM_ASCEND_VERSION=$(get_version pip_vllm_ascend_version) _info "====> Install vllm==${PIP_VLLM_VERSION} and vllm-ascend ${PIP_VLLM_ASCEND_VERSION}" + # Setup extra-index-url for x86 & torch_npu dev version + pip config set global.extra-index-url "https://download.pytorch.org/whl/cpu/ https://mirrors.huaweicloud.com/ascend/repos/pypi" + pip install vllm=="$(get_version pip_vllm_version)" pip install vllm-ascend=="$(get_version pip_vllm_ascend_version)" diff --git a/tests/e2e/long_term/spec_decode/__init__.py b/tests/e2e/long_term/spec_decode_v0/__init__.py similarity index 100% rename from tests/e2e/long_term/spec_decode/__init__.py rename to tests/e2e/long_term/spec_decode_v0/__init__.py diff --git a/tests/e2e/long_term/spec_decode/conftest.py b/tests/e2e/long_term/spec_decode_v0/conftest.py similarity index 100% rename from tests/e2e/long_term/spec_decode/conftest.py rename to tests/e2e/long_term/spec_decode_v0/conftest.py diff --git a/tests/e2e/long_term/spec_decode/e2e/__init__.py b/tests/e2e/long_term/spec_decode_v0/e2e/__init__.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/__init__.py rename to tests/e2e/long_term/spec_decode_v0/e2e/__init__.py diff --git a/tests/e2e/long_term/spec_decode/e2e/conftest.py b/tests/e2e/long_term/spec_decode_v0/e2e/conftest.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/conftest.py rename to tests/e2e/long_term/spec_decode_v0/e2e/conftest.py diff --git a/tests/e2e/long_term/spec_decode_v0/e2e/test_eagle_correctness.py b/tests/e2e/long_term/spec_decode_v0/e2e/test_eagle_correctness.py new file mode 100644 index 0000000000..b44dc3c48e --- /dev/null +++ b/tests/e2e/long_term/spec_decode_v0/e2e/test_eagle_correctness.py @@ -0,0 +1,344 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, EAGLE would not break the +correctness for the target model outputs. +""" + +import pytest + +from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \ + run_equality_correctness_test + +# main model +MAIN_MODEL = "JackFram/llama-68m" + +# speculative model +SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 4 + +# precision +# TODO The vLLM here uses float32, but some op on the vllm-ascend +# do not support float32, such as ROPE, When it is fixed, it is +# recommended to change this to float32. +PRECISION = "float16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, +}, { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, +}]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) + + +@pytest.mark.skipif(True, reason="Open it when graph mode ready.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_cuda_graph( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.skipif(True, reason="Open it when preempt ready.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": k, + }, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that eagle speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4, + }, +}]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that eagle speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py b/tests/e2e/long_term/spec_decode_v0/e2e/test_medusa_correctness.py similarity index 99% rename from tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py rename to tests/e2e/long_term/spec_decode_v0/e2e/test_medusa_correctness.py index e0c2efd7af..26398e2e45 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/e2e/long_term/spec_decode_v0/e2e/test_medusa_correctness.py @@ -41,9 +41,10 @@ import pytest -from tests.e2e.long_term.spec_decode.e2e.conftest import \ +from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \ run_equality_correctness_test -from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill +from tests.e2e.long_term.spec_decode_v0.utils import \ + maybe_enable_chunked_prefill # main model # lmsys/vicuna-7b-v1.3 was to be used but it's causing diff --git a/tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py b/tests/e2e/long_term/spec_decode_v0/e2e/test_mlp_correctness.py similarity index 99% rename from tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py rename to tests/e2e/long_term/spec_decode_v0/e2e/test_mlp_correctness.py index 56db617755..37003e4c1f 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/e2e/long_term/spec_decode_v0/e2e/test_mlp_correctness.py @@ -41,9 +41,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import \ pad_vocab_size # noqa: F401 -from tests.e2e.long_term.spec_decode.e2e.conftest import \ +from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \ run_equality_correctness_test -from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill +from tests.e2e.long_term.spec_decode_v0.utils import \ + maybe_enable_chunked_prefill # main model MAIN_MODEL = "JackFram/llama-160m" diff --git a/tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py b/tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py similarity index 100% rename from tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py rename to tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py diff --git a/tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py b/tests/e2e/long_term/spec_decode_v0/e2e/test_ngram_correctness.py similarity index 98% rename from tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py rename to tests/e2e/long_term/spec_decode_v0/e2e/test_ngram_correctness.py index b99187fe37..1cc20abac0 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/e2e/long_term/spec_decode_v0/e2e/test_ngram_correctness.py @@ -44,9 +44,10 @@ import pytest -from tests.e2e.long_term.spec_decode.e2e.conftest import \ +from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \ run_equality_correctness_test -from tests.e2e.long_term.spec_decode.utils import maybe_enable_chunked_prefill +from tests.e2e.long_term.spec_decode_v0.utils import \ + maybe_enable_chunked_prefill @pytest.mark.parametrize( diff --git a/tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py b/tests/e2e/long_term/spec_decode_v0/test_dynamic_spec_decode.py similarity index 96% rename from tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py rename to tests/e2e/long_term/spec_decode_v0/test_dynamic_spec_decode.py index 8e9480ea26..63e4e1dba4 100644 --- a/tests/e2e/long_term/spec_decode/test_dynamic_spec_decode.py +++ b/tests/e2e/long_term/spec_decode_v0/test_dynamic_spec_decode.py @@ -27,8 +27,9 @@ from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.top1_proposer import Top1Proposer -from tests.e2e.long_term.spec_decode.test_utils import mock_spec_decode_sampler -from tests.e2e.long_term.spec_decode.utils import create_batch, mock_worker +from tests.e2e.long_term.spec_decode_v0.test_utils import \ + mock_spec_decode_sampler +from tests.e2e.long_term.spec_decode_v0.utils import create_batch, mock_worker @pytest.mark.parametrize('queue_size', [4]) diff --git a/tests/e2e/long_term/spec_decode/test_multi_step_worker.py b/tests/e2e/long_term/spec_decode_v0/test_multi_step_worker.py similarity index 99% rename from tests/e2e/long_term/spec_decode/test_multi_step_worker.py rename to tests/e2e/long_term/spec_decode_v0/test_multi_step_worker.py index b3017a987e..1dc50dd169 100644 --- a/tests/e2e/long_term/spec_decode/test_multi_step_worker.py +++ b/tests/e2e/long_term/spec_decode_v0/test_multi_step_worker.py @@ -29,7 +29,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer -from tests.e2e.long_term.spec_decode.utils import ( +from tests.e2e.long_term.spec_decode_v0.utils import ( assert_logprobs_dict_allclose, create_batch, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) diff --git a/tests/e2e/long_term/spec_decode/test_ngram_worker.py b/tests/e2e/long_term/spec_decode_v0/test_ngram_worker.py similarity index 99% rename from tests/e2e/long_term/spec_decode/test_ngram_worker.py rename to tests/e2e/long_term/spec_decode_v0/test_ngram_worker.py index 078a4d2bed..30177b68bc 100644 --- a/tests/e2e/long_term/spec_decode/test_ngram_worker.py +++ b/tests/e2e/long_term/spec_decode_v0/test_ngram_worker.py @@ -22,7 +22,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.top1_proposer import Top1Proposer -from tests.e2e.long_term.spec_decode.utils import ( +from tests.e2e.long_term.spec_decode_v0.utils import ( create_seq_group_metadata_from_prompts, create_worker) diff --git a/tests/e2e/long_term/spec_decode/test_spec_decode_worker.py b/tests/e2e/long_term/spec_decode_v0/test_spec_decode_worker.py similarity index 99% rename from tests/e2e/long_term/spec_decode/test_spec_decode_worker.py rename to tests/e2e/long_term/spec_decode_v0/test_spec_decode_worker.py index 94a1bcf1e7..ffcb2f6b54 100644 --- a/tests/e2e/long_term/spec_decode/test_spec_decode_worker.py +++ b/tests/e2e/long_term/spec_decode_v0/test_spec_decode_worker.py @@ -35,10 +35,10 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) -from tests.e2e.long_term.spec_decode.test_utils import mock_spec_decode_sampler -from tests.e2e.long_term.spec_decode.utils import (create_batch, - create_sampler_output_list, - create_worker, mock_worker) +from tests.e2e.long_term.spec_decode_v0.test_utils import \ + mock_spec_decode_sampler +from tests.e2e.long_term.spec_decode_v0.utils import ( + create_batch, create_sampler_output_list, create_worker, mock_worker) from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner from vllm_ascend.worker.worker import NPUWorker diff --git a/tests/e2e/long_term/spec_decode/test_utils.py b/tests/e2e/long_term/spec_decode_v0/test_utils.py similarity index 100% rename from tests/e2e/long_term/spec_decode/test_utils.py rename to tests/e2e/long_term/spec_decode_v0/test_utils.py diff --git a/tests/e2e/long_term/spec_decode/utils.py b/tests/e2e/long_term/spec_decode_v0/utils.py similarity index 100% rename from tests/e2e/long_term/spec_decode/utils.py rename to tests/e2e/long_term/spec_decode_v0/utils.py diff --git a/tests/e2e/multicard/test_data_parallel.py b/tests/e2e/multicard/test_data_parallel.py new file mode 100644 index 0000000000..57f14ac6db --- /dev/null +++ b/tests/e2e/multicard/test_data_parallel.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Compare the outputs of vLLM with and without aclgraph. + +Run `pytest tests/multicard/test_data_parallel.py`. +""" + +import os +import subprocess +import sys +from unittest.mock import patch + +import pytest + +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] + + +@pytest.mark.skipif(True, reason="TODO: fix dp timeout error in ci") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) +def test_data_parallel_inference(model, max_tokens): + script = "examples/offline_data_parallel.py" + + env = os.environ.copy() + + cmd = [ + sys.executable, + script, + "--model", + model, + "--dp-size", + "2", + "--tp-size", + "1", + "--node-size", + "1", + "--node-rank", + "0", + "--trust-remote-code", + "--enforce-eager", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600) + output = proc.stdout.decode() + + print(output) + + assert "DP rank 0 needs to process" in output + assert "DP rank 1 needs to process" in output + assert "Generated text:" in output + assert proc.returncode == 0 diff --git a/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py b/tests/e2e/multicard/test_deepseek_v2_lite_tp2_accuracy.py similarity index 97% rename from tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py rename to tests/e2e/multicard/test_deepseek_v2_lite_tp2_accuracy.py index 27986cb149..3a9068ff6b 100644 --- a/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py +++ b/tests/e2e/multicard/test_deepseek_v2_lite_tp2_accuracy.py @@ -38,7 +38,7 @@ def run_test(model_name, queue, more_args=None): - model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4" + model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4,enforce_eager=True" if more_args is not None: model_args = f"{model_args},{more_args}" results = lm_eval.simple_evaluate( diff --git a/tests/e2e/multicard/test_fused_moe_allgather_ep.py b/tests/e2e/multicard/test_fused_moe_allgather_ep.py new file mode 100644 index 0000000000..ad755dd161 --- /dev/null +++ b/tests/e2e/multicard/test_fused_moe_allgather_ep.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep. + +Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'. +""" + +import os +from unittest.mock import patch + +from modelscope import snapshot_download # type: ignore +from vllm import SamplingParams + +from tests.conftest import VllmRunner + + +@patch.dict( + os.environ, { + "VLLM_USE_V1": "1", + "VLLM_WORKER_MULTIPROC_METHOD": "spawn", + "TASK_QUEUE_ENABLE": "1", + "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1" + }) +def test_generate_with_allgather(): + example_prompts = ["Hello, my name is"] + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + + with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"), + tensor_parallel_size=4, + enforce_eager=True, + max_model_len=1024, + dtype="auto", + enable_expert_parallel=True, + additional_config={ + "ascend_scheduler_config": { + "enabled": True, + "chunked_prefill_enabled": False, + }, + "expert_tensor_parallel_size": 1 + }) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict( + os.environ, { + "VLLM_USE_V1": "1", + "VLLM_WORKER_MULTIPROC_METHOD": "spawn", + "TASK_QUEUE_ENABLE": "1" + }) +def test_generate_with_alltoall(): + example_prompts = ["Hello, my name is"] + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + + with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"), + tensor_parallel_size=4, + enforce_eager=True, + max_model_len=1024, + dtype="auto", + enable_expert_parallel=True, + additional_config={ + "ascend_scheduler_config": { + "enabled": True, + "chunked_prefill_enabled": False, + }, + "expert_tensor_parallel_size": 1 + }) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) \ No newline at end of file diff --git a/tests/e2e/multicard/test_ilama_lora_tp2.py b/tests/e2e/multicard/test_ilama_lora_tp2.py index e743141b7a..3f62bfd7e3 100644 --- a/tests/e2e/multicard/test_ilama_lora_tp2.py +++ b/tests/e2e/multicard/test_ilama_lora_tp2.py @@ -1,4 +1,5 @@ import pytest +from modelscope import snapshot_download # type: ignore from tests.conftest import VllmRunner from tests.e2e.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, @@ -7,7 +8,7 @@ @pytest.mark.parametrize("distributed_executor_backend", ["mp"]) def test_ilama_lora_tp2(distributed_executor_backend, ilama_lora_files): - with VllmRunner(model_name=MODEL_PATH, + with VllmRunner(snapshot_download(MODEL_PATH), enable_lora=True, max_loras=4, max_model_len=1024, diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index f5ec2c872b..47ff47eddd 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -25,6 +25,7 @@ from modelscope import snapshot_download # type: ignore from vllm import SamplingParams +from vllm.model_executor.models.registry import ModelRegistry from tests.conftest import VllmRunner @@ -46,17 +47,28 @@ def test_models_distributed_QwQ(): vllm_model.generate_greedy(example_prompts, max_tokens) -def test_models_distributed_DeepSeek(): +def test_models_distributed_DeepSeek_multistream_moe(): example_prompts = [ "Hello, my name is", ] dtype = "half" max_tokens = 5 with VllmRunner( - "deepseek-ai/DeepSeek-V2-Lite", + "vllm-ascend/DeepSeek-V3-Pruning", dtype=dtype, tensor_parallel_size=4, distributed_executor_backend="mp", + additional_config={ + "torchair_graph_config": { + "enabled": True, + "enable_multistream_moe": True, + }, + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + }, + enforce_eager=False, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) @@ -94,6 +106,32 @@ def test_models_distributed_DeepSeek_dbo(): tensor_parallel_size=4, distributed_executor_backend="mp", ) as vllm_model: + model_arch = 'DeepseekV2ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" + vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_distributed_DeepSeekV3_dbo(): + example_prompts = ["The president of the United States is"] * 41 + dtype = "half" + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + model_arch = 'DeepseekV3ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" vllm_model.generate(example_prompts, sampling_params) @@ -112,3 +150,20 @@ def test_models_distributed_DeepSeek_W8A8(): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_pangu(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("vllm-ascend/pangu-pro-moe-pruing"), + max_model_len=8192, + enforce_eager=True, + dtype="auto", + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/multicard/test_prefix_caching.py b/tests/e2e/multicard/test_prefix_caching.py new file mode 100644 index 0000000000..368d3ff953 --- /dev/null +++ b/tests/e2e/multicard/test_prefix_caching.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compare the with and without prefix caching on V1 scheduler or AscendScheduler.""" + +import os + +import pytest + +from tests.conftest import VllmRunner +from tests.model_utils import check_outputs_equal + +MODELS = [ + # for MHA + "Qwen/Qwen3-8B-Base", + # for MLA + "deepseek-ai/DeepSeek-V2-Lite-Chat" +] + +# A prompt containing a large markdown table. The table is randomly generated by GPT-4. +LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ +| ID | Name | Age | Occupation | Country | Email | Phone Number | Address | +|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| +| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | +| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON | +| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK | +| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW | +| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ | +| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE | +| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY | +| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC | +| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK | +| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC| +| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ | +| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE | +| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA | +| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB | +| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK | +| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD | +| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ | +| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE | +| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA | +| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON | +| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK | +| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA | +| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ| +| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE | +| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO | +| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC | +| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK | +| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA | +| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | +| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | +""" + +INPUT_PROMPTS = [ + LONG_PROMPT + + "Question: what is the age of John Doe? Your answer: The age of John Doe is ", + LONG_PROMPT + + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is " +] + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="mtp is not supported on v1") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [50]) +def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None: + with VllmRunner(model, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=2, + gpu_memory_utilization=0.7) as vllm_model: + prefix_cache_output = vllm_model.generate_greedy( + INPUT_PROMPTS, max_tokens) + + with VllmRunner(model, + enable_prefix_caching=False, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=2, + gpu_memory_utilization=0.7) as vllm_model: + vllm_output = vllm_model.generate_greedy(INPUT_PROMPTS, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_output, + outputs_1_lst=prefix_cache_output, + name_0="vllm_output", + name_1="prefix_cache_output", + ) + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="mtp is not supported on v1") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [50]) +def test_prefix_cache_with_ascend_scheduler(model: str, + max_tokens: int) -> None: + + with VllmRunner(model, + additional_config={ + 'ascend_scheduler_config': { + 'enabled': True, + }, + }, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=2, + gpu_memory_utilization=0.7) as vllm_model: + vllm_output = vllm_model.generate_greedy(INPUT_PROMPTS, max_tokens) + + with VllmRunner(model, + additional_config={ + 'ascend_scheduler_config': { + 'enabled': True, + 'enable_prefix_caching': True, + }, + }, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=2, + gpu_memory_utilization=0.7) as vllm_model: + prefix_cache_output = vllm_model.generate_greedy( + INPUT_PROMPTS, max_tokens) + + with VllmRunner(model, + additional_config={ + 'ascend_scheduler_config': { + 'enabled': True, + 'enable_prefix_caching': True, + "enable_chunked_prefill": True, + }, + }, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=2, + gpu_memory_utilization=0.7) as vllm_model: + chunk_prefill_prefix_cache_output = vllm_model.generate_greedy( + INPUT_PROMPTS, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_output, + outputs_1_lst=prefix_cache_output, + name_0="vllm_output", + name_1="prefix_cache_output", + ) + + check_outputs_equal( + outputs_0_lst=chunk_prefill_prefix_cache_output, + outputs_1_lst=prefix_cache_output, + name_0="chunk_prefill_prefix_cache_output", + name_1="prefix_cache_output", + ) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index d06ec7de22..ce628f9d35 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -20,6 +20,7 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`. """ import os +from typing import Dict import pytest @@ -28,53 +29,133 @@ os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +def _deepseek_torchair_test_fixture( + additional_config: Dict, + *, + tensor_parallel_size=4, +): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # torchair is only work without chunked-prefill now + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + additional_config.update(**kwargs) + + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype="half", + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="mp", + enforce_eager=False, + additional_config=additional_config, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of + # DeepSeek-V3 with 2 hidden layers, thus the golden results seems + # inaccurate. This will only change if accuracy improves with the + # official weights of DeepSeek-V3. + golden_results = [ + 'Hello, my name is下载早点向前很有่อง', + 'The president of the United States isSender)## physiological Albany', + 'The capital of France is Rocky转角 hospitalizedinterval sparked', + 'The future of AI is её asegο BIOS一扫', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="torchair graph is not supported on v0") +def test_e2e_deepseekv3_with_torchair(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _deepseek_torchair_test_fixture(additional_config) + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="torchair graph is not supported on v0") +def test_e2e_deepseekv3_with_torchair_ms_mla(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + "enable_multistream_mla": True, + }, + } + _deepseek_torchair_test_fixture(additional_config) + + +def _pangu_torchair_test_fixture( + additional_config: Dict, + *, + tensor_parallel_size=4, +): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # torchair is only work without chunked-prefill now + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + additional_config.update(**kwargs) + + with VllmRunner( + "vllm-ascend/pangu-pro-moe-pruing", + dtype="half", + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="mp", + enforce_eager=False, + additional_config=additional_config, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE + # with 2 hidden layers, thus the golden results seems inaccurate. + # This will only change if accuracy changes with the official weights + # of PanguProMoE. + golden_results = [ + 'Hello, my name is Remempondeprecatedmiot忱', + 'The president of the United States is Remem下的一个 rever ceremoni Segnali', + 'The capital of France is Rememvoud administrativ Remem投', + 'The future of AI isotope Segnali Zoeken精细化 supus', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") + + @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="torchair graph is not supported on v0") -def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_MODELSCOPE", "True") - m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - dtype = "half" - max_tokens = 5 - # torchair is only work without chunked-prefill now - with VllmRunner( - "vllm-ascend/DeepSeek-V3-Pruning", - dtype=dtype, - tensor_parallel_size=4, - distributed_executor_backend="mp", - additional_config={ - "torchair_graph_config": { - "enabled": True, - }, - "ascend_scheduler_config": { - "enabled": True, - }, - "refresh": True, - }, - enforce_eager=False, - ) as vllm_model: - # use greedy sampler to make sure the generated results are fix - vllm_output = vllm_model.generate_greedy(example_prompts, - max_tokens) - # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of - # DeepSeek-V3 with 2 hidden layers, thus the golden results seems - # inaccurate. This will only change if accuracy improves with the - # official weights of DeepSeek-V3. - golden_results = [ - 'Hello, my name is feasibility伸 spazio debtor添', - 'The president of the United States is begg"""\n杭州风和 bestimm', - 'The capital of France is frequentlyশามalinkAllowed', - 'The future of AI is deleting俯احت怎么样了حراف', - ] - - assert len(golden_results) == len(vllm_output) - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") +def test_e2e_pangu_with_torchair(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _pangu_torchair_test_fixture(additional_config) diff --git a/tests/e2e/pd_disaggreate/setup_pd.sh b/tests/e2e/pd_disaggreate/setup_pd.sh index 675bee439f..c15f109299 100644 --- a/tests/e2e/pd_disaggreate/setup_pd.sh +++ b/tests/e2e/pd_disaggreate/setup_pd.sh @@ -66,6 +66,7 @@ function run_prefill_instance() { --served-model-name Deepseek \ --max-model-len 2000 \ --trust-remote-code \ + --enforce-eager \ --kv-transfer-config "$KV_CONFIG" } @@ -119,6 +120,7 @@ function run_decode_instance() { --max-num-batched-tokens 2000 \ --trust-remote-code \ --gpu-memory-utilization 0.9 \ + --enforce-eager \ --kv-transfer-config "$KV_CONFIG" } diff --git a/tests/e2e/prompts/example.txt b/tests/e2e/prompts/example.txt new file mode 100644 index 0000000000..e1b97bc6ee --- /dev/null +++ b/tests/e2e/prompts/example.txt @@ -0,0 +1,8 @@ +vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. +Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. +Compare and contrast artificial intelligence with human intelligence in terms of processing information. +Describe the basic components of a neural network and how it can be trained. +Write a short story about a robot that dreams for the first time. +Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. +Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' diff --git a/tests/e2e/singlecard/core/test_ascend_scheduler.py b/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py similarity index 91% rename from tests/e2e/singlecard/core/test_ascend_scheduler.py rename to tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py index 7d9c1b1ef5..e1fd16bda9 100644 --- a/tests/e2e/singlecard/core/test_ascend_scheduler.py +++ b/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler.py @@ -201,7 +201,10 @@ def test_schedule(enable_prefix_caching: Optional[bool], # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert len(output.scheduled_cached_reqs) == 0 + if vllm_version_is("0.9.1"): + assert len(output.scheduled_cached_reqs) == 0 + else: + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): @@ -238,7 +241,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 3 - assert len(output.scheduled_cached_reqs) == 0 + if vllm_version_is("0.9.1"): + assert len(output.scheduled_cached_reqs) == 0 + else: + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # The first request is scheduled partially - 400. @@ -268,7 +274,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output1 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output1.scheduled_new_reqs) == 0 - assert len(output1.scheduled_cached_reqs) == 3 + if vllm_version_is("0.9.1"): + assert len(output1.scheduled_cached_reqs) == 3 + else: + assert output1.scheduled_cached_reqs.num_reqs == 3 assert len(output1.finished_req_ids) == 0 assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400 @@ -292,7 +301,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): output2 = scheduler.schedule() assert len(scheduler.running) == 3 assert len(output2.scheduled_new_reqs) == 0 - assert len(output2.scheduled_cached_reqs) == 3 + if vllm_version_is("0.9.1"): + assert len(output2.scheduled_cached_reqs) == 3 + else: + assert output2.scheduled_cached_reqs.num_reqs == 3 assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 @@ -672,73 +684,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): assert stats.num_accepted_tokens_per_pos == expected[3] -def _assert_right_scheduler_output( - output: SchedulerOutput, - num_requests: int, - expected_num_scheduled_tokens: int, -): - """Check if SchedulerOutput is correct after remote KV cache hit.""" - - # We should inject the kv_connector_metadata. - assert len(output.kv_connector_metadata.requests) == num_requests - - # Only num_tokens - matched_num_new_tokens should be scheduled. - for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): - assert num_scheduled_tokens == expected_num_scheduled_tokens - - -def _assert_right_kv_cache_manager( - scheduler: AscendScheduler, - req_ids: list[str], - num_tokens: int, - block_size: int, - num_requests: int, - num_total_blocks: int, -): - """Check whether KVCacheManager is correct after allocate.""" - - # Make sure the request stats are right. - EXPECTED_TOTAL_BLOCKS = num_tokens // block_size - for req_id in req_ids: - blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req_id]) - hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] - assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) - assert len(blocks) == EXPECTED_TOTAL_BLOCKS - assert len(hashes) == EXPECTED_TOTAL_BLOCKS - - # Make sure we actually touched all the blocks. - BLOCKS_PER_REQ = num_tokens / block_size - assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == - num_total_blocks - num_requests * BLOCKS_PER_REQ) - - -def _step_until_done( - scheduler: AscendScheduler, - output: SchedulerOutput, - model_runner_output: ModelRunnerOutput, -): - """Loop over schedule(), update_from_output() until finished.""" - - all_finished = False - _ = scheduler.update_from_output(output, model_runner_output) - while not all_finished: - # Schedule + a few iterations until stopping. - output = scheduler.schedule() - assert len(scheduler.running) - for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): - # We should be in the decode phase now. - assert num_scheduled_tokens == 1 - assert len(output.kv_connector_metadata.requests) == 0 - ecos = scheduler.update_from_output(output, model_runner_output)[0] - all_done = True - for eco in ecos.outputs: - if eco.finish_reason is None: - all_done = False - all_finished = all_done - - def make_output(scheduler: AscendScheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], @@ -762,7 +707,6 @@ def assert_scheduler_empty(scheduler: AscendScheduler): assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 0 assert len(scheduler.finished_req_ids) == 0 - assert len(scheduler._cached_reqs_data) == 0 # EncoderCacheManager. assert len(scheduler.encoder_cache_manager.freed) == 0 diff --git a/tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py b/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler_e2e.py similarity index 86% rename from tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py rename to tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler_e2e.py index 668dafced9..17116ab59a 100644 --- a/tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py +++ b/tests/e2e/singlecard/core/ascend_scheduler/test_ascend_scheduler_e2e.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import os import pytest +import torch from vllm import LLM if os.getenv("VLLM_USE_V1", "0") != "1": @@ -13,8 +15,8 @@ @pytest.fixture(scope="module") -def model() -> LLM: - return LLM( +def model(): + llm = LLM( MODEL, enforce_eager=True, enable_prefix_caching=True, @@ -23,6 +25,10 @@ def model() -> LLM: additional_config={"ascend_scheduler_config": { "enabled": True, }}) + yield llm + del llm + torch.npu.empty_cache() + gc.collect() def test_concurrent_partial_prefill(model): @@ -37,4 +43,4 @@ def test_prefix_cache_stats_is_recorded(model): input_tokens = {"prompt_token_ids": [101] * 129} _ = model.generate([input_tokens]) outputs = model.generate([input_tokens]) - assert outputs[0].num_cached_tokens == 128 \ No newline at end of file + assert outputs[0].num_cached_tokens == 128 diff --git a/tests/e2e/singlecard/core/ascend_scheduler/test_chunk_prefill.py b/tests/e2e/singlecard/core/ascend_scheduler/test_chunk_prefill.py new file mode 100644 index 0000000000..0b557960e6 --- /dev/null +++ b/tests/e2e/singlecard/core/ascend_scheduler/test_chunk_prefill.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the with and without chunked prefill on AscendScheduler + +It tests chunked prefill. Chunked prefill can be enabled by +`additional_config={'ascend_scheduler_config': {'enabled': True, 'enable_chunked_prefill': True,},}`. +If prefill size exceeds max_num_batched_tokens, prefill requests are chunked. + +Run `pytest tests/e2e/singlecard/core/ascend_scheduler/test_chunk_prefill.py`. +""" +import pytest + +from tests.conftest import VllmRunner +from tests.model_utils import check_outputs_equal + +MODELS = [ + "Qwen/Qwen3-0.6B-Base", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", + [4]) # cannot align results when max_tokens > 4 +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_chunked_prefill_with_ascend_scheduler( + example_prompts, model: str, max_tokens: int, + chunked_prefill_token_size: int) -> None: + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size + with VllmRunner(model, + additional_config={ + 'ascend_scheduler_config': { + 'enabled': True, + 'enable_chunked_prefill': True, + }, + }, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + enforce_eager=True, + max_model_len=2048, + gpu_memory_utilization=0.7) as vllm_model: + chunked_prefill_output = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with VllmRunner(model, + additional_config={ + 'ascend_scheduler_config': { + 'enabled': True, + }, + }, + enforce_eager=True, + max_model_len=2048, + gpu_memory_utilization=0.7) as vllm_model: + vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_output, + outputs_1_lst=chunked_prefill_output, + name_0="vllm_output", + name_1="chunked_prefill_output", + ) diff --git a/tests/e2e/singlecard/sample/test_rejection_sampler.py b/tests/e2e/singlecard/sample/test_rejection_sampler.py index 4116814b67..3b48864cea 100644 --- a/tests/e2e/singlecard/sample/test_rejection_sampler.py +++ b/tests/e2e/singlecard/sample/test_rejection_sampler.py @@ -9,6 +9,7 @@ from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, AscendRejectionSampler) +from vllm_ascend.utils import vllm_version_is DEVICE = "npu" @@ -49,27 +50,46 @@ def create_sampling_metadata( temperature = None else: assert temperature is not None - - return SamplingMetadata( - temperature=temperature, - all_greedy=all_greedy, - all_random=not all_greedy, - top_p=top_p, - top_k=top_k, - min_p=torch.empty(1, ), - generators=generators, - max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - min_tokens={}, - logit_bias=[None], - allowed_token_ids_mask=None, - bad_words_token_ids={}, - ) + if vllm_version_is("0.9.1"): + return SamplingMetadata( + temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + min_p=torch.empty(1, ), + generators=generators, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + min_tokens={}, + logit_bias=[None], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + ) + else: + from vllm.v1.sample.logits_processor import LogitsProcessorManager + + return SamplingMetadata(temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + generators=generators, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager()) ########################### Tests for Greedy Sampling ################### diff --git a/tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py similarity index 97% rename from tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py rename to tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 2219a6f552..0cf64b059c 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -50,6 +50,8 @@ def model_name(): return "wemaster/deepseek_mtp_main_random_bf16" +@pytest.mark.skipif( + True, reason="TODO: Enable me after test_mtp_correctness is fixed") def test_mtp_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], diff --git a/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py similarity index 91% rename from tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py rename to tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 19ab0bc220..35cb19a14e 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -11,7 +11,7 @@ @pytest.fixture def test_prompts(): prompt_types = ["repeat", "sentence"] - num_prompts = 100 + num_prompts = 10 prompts = [] random.seed(0) @@ -69,6 +69,7 @@ def test_ngram_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using ngram speculative decoding. ''' + pytest.skip("Not current support for the test.") with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -116,11 +117,12 @@ def test_eagle_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. ''' - pytest.skip("Not current support for the test.") + if not use_eagle3: + pytest.skip("Not current support for the test.") with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name, max_model_len=2048) + ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -129,13 +131,17 @@ def test_eagle_correctness( spec_llm = LLM( model=model_name, trust_remote_code=True, + enable_chunked_prefill=True, + max_num_seqs=1, + max_num_batched_tokens=2048, + gpu_memory_utilization=0.6, speculative_config={ "method": "eagle3" if use_eagle3 else "eagle", "model": spec_model_name, - "num_speculative_tokens": 3, - "max_model_len": 2048, + "num_speculative_tokens": 2, + "max_model_len": 128, }, - max_model_len=2048, + max_model_len=128, enforce_eager=True, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index e0bfb65cf8..4fc23aa7b3 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -29,7 +29,7 @@ from tests.conftest import VllmRunner from tests.model_utils import check_outputs_equal -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "vllm-ascend/Qwen3-30B-A3B-Puring"] @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", diff --git a/tests/e2e/singlecard/test_embedding.py b/tests/e2e/singlecard/test_embedding.py new file mode 100644 index 0000000000..0ca07a017e --- /dev/null +++ b/tests/e2e/singlecard/test_embedding.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +from collections.abc import Sequence +from typing import Optional + +import pytest +from modelscope import snapshot_download # type: ignore[import-untyped] + +from tests.conftest import HfRunner +from tests.utils import check_embeddings_close, matryoshka_fy +from vllm_ascend.utils import vllm_version_is + + +def run_embedding_correctness_test( + hf_model: "HfRunner", + inputs: list[str], + vllm_outputs: Sequence[list[float]], + dimensions: Optional[int] = None, +): + hf_outputs = hf_model.encode(inputs) + if dimensions: + hf_outputs = matryoshka_fy(hf_outputs, dimensions) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + + +# dummy to avoid pytest collect nothing and exit code 5 +def test_dummy(): + assert True + + +@pytest.mark.skipif(vllm_version_is("0.9.1"), + reason="vLLM 0.9.1 does not support embed task for v1") +def test_embed_models_correctness(hf_runner, vllm_runner): + queries = ['What is the capital of China?', 'Explain gravity'] + + model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B") + with vllm_runner( + model_name, + task="embed", + enforce_eager=True, + ) as vllm_model: + vllm_outputs = vllm_model.encode(queries) + + with hf_runner( + model_name, + dtype="float32", + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, queries, vllm_outputs) diff --git a/tests/e2e/singlecard/test_guided_decoding.py b/tests/e2e/singlecard/test_guided_decoding.py index 0725812a28..9d103a5308 100644 --- a/tests/e2e/singlecard/test_guided_decoding.py +++ b/tests/e2e/singlecard/test_guided_decoding.py @@ -28,13 +28,10 @@ from tests.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" -MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -GuidedDecodingBackendV0 = [ - "outlines", - "lm-format-enforcer", - "xgrammar", -] -GuidedDecodingBackendV1 = ["xgrammar", "guidance:disable-any-whitespace"] +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" + +GuidedDecodingBackendV0 = ["outlines", "lm-format-enforcer", "xgrammar"] +GuidedDecodingBackendV1 = ["xgrammar", "guidance"] GuidedDecodingBackend = list( set(GuidedDecodingBackendV0 + GuidedDecodingBackendV1)) @@ -87,26 +84,25 @@ def sample_json_schema(): } -@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) -def test_guided_json_completion(guided_decoding_backend: str, - sample_json_schema): - if guided_decoding_backend == "xgrammar": - # xgrammar does not support json schema, will fall back to outlines, skip it - pytest.skip( - f"{guided_decoding_backend} will fall back to outlines, skip it") +def check_backend(guided_decoding_backend: str): if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( "VLLM_USE_V1") == "0": - # guidance does not support on v0, skip it - pytest.skip( - f"{guided_decoding_backend} does not support on v0, skip it") + pytest.skip(f"{guided_decoding_backend} does not support v0, skip it.") if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( "VLLM_USE_V1") == "1": - pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") + pytest.skip(f"{guided_decoding_backend} does not support v1, skip it.") + + +@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) +def test_guided_json_completion(guided_decoding_backend: str, + sample_json_schema): + check_backend(guided_decoding_backend) sampling_params = SamplingParams( temperature=1.0, - max_tokens=1000, + max_tokens=500, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + with VllmRunner( MODEL_NAME, seed=0, @@ -138,19 +134,13 @@ def test_guided_json_completion(guided_decoding_backend: str, @pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) def test_guided_regex(guided_decoding_backend: str, sample_regex): - if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( - "VLLM_USE_V1") == "0": - # guidance does not support on v0, skip it - pytest.skip( - f"{guided_decoding_backend} does not support on v0, skip it") - if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( - "VLLM_USE_V1") == "1": - pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") + check_backend(guided_decoding_backend) + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - regex=sample_regex, )) with VllmRunner( MODEL_NAME, seed=0, diff --git a/tests/e2e/singlecard/test_ilama_lora.py b/tests/e2e/singlecard/test_ilama_lora.py index 2d93bceea5..35f78ad773 100644 --- a/tests/e2e/singlecard/test_ilama_lora.py +++ b/tests/e2e/singlecard/test_ilama_lora.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 - import vllm +from modelscope import snapshot_download # type: ignore from vllm.lora.request import LoRARequest from tests.conftest import VllmRunner -MODEL_PATH = "ArthurZ/ilama-3.2-1B" +MODEL_PATH = "vllm-ascend/ilama-3.2-1B" PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 @@ -45,7 +45,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_ilama_lora(ilama_lora_files): - with VllmRunner(model_name=MODEL_PATH, + with VllmRunner(snapshot_download(MODEL_PATH), enable_lora=True, max_loras=4, max_model_len=1024, diff --git a/tests/e2e/singlecard/test_sampler.py b/tests/e2e/singlecard/test_sampler.py index b21142018e..d9584daeec 100644 --- a/tests/e2e/singlecard/test_sampler.py +++ b/tests/e2e/singlecard/test_sampler.py @@ -18,9 +18,12 @@ # from typing import Optional +import pytest import torch from vllm.v1.sample.sampler import Sampler # noqa: F401 +from vllm_ascend.utils import vllm_version_is + # Set tolerance to 1 for quant ops DEFAULT_ATOL = 1e-3 DEFAULT_RTOL = 1e-3 @@ -118,6 +121,8 @@ def apply_top_k_top_p_new( # test with leading dimension and merge seqlen and batch_size as num_tokens +@pytest.mark.skipif(not vllm_version_is("0.9.1"), + reason="apply_min_p has been removed after vllm 0.9.1") @torch.inference_mode() def test_apply_min_p() -> None: logits = torch.randn((128, 7168)).npu() diff --git a/tests/e2e/singlecard/test_scheduler.py b/tests/e2e/singlecard/test_scheduler.py index b3adf945bf..fba344afb4 100644 --- a/tests/e2e/singlecard/test_scheduler.py +++ b/tests/e2e/singlecard/test_scheduler.py @@ -192,7 +192,10 @@ def test_schedule(enable_prefix_caching: Optional[bool], # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert len(output.scheduled_cached_reqs) == 0 + if vllm_version_is("0.9.1"): + assert len(output.scheduled_cached_reqs) == 0 + else: + assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): diff --git a/tests/multicard/test_data_parallel.py b/tests/multicard/test_data_parallel.py deleted file mode 100644 index 6c0a20de97..0000000000 --- a/tests/multicard/test_data_parallel.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -Compare the outputs of vLLM with and without aclgraph. - -Run `pytest tests/multicard/test_data_parallel.py`. -""" - -import os - -import pytest - -from tests.conftest import VllmRunner -from tests.model_utils import check_outputs_equal - -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] - - -@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", - reason="Data parallel only support on v1") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [32]) -def test_data_parallel_correctness( - model: str, - max_tokens: int, -) -> None: - example_prompts = [ - "Hello, my name is", "The president of the United States is", - "The capital of France is", "The future of AI is" - ] - - with VllmRunner(model_name=model, - max_model_len=1024, - max_num_seqs=16, - data_parallel_size=2, - distributed_executor_backend="mp") as vllm_model: - vllm_dp_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - with VllmRunner( - model_name=model, - max_model_len=1024, - max_num_seqs=16, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs, - outputs_1_lst=vllm_dp_outputs, - name_0="vllm_outputs", - name_1="vllm_dp_outputs", - ) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py new file mode 100644 index 0000000000..e9ce36e2e0 --- /dev/null +++ b/tests/ut/attention/test_attention_v1.py @@ -0,0 +1,497 @@ +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, + AscendAttentionBackendImpl, + AscendAttentionMetadataBuilder, + AscendAttentionState, + AscendMetadata, + CommonAttentionState) + + +class TestAscendAttentionBackend(TestBase): + + def test_get_name(self): + self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND") + + def test_get_impl_cls(self): + self.assertEqual(AscendAttentionBackend.get_impl_cls(), + AscendAttentionBackendImpl) + + def test_get_metadata_cls(self): + self.assertEqual(AscendAttentionBackend.get_metadata_cls(), + AscendMetadata) + + def test_get_state_cls(self): + self.assertEqual(AscendAttentionBackend.get_state_cls(), + CommonAttentionState) + + def test_get_builder_cls(self): + self.assertEqual(AscendAttentionBackend.get_builder_cls(), + AscendAttentionMetadataBuilder) + + @patch('vllm_ascend.attention.attention_v1.is_310p') + def test_get_kv_cache_shape_310p(self, mock_is_310p): + mock_is_310p.return_value = True + result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16)) + + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) + def test_get_kv_cache_shape_not_310p(self, mock_is_310p): + result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 20, 30, 40)) + + def test_get_bsh_kv_cache_shape(self): + result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40) + self.assertEqual(result, (2, 10, 20, 30 * 40)) + + def test_swap_blocks(self): + src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))] + dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))] + src_to_dst = torch.tensor([[0, 1], [2, 3]]) + AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache, + src_to_dst) + self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0])) + self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2])) + + def test_copy_blocks(self): + kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))] + src_to_dists = torch.tensor([[0, 1], [2, 3]]) + AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists) + self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0])) + self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2])) + + +class TestAscendAttentionMetadataBuilder(TestBase): + + def setUp(self): + self.mock_runner = MagicMock() + self.builder = AscendAttentionMetadataBuilder(self.mock_runner) + + def test_reorder_batch(self): + mock_input_batch = MagicMock() + mock_scheduler_output = MagicMock() + + result = self.builder.reorder_batch(mock_input_batch, + mock_scheduler_output) + + self.assertFalse(result) + + @patch('vllm_ascend.attention.attention_v1.AscendMetadata') + @patch('torch_npu.npu_format_cast') + @patch('vllm_ascend.utils.nd_to_nz_2d') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) + def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, + mock_npu_format_cast, + mock_ascend_metadata): + num_reqs = 2 + num_actual_tokens = 10 + max_query_len = 5 + common_prefix_len = 1 + + self.mock_runner.input_batch.block_table = [MagicMock()] + self.mock_runner.input_batch.block_table[ + 0].get_device_tensor.return_value = torch.zeros((10, 10)) + self.mock_runner.max_num_blocks_per_req = 10 + self.mock_runner.query_lens = torch.tensor([3, 4]) + self.mock_runner.seq_lens_cpu = torch.tensor([5, 6]) + self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) + self.mock_runner.device = 'cpu:0' + self.mock_runner.attn_mask = torch.ones((10, 10)) + self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache + self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7]) + + mock_nz_tensor = MagicMock() + mock_nd_to_nz_2d.return_value = mock_nz_tensor + mock_npu_format_cast.return_value = mock_nz_tensor + + self.builder.build(num_reqs, num_actual_tokens, max_query_len, + common_prefix_len) + + @patch('vllm_ascend.attention.attention_v1.AscendMetadata') + @patch('torch_npu.npu_format_cast') + @patch('vllm_ascend.utils.nd_to_nz_spec') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) + @patch('vllm_ascend.attention.attention_v1.AscendAttentionState') + def test_build_chunked_prefill(self, mock_ascend_attention_state, + mock_is_310p, mock_nd_to_nz_spec, + mock_npu_format_cast, mock_ascend_metadata): + num_reqs = 3 + num_actual_tokens = 15 + max_query_len = 6 + + self.mock_runner.input_batch.block_table = [MagicMock()] + self.mock_runner.input_batch.block_table[ + 0].get_device_tensor.return_value = torch.zeros((10, 10)) + self.mock_runner.max_num_blocks_per_req = 10 + self.mock_runner.query_lens = torch.tensor([2, 3, 4]) + self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) + self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) + self.mock_runner.device = 'cpu:0' + self.mock_runner.attn_mask = torch.ones((15, 15)) + self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill + self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) + + mock_ascend_attention_state = MagicMock() + mock_ascend_attention_state.PrefillNoCache = 0 + + mock_nz_tensor = MagicMock() + mock_nd_to_nz_spec.return_value = mock_nz_tensor + mock_npu_format_cast.return_value = mock_nz_tensor + + self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0) + + @patch('vllm_ascend.attention.attention_v1.AscendMetadata') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) + def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): + num_reqs = 3 + num_actual_tokens = 15 + max_query_len = 6 + + self.mock_runner.input_batch.block_table = [MagicMock()] + self.mock_runner.input_batch.block_table[ + 0].get_device_tensor.return_value = torch.zeros((10, 10)) + self.mock_runner.max_num_blocks_per_req = 10 + self.mock_runner.query_lens = torch.tensor([2, 3, 4]) + self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6]) + self.mock_runner.slot_mapping_cpu = torch.tensor(range(20)) + self.mock_runner.device = 'cpu:0' + self.mock_runner.attn_mask = torch.ones((15, 15)) + self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill + self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9]) + + self.builder.build(num_reqs, num_actual_tokens, max_query_len, 0) + + +class TestAscendAttentionBackendImpl(TestBase): + + def setUp(self): + self.layer = MagicMock() + self.layer.layer_name = "test_layer" + self.layer._k_scale_float = 1.0 + self.layer._v_scale_float = 1.0 + + self.attention_type = MagicMock() + self.attention_type.DECODER = "decoder" + self.attention_type.ENCODER = "encoder" + + self.attn_metadata = MagicMock() + self.attn_metadata.return_value = "1" + + self.layer_no_quant = MagicMock( + spec=['layer_name', '_k_scale_float', '_v_scale_float']) + self.layer_no_quant.layer_name = "test_layer" + self.layer_no_quant._k_scale_float = 1.0 + self.layer_no_quant._v_scale_float = 1.0 + + self.impl = AscendAttentionBackendImpl( + num_heads=8, + head_size=64, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + attn_type=self.attention_type.DECODER) + + self.impl_192 = AscendAttentionBackendImpl( + num_heads=8, + head_size=192, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + attn_type=self.attention_type.DECODER) + + self.impl_error = AscendAttentionBackendImpl(num_heads=8, + head_size=192, + scale=1.0, + num_kv_heads=8, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="float16", + attn_type=None) + + @patch('torch.ops.vllm.unified_ascend_attention_with_output') + def test_forward_trace_flag_true(self, mock_unified_attention): + """Test forward pass when trace_flag is True""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 0, 0, 8, 64) + metadata = self.attn_metadata + layer = self.layer + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=True) + + mock_unified_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_paged_attention_splitfuse') + def test_forward_with_quant_method(self, mock_paged_attention): + """Test forward pass when layer has quant_method""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8) + + metadata = MagicMock() + metadata.num_actual_tokens = torch.randn(10, 8 * 64) + metadata.block_tables = torch.randn(10, 8 * 64) + metadata.seq_lens = torch.randn(10, 8 * 64) + metadata.attn_mask = torch.randn(10, 8 * 64) + metadata.query_lens = torch.randn(10, 8 * 64) + layer = self.layer + layer.quant_method = MagicMock() + layer.quant_method.apply.return_value = kv_cache + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + layer.quant_method.apply.assert_called_once() + assert output.shape == (10, 8 * 64) + + def test_forward_no_attn_metadata(self): + """Test forward pass when attn_metadata is None""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 0, 0, 8, 64) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + None, + trace_flag=False) + + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_flash_attention') + def test_forward_prefill_no_cache(self, mock_flash_attention, + mock_reshape_cache): + """Test forward pass in PrefillNoCache state""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.PrefillNoCache + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.seq_lens = torch.tensor([10]) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + # layer.quant_method.apply.return_value = metadata + print(self.layer_no_quant._v_scale_float) + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_reshape_cache.assert_called_once() + mock_flash_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_flash_attention_qlens') + def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens, + mock_npu_reshape_and_cache): + """Test forward pass in PrefillCacheHit state""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.PrefillCacheHit + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_flash_attention_qlens.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention') + def test_forward_decode_only(self, mock_paged_attention, + mock_npu_reshape_and_cache): + """Test forward pass in DecodeOnly state""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) + @patch('torch_npu._npu_reshape_and_cache') + @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') + def test_forward_head_size_192(self, mock_vanilla_prefill, + mock_npu_reshape_and_cache, mock_is_310p): + """Test forward pass when head_size is 192""" + + self.impl.head_size = 192 + query = torch.randn(10, 8 * 192) + key = torch.randn(10, 8 * 192) + value = torch.randn(10, 8 * 192) + kv_cache = torch.empty(2, 5, 128, 8, 192) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + mock_vanilla_prefill.return_value = MagicMock() + + def mock_tensor(data, device=None, **kwargs): + if device == "npu": + return metadata.attn_mask + return torch.tensor(data, **kwargs) + + with patch("torch.tensor", side_effect=mock_tensor): + output = self.impl_192.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_vanilla_prefill.assert_called_once() + assert output.shape == (10, 8 * 192) + + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention_splitfuse') + def test_forward_normal_v1_situation(self, mock_paged_attention, + mock_npu_reshape_and_cache): + """Test forward pass in normal V1 situation""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu.npu_format_cast') + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention_splitfuse') + @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) + def test_forward_310p_device(self, mock_is_310p, mock_paged_attention, + mock_npu_reshape_and_cache, + mock_npu_format_cast): + """Test forward pass on 310P device""" + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + mock_npu_format_cast.return_value = metadata.attn_mask + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + assert output.shape == (10, 8 * 64) + + @patch('torch_npu._npu_reshape_and_cache') + def test_forward_raise_error(self, mock_paged_attention): + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_mask = torch.randn(1, 1, 10, 10) + metadata.query_lens = torch.tensor([10]) + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + with self.assertRaises(NotImplementedError): + self.impl_error.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) diff --git a/tests/ut/base.py b/tests/ut/base.py new file mode 100644 index 0000000000..e34f175935 --- /dev/null +++ b/tests/ut/base.py @@ -0,0 +1,31 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import unittest + +from vllm_ascend.utils import adapt_patch + +# fused moe ops test will hit the infer_schema error, we need add the patch +# here to make the test pass. +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + + +class TestBase(unittest.TestCase): + + def setUp(self): + # adapt patch by default. + adapt_patch(True) + adapt_patch() + super().setUp() diff --git a/tests/ut/distributed/kv_transfer/test_simple_buffer.py b/tests/ut/distributed/kv_transfer/test_simple_buffer.py new file mode 100644 index 0000000000..6f90df923f --- /dev/null +++ b/tests/ut/distributed/kv_transfer/test_simple_buffer.py @@ -0,0 +1,71 @@ +import unittest +import zlib +from unittest.mock import MagicMock + +import torch + +from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer, + int32_hash) + + +class MockSimplePipe: + + def __init__(self): + self.cluster_id = 0 + self.send_tensor = MagicMock() + self.recv_tensor = MagicMock() + self.deallocate_buffer = MagicMock() + + +class TestSimpleBuffer(unittest.TestCase): + + def setUp(self): + self.pipe = MockSimplePipe() + self.buffer = SimpleBuffer(self.pipe) + + def test_int32_hash(self): + self.assertEqual(int32_hash("test"), zlib.adler32(b"test")) + + def test_insert(self): + input_tokens = torch.tensor([1, 2, 3]) + roi = torch.tensor([1, 0, 1]) + key = torch.randn(2, 3, 4, 5) + value = torch.randn(2, 3, 4, 5) + hidden = torch.randn(3, 6) + + self.buffer.num_layers = 2 + self.buffer.num_heads = 4 + self.buffer.head_size = 5 + self.buffer.hidden_size = 6 + self.buffer.dtype = torch.float32 + + self.buffer.insert(input_tokens, roi, key, value, hidden, "req1") + + self.pipe.send_tensor.assert_called() + + def test_drop_select(self): + input_tokens = torch.tensor([1, 2, 3]) + roi = None + + self.buffer.num_layers = 2 + self.buffer.num_heads = 4 + self.buffer.head_size = 5 + self.buffer.hidden_size = 6 + self.buffer.dtype = torch.float32 + + self.pipe.recv_tensor.side_effect = [ + (MagicMock(), torch.randn(1, 2, 3 * 4 * 5)), + (MagicMock(), torch.randn(1, 2, 3 * 4 * 5)), + (MagicMock(), torch.randn(1, 3, 6)) + ] + + result = self.buffer.drop_select(input_tokens, roi, "req1") + self.assertEqual(len(result), 4) + self.assertIsInstance(result[0], torch.Tensor) + self.assertIsInstance(result[1], torch.Tensor) + self.assertIsInstance(result[2], torch.Tensor) + self.assertIsNone(result[3]) + self.assertEqual(result[0].shape, (2, 3, 4, 5)) + + def test_close(self): + self.buffer.close() diff --git a/tests/ut/distributed/kv_transfer/test_simple_connector.py b/tests/ut/distributed/kv_transfer/test_simple_connector.py new file mode 100644 index 0000000000..ac6c4d478d --- /dev/null +++ b/tests/ut/distributed/kv_transfer/test_simple_connector.py @@ -0,0 +1,146 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +from vllm.config import VllmConfig +from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer +from vllm_ascend.distributed.kv_transfer.simple_connector import \ + SimpleConnector +from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe + + +class TestSimpleConnector(unittest.TestCase): + + def setUp(self): + self.mock_pipe = MagicMock(spec=SimplePipe) + self.mock_buffer = MagicMock(spec=SimpleBuffer) + + patcher = patch( + 'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer') + self.addCleanup(patcher.stop) + self.MockSimpleBuffer = patcher.start() + self.MockSimpleBuffer.return_value = self.mock_buffer + + def _create_mock_config(self, kv_role): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": ["127.0.0.1"], + "decode_device_ips": ["127.0.0.1"], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + mock_config.kv_port = 5500 + self.mock_config = MagicMock(spec=VllmConfig) + self.mock_config.kv_transfer_config.is_kv_producer = True + self.mock_config.model_config.hf_config.hidden_size = 128 + self.mock_config.model_config.hf_config.num_attention_heads = 8 + self.mock_config.model_config.hf_config.num_key_value_heads = 8 + self.mock_config.model_config.hf_config.qk_rope_head_dim = 16 + self.mock_config.model_config.hf_config.kv_lora_rank = 16 + self.mock_config.model_config.is_deepseek_mla = True + # 模拟 parallel_config + self.mock_config.parallel_config = MagicMock() + self.mock_config.parallel_config.tensor_parallel_size = 1 + self.mock_config.parallel_config.get_num_layers.return_value = 4 + + if kv_role == "kv_producer": + self.mock_config.kv_transfer_config.kv_role = "kv_producer" + else: + self.mock_config.kv_transfer_config.kv_role = "kv_consumer" + return mock_config + + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist): + """Test select method when buffer retrieval succeeds.""" + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_producer")) + assert connector.producer_data_pipe is not None + assert connector.producer_buffer is not None + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist): + + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_consumer")) + connector.consumer_data_pipe = mock_pipe + connector.consumer_buffer = mock_buffer + assert connector.consumer_data_pipe is not None + assert connector.consumer_buffer is not None + input_tokens = torch.tensor([1, 2, 3]) + roi = torch.tensor([True, True, True]) + req_id = "test_req" + connector.select(input_tokens, roi, req_id) + + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist): + """Test insert operation""" + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_producer")) + + connector.producer_buffer = mock_buffer + + input_tokens = torch.randint(0, 1000, (5, )) + roi = torch.ones_like(input_tokens, dtype=torch.bool) + keys = torch.randn(3, 5, 1, 96) + values = torch.randn(3, 5, 1, 96) + hidden = torch.randn(5, 768) + req_id = "test_req" + + connector.insert(input_tokens, roi, keys, values, hidden, req_id) + + mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys, + values, hidden, req_id) + + @patch.object(SimpleConnector, 'insert') + @patch('torch.distributed.get_rank', return_value=0) + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe') + @patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer') + @patch('llm_datadist.LLMDataDist') + def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer, + MockLLMDataDist, mock_insert, + mock_rank): + """Test sending KV caches and hidden states""" + connector = SimpleConnector( + rank=0, + local_rank=0, + config=self._create_mock_config("kv_producer")) + + mock_model_executable = MagicMock() + mock_model_executable.model.start_layer = 0 + mock_model_executable.model.end_layer = 3 + + mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata) + mock_model_input.input_tokens = torch.randint(0, 1000, (10, )) + mock_model_input.attn_metadata.seq_lens = [5, 5] + mock_model_input.attn_metadata.slot_mapping = torch.randint( + 0, 100, (10, )) + mock_model_input.attn_metadata.num_prefill_tokens = 10 + mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]} + + kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)] + + hidden_states = torch.randn(10, 768) + + connector.send_kv_caches_and_hidden_states(mock_model_executable, + mock_model_input, kv_caches, + hidden_states) diff --git a/tests/ut/distributed/kv_transfer/test_simple_pipe.py b/tests/ut/distributed/kv_transfer/test_simple_pipe.py new file mode 100644 index 0000000000..efd6eddea8 --- /dev/null +++ b/tests/ut/distributed/kv_transfer/test_simple_pipe.py @@ -0,0 +1,145 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe + + +class TestSimplePipe(unittest.TestCase): + + @classmethod + def _create_mock_config(self): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": ["127.0.0.1"], + "decode_device_ips": ["127.0.0.1"], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + mock_config.kv_port = 5500 + return mock_config + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_init_success(self, mock_thread, MockLLMDataDist): + + mock_config = self._create_mock_config() + + self.pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + + self.pipe.router_socket.close() + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_prepare_data_dist(self, mock_thread, MockLLMDataDist): + self.pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=self._create_mock_config(), + hostname="127.0.0.1", + port_offset=0) + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + self.pipe.router_socket.close() + + def test_init_with_invalid_kv_role(self): + with self.assertRaises(NotImplementedError): + mock_config = MagicMock() + mock_config.kv_role = "err_role" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": ["127.0.0.1"], + "decode_device_ips": ["127.0.0.1"], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + pipe.router_socket.close() + + def test_init_with_missing_device_ips(self): + with self.assertRaises(ValueError): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + pipe = SimplePipe(rank=0, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + pipe.router_socket.close() + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_create_register_thread_address_is_empty(self, MockThread, + MockLLMDataDist): + + mock_config = self._create_mock_config() + pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + self.assertIsNotNone(pipe._register_thread) + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + pipe.router_socket.close() + + @patch('threading.Thread') + @patch('llm_datadist.LLMDataDist') + def test_create_register_thread_address_is_not_empty( + self, MockThread, MockLLMDataDist): + mock_config = MagicMock() + mock_config.kv_role = "kv_producer" + mock_config.kv_connector_extra_config = { + "prefill_device_ips": [""], + "decode_device_ips": [""], + "llmdatadist_comm_port": 26000, + "http_port": 8000, + "proxy_ip": "127.0.0.1", + "proxy_port": "8000", + "port": 5500 + } + pipe = SimplePipe(rank=5, + local_rank=0, + kv_transfer_config=mock_config, + hostname="127.0.0.1", + port_offset=0) + self.assertIsNotNone(pipe._register_thread) + mock_data_dist = MockLLMDataDist.return_value + mock_data_dist.init.return_value = None + pipe.router_socket.close() + + @patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe') + @patch('llm_datadist.LLMDataDist') + def test_should_send_tensor_when_valid_input(self, MockSimplePipe, + MockLLMDataDist): + pipe = MockSimplePipe() + tensor = torch.randn(3, 3) + tensor_desc = MockLLMDataDist.CacheDesc( + num_tensors=1, + shape=(3, 3), + data_type=MockLLMDataDist.DataType.DT_FLOAT, + seq_len_dim_index=1) + tensor_key = MockLLMDataDist.CacheKey(1, 0, 1) + result = pipe.send_tensor(tensor, tensor_desc, tensor_key) + self.assertIsNotNone(result) diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py new file mode 100644 index 0000000000..b00eeb90a0 --- /dev/null +++ b/tests/ut/distributed/test_parallel_state.py @@ -0,0 +1,208 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from unittest.mock import MagicMock, patch + +import pytest +from vllm.distributed.parallel_state import GroupCoordinator + +import vllm_ascend +from tests.ut.base import TestBase +from vllm_ascend.distributed.parallel_state import ( + destory_ascend_model_parallel, get_ep_group, get_etp_group, + init_ascend_model_parallel, model_parallel_initialized) + + +class TestParallelState(TestBase): + + @patch('vllm_ascend.distributed.parallel_state._EP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_get_ep_group_when_initialized(self, mock_ep): + # Act + result = get_ep_group() + + # Assert + assert isinstance(result, GroupCoordinator) + + @patch('vllm_ascend.distributed.parallel_state._EP', None) + def test_get_ep_group_when_not_initialized(self): + # Act & Assert + with pytest.raises(AssertionError) as excinfo: + get_ep_group() + assert "expert model parallel group is not initialized" in str( + excinfo.value) + + @patch('vllm_ascend.distributed.parallel_state._ETP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_get_etp_group_when_initialized(self, mock_etp): + # Act + result = get_etp_group() + + # Assert + assert isinstance(result, GroupCoordinator) + + @patch('vllm_ascend.distributed.parallel_state._ETP', None) + def test_get_etp_group_when_not_initialized(self): + # Act & Assert + with pytest.raises(AssertionError) as excinfo: + get_etp_group() + assert "expert tensor parallel group is not initialized" in str( + excinfo.value) + + @patch('vllm_ascend.distributed.parallel_state._ETP', None) + @patch('vllm_ascend.distributed.parallel_state._EP', None) + def test_model_parallel_initialized_when_both_none(self): + # Act & Assert + assert not model_parallel_initialized() + + @patch('vllm_ascend.distributed.parallel_state._ETP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm_ascend.distributed.parallel_state._EP', None) + def test_model_parallel_initialized_when_ep_none(self, mock_etp): + # Act & Assert + assert not model_parallel_initialized() + + @patch('vllm_ascend.distributed.parallel_state._ETP', None) + @patch('vllm_ascend.distributed.parallel_state._EP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_model_parallel_initialized_when_etp_none(self, mock_ep): + # Act & Assert + assert not model_parallel_initialized() + + @patch('vllm_ascend.distributed.parallel_state._ETP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm_ascend.distributed.parallel_state._EP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_model_parallel_initialized_when_etp_initialized( + self, mock_ep, mock_etp): + # Act & Assert + assert model_parallel_initialized() + + @patch('vllm_ascend.distributed.parallel_state._ETP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm_ascend.distributed.parallel_state._EP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_destroy_when_both_exist(self, mock_ep, mock_etp): + # Act + destory_ascend_model_parallel() + # Assert + mock_ep.destroy.assert_called_once() + mock_etp.destroy.assert_called_once() + assert vllm_ascend.distributed.parallel_state._ETP is None + assert vllm_ascend.distributed.parallel_state._EP is None + + @patch('vllm_ascend.distributed.parallel_state._ETP', None) + @patch('vllm_ascend.distributed.parallel_state._EP', + new_callable=lambda: MagicMock()) + def test_destory_ascend_model_parallel_when_etp_none(self, mock_ep): + # Act + destory_ascend_model_parallel() + # Assert + mock_ep.destroy.assert_called_once() + assert vllm_ascend.distributed.parallel_state._EP is None + assert vllm_ascend.distributed.parallel_state._ETP is None + + @patch('vllm_ascend.distributed.parallel_state._ETP', + new_callable=lambda: MagicMock()) + @patch('vllm_ascend.distributed.parallel_state._EP', None) + def test_destory_ascend_model_parallel_when_ep_none(self, mock_etp): + # Act + destory_ascend_model_parallel() + # Assert + mock_etp.destroy.assert_called_once() + assert vllm_ascend.distributed.parallel_state._ETP is None + assert vllm_ascend.distributed.parallel_state._EP is None + + @patch('vllm_ascend.distributed.parallel_state._ETP', None) + @patch('vllm_ascend.distributed.parallel_state._EP', None) + def test_destory_ascend_model_parallel_when_both_none(self): + # Act + destory_ascend_model_parallel() + # Assert + assert vllm_ascend.distributed.parallel_state._ETP is None + assert vllm_ascend.distributed.parallel_state._EP is None + + @patch('torch.distributed.is_initialized', return_value=True) + @patch('torch.distributed.get_world_size', return_value=8) + @patch('vllm_ascend.distributed.parallel_state.get_world_group', + return_value=MagicMock(device_group='npu:0', local_rank=0)) + @patch('torch.distributed.get_backend', return_value='hccl') + @patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group') + @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', + return_value=False) + def test_init_ascend_model_parallel_normal_case( + self, mock_mp_init, mock_init_group, mock_get_backend, + mock_world_group, mock_get_world_size, mock_is_init): + """Test normal initialization with default parameters""" + # Act + init_ascend_model_parallel() + # Assert + mock_init_group.assert_any_call([[0, 1, 2, 3, 4, 5, 6, 7]], + 0, + 'hccl', + group_name="ep") + mock_init_group.assert_any_call([[0]], 0, 'hccl', group_name="etp") + self.assertIsNotNone(vllm_ascend.distributed.parallel_state._EP) + self.assertIsNotNone(vllm_ascend.distributed.parallel_state._ETP) + + @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', + return_value=True) + def test_init_ascend_model_parallel_skip_if_initialized( + self, mock_mp_init): + """Test skipping when model parallel already initialized""" + with patch.object(vllm_ascend.distributed.parallel_state, + '_EP') as mock_ep, patch.object( + vllm_ascend.distributed.parallel_state, + '_ETP') as mock_etp: + # Act + init_ascend_model_parallel() + # Assert + mock_ep.assert_not_called() + mock_etp.assert_not_called() + + @patch('torch.distributed.is_initialized', return_value=False) + def test_init_ascend_model_parallel_assert_dist_not_init( + self, mock_is_init): + """Test assertion when distributed not initialized""" + # Act & Assert + with self.assertRaises(AssertionError): + init_ascend_model_parallel() + + @patch('torch.distributed.is_initialized', return_value=True) + @patch('torch.distributed.get_world_size', return_value=8) + @patch('vllm_ascend.distributed.parallel_state.get_world_group', + return_value=MagicMock(device_group='npu:0', local_rank=1)) + @patch('torch.distributed.get_backend', return_value='hccl') + @patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group') + @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', + return_value=False) + def test_init_ascend_model_parallel_custom_params( + self, mock_mp_init, mock_init_group, mock_get_backend, + mock_world_group, mock_get_world_size, mock_is_init): + """Test initialization with custom parallel sizes""" + # Act + init_ascend_model_parallel(expert_parallel_size=2, + expert_tensor_parallel_size=4, + world_size=8, + backend='hccl') + #Assert + mock_init_group.assert_any_call([[0, 4], [1, 5], [2, 6], [3, 7]], + 1, + 'hccl', + group_name="ep") + mock_init_group.assert_any_call([[0, 1, 2, 3], [4, 5, 6, 7]], + 1, + 'hccl', + group_name="etp") diff --git a/tests/ut/ops/expert_map.json b/tests/ut/ops/expert_map.json new file mode 100644 index 0000000000..bb74799a7c --- /dev/null +++ b/tests/ut/ops/expert_map.json @@ -0,0 +1,17 @@ +{ + "moe_layer_count": + 1, + "layer_list": [{ + "layer_id": + 0, + "device_count": + 2, + "device_list": [{ + "device_id": 0, + "device_expert": [7, 2, 0, 3, 5] + }, { + "device_id": 1, + "device_expert": [6, 1, 4, 7, 2] + }] + }] +} diff --git a/tests/ut/ops/test_expert_load_balancer.py b/tests/ut/ops/test_expert_load_balancer.py index 3b7a69ddd4..97beada12c 100644 --- a/tests/ut/ops/test_expert_load_balancer.py +++ b/tests/ut/ops/test_expert_load_balancer.py @@ -1,14 +1,26 @@ -# fused moe ops test will hit the infer_schema error, we need add the patch -# here to make the test pass. -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# import json -import unittest +import os from typing import List, TypedDict from unittest import mock import torch +from tests.ut.base import TestBase from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer @@ -28,31 +40,13 @@ class MockData(TypedDict): layer_list: List[Layer] -MOCK_DATA: MockData = { - "moe_layer_count": - 1, - "layer_list": [{ - "layer_id": - 0, - "device_count": - 2, - "device_list": [{ - "device_id": 0, - "device_expert": [7, 2, 0, 3, 5] - }, { - "device_id": 1, - "device_expert": [6, 1, 4, 7, 2] - }] - }] -} - - -class TestExpertLoadBalancer(unittest.TestCase): +class TestExpertLoadBalancer(TestBase): def setUp(self): - json_file = "expert_map.json" - with open(json_file, 'w') as f: - json.dump(MOCK_DATA, f) + _TEST_DIR = os.path.dirname(__file__) + json_file = _TEST_DIR + "/expert_map.json" + with open(json_file, 'r') as f: + self.expert_map: MockData = json.load(f) self.expert_load_balancer = ExpertLoadBalancer(json_file, global_expert_num=8) @@ -62,9 +56,9 @@ def test_init(self): self.assertIsInstance(self.expert_load_balancer.expert_map_tensor, torch.Tensor) self.assertEqual(self.expert_load_balancer.layers_num, - MOCK_DATA["moe_layer_count"]) + self.expert_map["moe_layer_count"]) self.assertEqual(self.expert_load_balancer.ranks_num, - MOCK_DATA["layer_list"][0]["device_count"]) + self.expert_map["layer_list"][0]["device_count"]) def test_generate_index_dicts(self): tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]]) @@ -142,6 +136,6 @@ def test_get_rank_log2phy_map(self): def test_get_global_redundant_expert_num(self): redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num( ) - expected_redundant_expert_num = len(MOCK_DATA["layer_list"][0]["device_list"][0]["device_expert"]) * \ - MOCK_DATA["layer_list"][0]["device_count"] - 8 + expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \ + self.expert_map["layer_list"][0]["device_count"] - 8 self.assertEqual(redundant_expert_num, expected_redundant_expert_num) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py new file mode 100644 index 0000000000..91c2ad40df --- /dev/null +++ b/tests/ut/ops/test_rotary_embedding.py @@ -0,0 +1,315 @@ +import math +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled, + native_rope_deepseek_forward, + rope_forward_oot, rotate_half, + yarn_find_correction_dim, + yarn_get_mscale) + + +class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 4, dtype=torch.float16) + self.key = torch.randn(3, 4, dtype=torch.float16) + self.head_size = 32 + self.cos_sin_cache = torch.randn(3, 4) + + # Mock self object for rope_forward_oot + self.mock_self = MagicMock() + self.mock_self.head_size = self.head_size + self.mock_self.cos_sin_cache = self.cos_sin_cache + self.mock_self.is_neox_style = True + self.mock_self.forward_native.return_value = (self.query, self.key) + + def test_custom_rotary_embedding_enabled(self): + # Test when all conditions are True + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size) + self.assertTrue(result) + + # Test when dtype is not float16 + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + query = self.query.to(torch.float32) + result = custom_rotary_embedding_enabled(query, True, + self.head_size) + self.assertFalse(result) + + # Test when neox_style is False + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, False, + self.head_size) + self.assertFalse(result) + + # Test when head_size is not divisible by 32 + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size + 1) + self.assertFalse(result) + + # Test when custom op is disabled + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=False): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size) + self.assertFalse(result) + + +class TestRopeForwardOot(unittest.TestCase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 4, dtype=torch.float16) + self.key = torch.randn(3, 4, dtype=torch.float16) + self.head_size = 32 + self.cos_sin_cache = torch.randn(3, 4) + + # Mock self object for rope_forward_oot + self.mock_self = MagicMock() + self.mock_self.head_size = self.head_size + self.mock_self.cos_sin_cache = self.cos_sin_cache + self.mock_self.is_neox_style = True + self.mock_self.forward_native.return_value = (self.query, self.key) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + def test_rope_forward_oot_torchair_enabled_base(self, + mock_get_ascend_config): + # Setup mock for torchair enabled + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_ascend_config.return_value = mock_config + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + self.query, self.key) + + self.mock_self.forward_native.assert_called_once_with( + self.positions, self.query, self.key, None) + self.assertTrue(torch.equal(result_q, self.query)) + self.assertTrue(torch.equal(result_k, self.key)) + + @patch('torch.ops._C') + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) + @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + return_value=True) + @patch('torch.ops._npu_rotary_embedding') + def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, + mock_custom_enabled, mock_is_310p, + mock_get_ascend_config, mock__c): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Setup mock for custom kernel path + + mock__c.rotary_embedding.return_value = self.query, self.key + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + self.query, self.key) + + self.assertEqual(result_q.shape, self.query.shape) + self.assertEqual(result_k.shape, self.key.shape) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_contiguous(self, mock_npu_rotary, + mock_custom_enabled, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test contiguous path when custom is disabled + non_contig_query = self.query.transpose(0, 1) + non_contig_key = self.key.transpose(0, 1) + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + non_contig_query, non_contig_key) + + mock_npu_rotary.assert_called_once() + self.assertEqual(result_q.shape, non_contig_query.shape) + self.assertEqual(result_k.shape, non_contig_key.shape) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test that NotImplementedError is raised when offsets is provided + offsets = torch.tensor([1, 2, 3]) + with self.assertRaises(NotImplementedError): + rope_forward_oot(self.mock_self, self.positions, self.query, + self.key, offsets) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, + mock_custom_enabled, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test neox_style override + result_q, result_k = rope_forward_oot(self.mock_self, + self.positions, + self.query, + self.key, + is_neox_style_override=False) + + # Check that neox_style=False was passed to the NPU function + args, kwargs = mock_npu_rotary.call_args + self.assertFalse(args[-1]) + + +class MockRopeModule: + + def __init__(self, max_seq_len=2048, is_neox_style=True): + self.max_seq_len = max_seq_len + self.is_neox_style = is_neox_style + self.cos_cached = None + self.sin_cached = None + self.rotary_dim = 1 + self.base = 1 + + +class TestNativeRopeDeepseekForward(TestBase): + + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot): + module = MockRopeModule() + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + @patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache') + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_cache_handling( + self, mock_rope_forward_oot, mock_set_cache): + # Test cache situation is true + module = MockRopeModule(max_seq_len=1024) + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, + positions, + query, + key, + max_seq_len=2048) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_key_reshaping( + self, mock_rope_forward_oot): + module = MockRopeModule() + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == (1, 128) + + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_non_neox_style( + self, mock_rope_forward_oot): + module = MockRopeModule(is_neox_style=False) + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + +class TestRotateHalf(unittest.TestCase): + + def test_rotate_half_even_dim(self): + # Test with even dimension + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + expected = torch.tensor([-3.0, -4.0, 1.0, 2.0]) + result = rotate_half(x) + self.assertTrue(torch.allclose(result, expected)) + + +class TestYarnFindCorrectionDim(unittest.TestCase): + + def test_basic_case(self): + # Test with standard values + num_rotations = 100 + dim = 512 + base = 10000 + max_position_embeddings = 2048 + + result = yarn_find_correction_dim(num_rotations, dim, base, + max_position_embeddings) + + # Calculate expected value manually + expected = (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * + torch.log(torch.tensor(base))) + + self.assertTrue(torch.allclose(result, expected)) + + +class TestYarnGetMscale(unittest.TestCase): + + def test_scale_less_than_or_equal_1(self): + self.assertEqual(yarn_get_mscale(scale=0.5), 1.0) + self.assertEqual(yarn_get_mscale(scale=1.0), 1.0) + self.assertEqual(yarn_get_mscale(scale=0.999), 1.0) + + def test_scale_greater_than_1(self): + test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)), + (10.0, 1.0, 1.0 + 0.1 * math.log(10.0)), + (5.0, 2.0, 1.0 + 0.2 * math.log(5.0)), + (math.e, 1.0, 1.0 + 0.1)] + + for scale, mscale, expected in test_cases: + result = yarn_get_mscale(scale, mscale) + self.assertAlmostEqual( + result, + expected, + places=6, + msg=f"Failed for scale={scale}, mscale={mscale}") diff --git a/tests/ut/patch/worker/patch_common/test_patch_distributed.py b/tests/ut/patch/worker/patch_common/test_patch_distributed.py new file mode 100644 index 0000000000..73525eefd7 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_distributed.py @@ -0,0 +1,27 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from tests.ut.base import TestBase + + +class TestPatchDistributed(TestBase): + + def test_GroupCoordinator_patched(self): + from vllm.distributed.parallel_state import GroupCoordinator + + from vllm_ascend.patch.worker.patch_common.patch_distributed import \ + GroupCoordinatorPatch + + self.assertIs(GroupCoordinator, GroupCoordinatorPatch) diff --git a/tests/ut/patch/worker/patch_common/test_patch_sampler.py b/tests/ut/patch/worker/patch_common/test_patch_sampler.py new file mode 100644 index 0000000000..3db3fa2ee7 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_sampler.py @@ -0,0 +1,31 @@ +import importlib +import os +from unittest import mock + +import torch +from vllm.v1.sample.ops import topk_topp_sampler + +from tests.ut.base import TestBase + + +class TestTopKTopPSamplerOptimize(TestBase): + + @mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"}) + @mock.patch("torch_npu.npu_top_k_top_p") + def test_npu_topk_topp_called_when_optimized(self, mock_npu_op): + # We have to patch and reload because the patch will take effect + # only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set. + import vllm_ascend.patch.worker.patch_0_9_1.patch_sampler + importlib.reload(vllm_ascend.patch.worker.patch_0_9_1.patch_sampler) + + mock_npu_op.return_value = (torch.randn(1, 3)) + sampler = topk_topp_sampler.TopKTopPSampler() + + logits = torch.tensor([[1.0, 2.0, 3.0]]) + k = torch.tensor([2]) + p = torch.tensor([0.9]) + generators = {0: torch.Generator()} + generators[0].manual_seed(42) + + sampler.forward_native(logits, generators, k, p) + mock_npu_op.assert_called_once_with(logits, p, k) diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py new file mode 100644 index 0000000000..6591d93428 --- /dev/null +++ b/tests/ut/quantization/test_quant_config.py @@ -0,0 +1,230 @@ +from unittest.mock import MagicMock, patch + +import torch +from vllm.attention.layer import Attention +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) + +from tests.ut.base import TestBase +from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod, + AscendQuantConfig) + +ASCEND_QUATIZATION_METHOD = "ascend" + + +class TestAscendQuantConfig(TestBase): + + def setUp(self): + self.sample_config = { + "weight": "INT8", + "fa_quant_type": "C8", + "kv_quant_type": "C8", + "layer1.weight": "INT8", + "layer2.weight": "FLOAT", + "fused_layer.weight": "FLOAT", + "fused_layer.shard1.weight": "FLOAT", + "fused_layer.shard2.weight": "FLOAT", + "shard1.weight": "FLOAT", + "shard2.weight": "FLOAT", + } + self.ascend_config = AscendQuantConfig(self.sample_config) + self.ascend_config.packed_modules_mapping = None + + def test_init(self): + self.assertEqual(self.ascend_config.quant_description, + self.sample_config) + + def test_repr(self): + repr_str = repr(self.ascend_config) + self.assertTrue(repr_str.startswith("AscendQuantConfig:\n")) + + def test_get_name(self): + self.assertEqual(AscendQuantConfig.get_name(), + ASCEND_QUATIZATION_METHOD) + + def test_get_supported_act_dtypes(self): + supported_dtypes = AscendQuantConfig.get_supported_act_dtypes() + self.assertEqual(len(supported_dtypes), 3) + + def test_get_min_capability(self): + with self.assertRaises(NotImplementedError): + AscendQuantConfig.get_min_capability() + + def test_get_config_filenames(self): + filenames = AscendQuantConfig.get_config_filenames() + self.assertEqual(filenames, ["quant_model_description.json"]) + + def test_from_config(self): + config = AscendQuantConfig.from_config(self.sample_config) + self.assertIsInstance(config, AscendQuantConfig) + self.assertEqual(config.quant_description, self.sample_config) + + @patch('torch.npu.is_available') + def test_override_quantization_method(self, mock_is_available): + # Test when NPU is available + mock_is_available.return_value = True + result = AscendQuantConfig.override_quantization_method(None, None) + self.assertEqual(result, ASCEND_QUATIZATION_METHOD) + + # Test when NPU is not available + mock_is_available.return_value = False + result = AscendQuantConfig.override_quantization_method(None, None) + self.assertIsNone(result) + + def test_get_quant_method_for_linear(self): + linear_layer = MagicMock(spec=LinearBase) + # Test skipped layer + with patch.object(self.ascend_config, + 'is_layer_skipped_ascend', + return_value=True): + method = self.ascend_config.get_quant_method(linear_layer, ".attn") + self.assertIsInstance(method, UnquantizedLinearMethod) + + # Test quantized layer + with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ + patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: + + method = self.ascend_config.get_quant_method(linear_layer, ".attn") + self.assertIs(method, mock_ascend_linear.return_value) + mock_ascend_linear.assert_called_once_with( + self.ascend_config, ".attn", + self.ascend_config.packed_modules_mapping) + + def test_get_quant_method_for_attention(self): + attention_layer = MagicMock(spec=Attention) + with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', + return_value=MagicMock()) as mock_ascend_kvcache: + # Test with fa_quant_type + method = self.ascend_config.get_quant_method( + attention_layer, ".attn") + self.assertIs(method, mock_ascend_kvcache.return_value) + + with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', + return_value=MagicMock()) as mock_ascend_kvcache: + # Test with kv_quant_type + modified_config = {"kv_quant_type": "C8"} + config = AscendQuantConfig(modified_config) + config.packed_modules_mapping = None + method = config.get_quant_method(attention_layer, "attn") + self.assertIs(method, mock_ascend_kvcache.return_value) + + def test_get_quant_method_for_fused_moe(self): + fused_moe_layer = MagicMock(spec=FusedMoE) + + # Test skipped layer + with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \ + patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + method = self.ascend_config.get_quant_method( + fused_moe_layer, "moe_layer") + self.assertIs(method, mock_ascend_moe.return_value) + + # Test quantized layer + with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ + patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + method = self.ascend_config.get_quant_method( + fused_moe_layer, "moe_layer") + self.assertIs(method, mock_ascend_moe.return_value) + + def test_is_layer_skipped_ascend(self): + # Test non-fused layer that should be quantized + self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1")) + + # Test non-fused layer that should be skipped + self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2")) + + # Test fused layer + fused_mapping = {"fused_layer": ["shard1", "shard2"]} + self.assertTrue( + self.ascend_config.is_layer_skipped_ascend("fused_layer", + fused_mapping)) + + # Test inconsistent fused layer shards + bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"} + config = AscendQuantConfig(bad_config) + with self.assertRaises(ValueError): + config.is_layer_skipped_ascend("fused_layer", fused_mapping) + + def test_get_scaled_act_names(self): + self.assertEqual(self.ascend_config.get_scaled_act_names(), []) + + +class TestAscendKVCacheMethod(TestBase): + + def setUp(self): + # Setup common test fixtures + self.mock_quant_config = MagicMock(spec=AscendQuantConfig) + self.mock_quant_config.quant_description = {"some_config": "value"} + self.prefix = "attention_layer" + + # Mock the quantizer and quant_method + self.mock_quantizer = MagicMock() + self.mock_quant_method = MagicMock() + + # Patch the AscendQuantizer + self.quantizer_patcher = patch( + 'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer', + return_value=self.mock_quantizer) + self.mock_get_quantizer = self.quantizer_patcher.start() + + self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method + + # Create instance + self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config, + self.prefix) + + def tearDown(self): + self.quantizer_patcher.stop() + + def test_init(self): + """Test initialization with proper quantizer setup.""" + self.mock_get_quantizer.assert_called_once_with( + self.mock_quant_config.quant_description, self.prefix) + self.mock_quantizer.build_attention_method.assert_called_once() + + def test_create_weights(self): + """Test create_weights delegates to quant_method.""" + mock_layer = MagicMock() + self.kv_cache_method.create_weights(mock_layer) + self.mock_quant_method.create_weights.assert_called_once_with( + mock_layer) + + def test_process_weights_after_loading_with_method(self): + """Test process_weights when quant_method has the method.""" + mock_layer = MagicMock() + self.kv_cache_method.process_weights_after_loading(mock_layer) + self.mock_quant_method.process_weights_after_loading.assert_called_once_with( + mock_layer) + + def test_process_weights_after_loading_without_method(self): + """Test process_weights when quant_method lacks the method.""" + # Reset mock to remove the method + del self.mock_quant_method.process_weights_after_loading + mock_layer = MagicMock() + + # Should not raise exception + self.kv_cache_method.process_weights_after_loading(mock_layer) + + def test_apply_delegation(self): + """Test apply properly delegates to quant_method.""" + mock_layer = MagicMock() + mock_query = torch.randn(1, 32, 128) + mock_key = torch.randn(1, 32, 128) + mock_value = torch.randn(1, 32, 128) + mock_kv_cache = MagicMock() + mock_attn_metadata = MagicMock() + mock_scale = 1.0 + mock_output = torch.zeros(1, 32, 128) + mock_attn_type = MagicMock() + expected_result = torch.randn(1, 32, 128) + self.mock_quant_method.apply.return_value = expected_result + + result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key, + mock_value, mock_kv_cache, + mock_attn_metadata, mock_attn_type, + mock_scale, mock_output) + + self.mock_quant_method.apply.assert_called_once_with( + mock_layer, mock_query, mock_key, mock_value, mock_kv_cache, + mock_attn_metadata, mock_attn_type, mock_scale, mock_output) + self.assertTrue(torch.equal(result, expected_result)) diff --git a/tests/ut/quantization/test_quantizer.py b/tests/ut/quantization/test_quantizer.py new file mode 100644 index 0000000000..559cf19379 --- /dev/null +++ b/tests/ut/quantization/test_quantizer.py @@ -0,0 +1,122 @@ +from unittest.mock import MagicMock, patch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.quant_config import AscendQuantConfig +from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer, + W8A8Quantizer) + +SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"} + + +class TestGetQuantizer(TestBase): + + def setUp(self): + # Setup common test fixtures + self.supported_types = { + 'INT8': MagicMock(_instance=None), + 'FP16': MagicMock(_instance=None), + 'C8': MagicMock(_instance=None) + } + self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy() + SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types) + self.mock_quant_config = MagicMock(spec=AscendQuantConfig) + self.mock_quant_config.quant_description = {"some_config": "value"} + + def tearDown(self): + # Restore original supported types + SUPPORT_ASCEND_QUANTIZER_TYPE.clear() + SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types) + + def test_get_quantizer_fa(self): + """Test successful quantizer retrieval for different cases.""" + # Setup + quant_description = {'fa_quant_type': 'C8'} + prefix = '.attn' + expected_type = 'C8' + with patch.dict( + 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', + SUPPORT_ASCEND_QUANTIZER_TYPE): + + result = VLLMAscendQuantizer.get_quantizer( + quant_description, + prefix, + packed_modules_mapping={"some": "mapping"}) + + # Verify + self.assertIsNotNone(result) + self.assertEqual(result, + self.supported_types[expected_type]._instance) + self.supported_types[expected_type].assert_called_once_with( + quant_description) + + def test_get_quantizer_kv(self): + """Test successful quantizer retrieval for different cases.""" + # Setup + quant_description = {'kv_quant_type': 'C8'} + prefix = '.attn' + expected_type = 'C8' + with patch.dict( + 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', + SUPPORT_ASCEND_QUANTIZER_TYPE): + + result = VLLMAscendQuantizer.get_quantizer( + quant_description, + prefix, + packed_modules_mapping={"some": "mapping"}) + + # Verify + self.assertIsNotNone(result) + self.assertEqual(result, + self.supported_types[expected_type]._instance) + self.supported_types[expected_type].assert_called_once_with( + quant_description) + + def test_get_quantizer_linear(self): + """Test successful quantizer retrieval for different cases.""" + # Setup + quant_description = {'linear_type': 'INT8'} + prefix = 'nothing' + expected_type = 'INT8' + with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type', + return_value=expected_type), \ + patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE): + + result = VLLMAscendQuantizer.get_quantizer( + quant_description, + prefix, + packed_modules_mapping={"some": "mapping"}) + + # Verify + self.assertIsNotNone(result) + self.assertEqual(result, + self.supported_types[expected_type]._instance) + self.supported_types[expected_type].assert_called_once_with( + quant_description) + + +class TestW8A8Quantizer(TestBase): + + def setUp(self): + self.quantizer = W8A8Quantizer(quant_description={}) + + def test_build_linear_method(self): + with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_linear_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) + + def test_build_moe_method(self): + with patch( + 'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_moe_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) + + def test_build_attention_method(self): + with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_attention_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py new file mode 100644 index 0000000000..392355a7e9 --- /dev/null +++ b/tests/ut/quantization/test_w8a8.py @@ -0,0 +1,906 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, + AscendW8A8FusedMoEMethod, + AscendW8A8LinearMethod, + fused_experts, fused_experts_310p, + native_grouped_topk, + quant_per_tensor, select_experts) + + +class TestQuantPerTensor(TestBase): + + @patch("torch_npu.npu_quantize") + def test_quant_per_tensor(self, mock_npu_quantize): + in_tensor = torch.randn(32, 128) + input_scale = torch.tensor(0.1) + input_offset = torch.tensor(0) + + expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + mock_npu_quantize.return_value = expected_output + + output = quant_per_tensor(in_tensor, input_scale, input_offset) + + mock_npu_quantize.assert_called_once_with( + in_tensor, + input_scale, + input_offset, + torch.qint8, + -1, + False, + ) + + self.assertTrue(torch.equal(output, expected_output)) + + +class TestAscendW8A8LinearMethod(TestBase): + + def setUp(self): + self.method = AscendW8A8LinearMethod() + + def test_get_weight(self): + weight = self.method.get_weight(10, 20) + self.assertEqual(weight['weight'].dtype, torch.int8) + self.assertEqual(weight['weight'].shape, (20, 10)) + + def test_get_pertensor_param(self): + params = self.method.get_pertensor_param(torch.bfloat16) + self.assertEqual(params['input_scale'].dtype, torch.bfloat16) + self.assertEqual(params['input_offset'].dtype, torch.int8) + self.assertEqual(params['input_scale'].shape, (1, )) + self.assertEqual(params['input_offset'].shape, (1, )) + + def test_get_perchannel_param(self): + params = self.method.get_perchannel_param(10, torch.bfloat16) + + self.assertEqual(params['quant_bias'].dtype, torch.int32) + self.assertEqual(params['deq_scale'].dtype, torch.float32) + self.assertEqual(params['weight_scale'].dtype, torch.bfloat16) + self.assertEqual(params['weight_offset'].dtype, torch.bfloat16) + self.assertEqual(params['quant_bias'].shape, (10, )) + self.assertEqual(params['deq_scale'].shape, (10, )) + self.assertEqual(params['weight_scale'].shape, (10, 1)) + self.assertEqual(params['weight_offset'].shape, (10, 1)) + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, + mock_quant_per_tensor): + layer = MagicMock() + layer.aclnn_input_scale = 0.1 + layer.aclnn_input_offset = 0.2 + layer.weight = torch.randn(128, 256) + layer.deq_scale = 0.3 + + x = torch.randn(32, 128) + bias = torch.randn(256) + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + x.shape, + dtype=torch.int8) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_is_int8(self, mock_npu_quant_matmul): + layer = MagicMock() + layer.aclnn_input_scale = 0.1 + layer.aclnn_input_offset = 0.2 + layer.weight = torch.randn(128, 256) + layer.deq_scale = 0.3 + + x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + bias = torch.randn(256) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True) + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p): + layer = MagicMock() + layer.aclnn_input_scale = 0.1 + layer.aclnn_input_offset = 0.2 + layer.weight = torch.randn(128, 256) + layer.deq_scale = 0.3 + + x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + bias = torch.randn(256) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading(self, mock_npu_format_cast): + layer = MagicMock() + + layer.weight.data = torch.randn(128, 256) + layer.input_scale.data = torch.tensor([0.1]) + layer.input_offset.data = torch.tensor([0]) + layer.deq_scale = torch.tensor([0.5]) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + expected_offset = torch.tensor([0]).repeat(256).to(torch.int8) + self.assertTrue( + torch.equal(layer.aclnn_input_offset.data, expected_offset)) + self.assertFalse(layer.aclnn_input_offset.requires_grad) + + self.assertFalse(layer.deq_scale.requires_grad) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + + +class TestAscendW8A8FusedMoEMethod(TestBase): + + def setUp(self): + self.moe_method = AscendW8A8FusedMoEMethod() + self.num_experts = 4 + self.intermediate_size = 64 + self.hidden_size = 128 + self.dtype = torch.float32 + + def test_init(self): + self.assertTrue(self.moe_method.transpose_weight) + + def test_get_weight(self): + weights = self.moe_method.get_weight( + num_experts=self.num_experts, + intermediate_size_per_partition=self.intermediate_size, + hidden_sizes=self.hidden_size, + params_dtype=self.dtype) + + assert "w13_weight" in weights, f"w13_weight not in {weights}" + assert "w2_weight" in weights, f"w2_weight not in {weights}" + self.assertEqual( + weights["w13_weight"].shape, + (self.num_experts, 2 * self.intermediate_size, self.hidden_size)) + self.assertEqual( + weights["w2_weight"].shape, + (self.num_experts, self.hidden_size, self.intermediate_size)) + self.assertEqual(weights["w13_weight"].dtype, torch.int8) + self.assertEqual(weights["w2_weight"].dtype, torch.int8) + self.assertFalse(weights["w13_weight"].requires_grad) + self.assertFalse(weights["w2_weight"].requires_grad) + + def test_get_dynamic_quant_param(self): + quant_params = self.moe_method.get_dynamic_quant_param( + num_experts=self.num_experts, + intermediate_size_per_partition=self.intermediate_size, + hidden_sizes=self.hidden_size, + params_dtype=self.dtype) + + expected_params = [ + "w13_weight_scale", "w13_weight_offset", "w2_weight_scale", + "w2_weight_offset", "w2_deq_scale", "w13_deq_scale", + "w2_input_scale", "w13_input_scale", "w2_input_offset", + "w13_input_offset", "quant_bias" + ] + + for param in expected_params: + assert param in quant_params, f"{param} not in {quant_params}" + + # Check some sample shapes + self.assertEqual(quant_params["w13_weight_scale"].shape, + (self.num_experts, 2 * self.intermediate_size, 1)) + self.assertEqual(quant_params["w2_input_offset"].shape, + (self.num_experts, 1)) + self.assertEqual(quant_params["quant_bias"].shape, + (self.num_experts, self.hidden_size)) + + @patch('vllm_ascend.quantization.w8a8.select_experts') + @patch('vllm_ascend.quantization.w8a8.fused_experts') + def test_apply_with_other_expert_count(self, mock_fused_experts, + mock_select_experts): + # Setup + mock_layer = MagicMock() + x = torch.randn(32, self.hidden_size) + router_logits = torch.randn(32, 128) # 128 experts + top_k = 2 + + # Mock return values + mock_select_experts.return_value = (torch.randn(32, top_k), + torch.randint(0, 128, (32, top_k))) + mock_fused_experts.return_value = torch.randn(32, self.hidden_size) + + # Test + result = self.moe_method.apply(layer=mock_layer, + x=x, + router_logits=router_logits, + top_k=top_k, + renormalize=True, + global_num_experts=128) + + # Assertions + mock_select_experts.assert_called_once() + mock_fused_experts.assert_called_once() + self.assertEqual(result.shape, (32, self.hidden_size)) + + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True) + @patch('vllm_ascend.quantization.w8a8.select_experts') + @patch('vllm_ascend.quantization.w8a8.fused_experts_310p') + def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts, + mock_is_310p): + # Setup + mock_layer = MagicMock() + x = torch.randn(32, self.hidden_size) + router_logits = torch.randn(32, 128) # 128 experts + top_k = 2 + + # Mock return values + mock_select_experts.return_value = (torch.randn(32, top_k), + torch.randint(0, 128, (32, top_k))) + mock_fused_experts_310p.return_value = torch.randn( + 32, self.hidden_size) + + # Test + result = self.moe_method.apply(layer=mock_layer, + x=x, + router_logits=router_logits, + top_k=top_k, + renormalize=True, + global_num_experts=128) + + # Assertions + mock_select_experts.assert_called_once() + mock_fused_experts_310p.assert_called_once() + self.assertEqual(result.shape, (32, self.hidden_size)) + + +class TestAscendC8KVCacheMethod(TestBase): + + def setUp(self): + self.layer = MagicMock() + self.layer.num_kv_heads = 4 + self.layer.head_size = 64 + self.layer.num_heads = 8 + self.layer._k_scale_float = 1.0 + self.layer._v_scale_float = 1.0 + self.method = AscendC8KVCacheMethod() + + self.attention_type = MagicMock() + self.attention_type.DECODER = "decoder" + self.attention_type.ENCODER = "encoder" + + def test_create_weights(self): + """测试 create_weights 是否正确注册参数""" + AscendC8KVCacheMethod.create_weights(self.layer) + + self.layer.register_parameter.assert_any_call("key_antiquant_scale", + unittest.mock.ANY) + self.layer.register_parameter.assert_any_call("value_antiquant_scale", + unittest.mock.ANY) + + calls = self.layer.register_parameter.call_args_list + + for call in calls: + args, kwargs = call + param = kwargs.get('parameter', args[1] if len(args) > 1 else None) + + expected_shape = (self.layer.num_kv_heads * self.layer.head_size, ) + self.assertEqual(param.shape, expected_shape) + + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False) + def test_process_weights_after_loading_not_310p(self, mock_is_310p): + key_data = torch.ones(4 * 64) + value_data = torch.ones(4 * 64) * 2 + + self.layer.key_antiquant_scale.data = key_data + self.layer.value_antiquant_scale.data = value_data + + self.method.process_weights_after_loading(self.layer) + + self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2)) + + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True) + def test_process_weights_after_loading_is_310p(self, mock_is_310p): + key_data = torch.ones(4 * 64) + value_data = torch.ones(4 * 64) * 2 + + self.layer.key_antiquant_scale.data = key_data + self.layer.value_antiquant_scale.data = value_data + + self.method.process_weights_after_loading(self.layer) + + self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2)) + + @patch('torch_npu.npu_scatter_nd_update_') + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_decode_only(self, mock_quant, mock_scatter): + + num_tokens = 2 + query = torch.randn(num_tokens, + self.layer.num_heads * self.layer.head_size) + key = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.DecodeOnly + attn_metadata.seq_lens = [10, 10] + attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]]) + attn_metadata.slot_mapping = torch.tensor([0, 1]) + attn_metadata.attn_mask = None + + block_size = 16 + key_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + value_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + kv_cache = (key_cache, value_cache) + + mock_quant.side_effect = [key, value] + + self.layer.key_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.layer.value_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.method.process_weights_after_loading(self.layer) + + expected_output = torch.randn( + num_tokens, self.layer.num_heads * self.layer.head_size) + with patch('torch_npu.npu_incre_flash_attention', + return_value=expected_output): + result = self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, + self.attention_type.DECODER, 1.0, + output) + + self.assertEqual(mock_quant.call_count, 2) + self.assertEqual(mock_scatter.call_count, 2) + self.assertTrue(torch.equal(result, expected_output)) + + @patch('torch_npu.npu_scatter_nd_update_') + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_attn_metadata_without_decode(self, mock_quant, + mock_scatter): + + num_tokens = 2 + query = torch.randn(num_tokens, + self.layer.num_heads * self.layer.head_size) + key = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock(spec=[ + 'attn_state', 'seq_lens', 'block_tables', 'slot_mapping', + 'attn_mask' + ]) + attn_metadata.attn_state = AscendAttentionState.DecodeOnly + attn_metadata.seq_lens = [10, 10] + attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]]) + attn_metadata.slot_mapping = torch.tensor([0, 1]) + attn_metadata.attn_mask = None + + block_size = 16 + key_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + value_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + kv_cache = (key_cache, value_cache) + + mock_quant.side_effect = [key, value] + + self.layer.key_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.layer.value_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.method.process_weights_after_loading(self.layer) + + expected_output = torch.randn( + num_tokens, self.layer.num_heads * self.layer.head_size) + with patch('torch_npu.npu_incre_flash_attention', + return_value=expected_output): + result = self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, + self.attention_type.DECODER, 1.0, + output) + + self.assertEqual(mock_quant.call_count, 2) + self.assertEqual(mock_scatter.call_count, 2) + self.assertTrue(torch.equal(result, expected_output)) + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('torch_npu._npu_flash_attention') + def test_apply_prefill_no_cache(self, mock_flash, mock_quant): + """Test apply method in prefill no-cache mode""" + + num_tokens = 2 + query = torch.randn(num_tokens, + self.layer.num_heads * self.layer.head_size) + key = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.PrefillNoCache + attn_metadata.seq_lens = [10, 10] + attn_metadata.attn_mask = torch.ones(2, 2) + + kv_cache = (torch.tensor([]), torch.tensor([])) + mock_quant.return_value = key + + result = self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, self.attention_type.DECODER, + 1.0, output) + + # Check that flash attention was called + mock_flash.assert_called_once() + + # Check output shape + self.assertEqual( + result.shape, + (num_tokens, self.layer.num_heads * self.layer.head_size)) + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_unsupported_attention_type(self, mock_quant): + + query = torch.randn(1, self.layer.num_heads * self.layer.head_size) + key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + mock_quant.return_value = key + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.PrefillNoCache + + with self.assertRaises(NotImplementedError) as cm: + self.method.apply(self.layer, query, key, value, (None, None), + attn_metadata, self.attention_type.ENCODER, 1.0, + output) + + assert "Encoder self-attention" in str( + cm.exception), f"Encoder self-attention not in {str(cm.exception)}" + assert "not implemented" in str( + cm.exception), f"not implemented not in{str(cm.exception)}" + + mock_quant.assert_not_called() + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_unsupported_attention_state(self, mock_quant): + """Test apply with unsupported attention state""" + query = torch.randn(1, self.layer.num_heads * self.layer.head_size) + key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock() + attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit + mock_quant.return_value = key + kv_cache = (torch.tensor([]), torch.tensor([])) + + with self.assertRaises(NotImplementedError): + self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, self.attention_type.DECODER, 1.0, + output) + + +class TestFusedExperts(TestBase): + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('vllm_ascend.quantization.w8a8.get_ep_group') + @patch('torch_npu.npu_moe_init_routing_v2') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + @patch('torch_npu.npu_moe_finalize_routing') + def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu, + mock_group_matmul, + mock_init_routing, + mock_get_ep_group, + mock_quant_per_tensor): + num_tokens = 32 + hidden_size = 128 + intermediate_size = 256 + num_experts = 4 + top_k = 2 + + hidden_states = torch.randn(num_tokens, hidden_size) + + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size) + w1_scale = torch.tensor([0.1]) + w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]]) + w1_input_offset = torch.tensor([0]) + + w2 = torch.randn(num_experts, hidden_size, intermediate_size) + w2_scale = torch.tensor([0.1]) + w2_input_scale = torch.tensor([0.2]) + w2_input_offset = torch.tensor([0]) + + topk_weights = torch.rand(num_tokens, top_k) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k)) + expert_map = torch.arange(num_experts) + + mock_get_ep_group.return_value.world_size = 8 + + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + hidden_states.shape, + dtype=torch.int8) + + mock_init_routing.return_value = (torch.randn(num_tokens * top_k, + hidden_size), + torch.arange(num_tokens * top_k), + torch.tensor([num_tokens // 2] * 2), + torch.tensor(1.0)) + + mock_group_matmul.side_effect = [[ + torch.randn(num_tokens * top_k, intermediate_size * 2) + ], [torch.randn(num_tokens * top_k, hidden_size)]] + + mock_swiglu.return_value = torch.randn(num_tokens * top_k, + intermediate_size) + + expected_output = torch.randn(num_tokens, hidden_size) + mock_finalize.return_value = expected_output + + output = fused_experts( + hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w1_input_scale=w1_input_scale, + w1_input_offset=w1_input_offset, + w2=w2, + w2_scale=w2_scale, + w2_input_scale=w2_input_scale, + w2_input_offset=w2_input_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=num_experts, + expert_map=expert_map, + ) + + mock_init_routing.assert_called_once() + + self.assertEqual(mock_group_matmul.call_count, 2) + + self.assertEqual(output.shape, (num_tokens, hidden_size)) + + mock_finalize.assert_called_once() + + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('vllm_ascend.quantization.w8a8.get_ep_group') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + def test_fused_experts_without_expert_map(self, mock_swiglu, + mock_group_matmul, + mock_get_ep_group, + mock_quant_per_tensor): + num_tokens = 16 + hidden_size = 64 + intermediate_size = 128 + num_experts = 8 + top_k = 1 + + hidden_states = torch.randn(num_tokens, hidden_size) + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size) + w2 = torch.randn(num_experts, hidden_size, intermediate_size) + topk_weights = torch.rand(num_tokens, top_k) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k)) + + mock_get_ep_group.return_value.world_size = 8 + + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + hidden_states.shape, + dtype=torch.int8) + mock_group_matmul.side_effect = [[ + torch.randn(num_tokens * top_k, intermediate_size * 2) + ], [torch.randn(num_tokens * top_k, hidden_size)]] + mock_swiglu.return_value = torch.randn(num_tokens * top_k, + intermediate_size) + with self.assertRaises(NotImplementedError): + fused_experts( + hidden_states=hidden_states, + w1=w1, + w1_scale=torch.tensor([0.1]), + w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]), + w1_input_offset=torch.tensor([0]), + w2=w2, + w2_scale=torch.tensor([0.1]), + w2_input_scale=torch.tensor([0.1]), + w2_input_offset=torch.tensor([0]), + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=num_experts, + expert_map=None, + ) + + +class TestFusedExperts310(TestBase): + + @patch('torch_npu.npu_quant_grouped_matmul_dequant') + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('vllm_ascend.quantization.w8a8.get_ep_group') + @patch('torch_npu.npu_swiglu') + def test_fused_experts_310p_with_expert_map(self, mock_swiglu, + mock_get_ep_group, + mock_quant_per_tensor, + mock_matmul_dequant): + num_tokens = 32 + hidden_size = 128 + intermediate_size = 256 + num_experts = 4 + top_k = 1 + + hidden_states = torch.randn(num_tokens, hidden_size) + + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size) + w1_scale = torch.tensor([0.1]) + w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]]) + + w2 = torch.randn(num_experts, hidden_size, intermediate_size) + w2_scale = torch.tensor([0.1]) + w2_input_scale = torch.tensor([0.2]) + + topk_weights = torch.rand(num_tokens, top_k) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k)) + expert_map = torch.arange(num_experts) + + mock_get_ep_group.return_value.world_size = 1 + + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + hidden_states.shape, + dtype=torch.int8) + + mock_swiglu.return_value = torch.randn(num_tokens * top_k, + intermediate_size) + + mock_matmul_dequant.return_value = hidden_states + + output = fused_experts_310p( + hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w1_input_scale=w1_input_scale, + w2=w2, + w2_scale=w2_scale, + w2_input_scale=w2_input_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=num_experts, + expert_map=expert_map, + ) + + self.assertEqual(output.shape, (num_tokens, hidden_size)) + self.assertEqual(mock_matmul_dequant.call_count, 2) + + +class TestSelectExperts(TestBase): + + def setUp(self): + # Common test data + self.num_tokens = 10 + self.hidden_size = 32 + self.num_experts = 8 + self.top_k = 2 + + self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) + self.router_logits = torch.randn(self.num_tokens, self.num_experts) + + def test_softmax_scoring(self): + """Test softmax scoring function""" + + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="softmax") + + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + + def test_sigmoid_scoring(self): + """Test sigmoid scoring function""" + + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid") + + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + + def test_invalid_scoring_func(self): + """Test invalid scoring function raises ValueError""" + with self.assertRaises(ValueError): + select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="invalid_func") + + @patch('torch.topk') + def test_grouped_topk(self, mock_topk): + """Test grouped topk functionality""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long)) + + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) + + mock_topk.assert_called() + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.dtype, torch.int32) + + @patch('vllm_ascend.quantization.w8a8.native_grouped_topk') + def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): + """Test grouped topk with expert score correction bias""" + mock_grouped_topk.return_value = torch.ones(self.num_tokens, + self.num_experts) + + e_score_correction_bias = torch.randn(self.num_experts) + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2, + e_score_correction_bias=e_score_correction_bias) + + mock_grouped_topk.assert_called_once() + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + + def test_custom_routing_function(self): + """Test custom routing function""" + mock_custom_routing = MagicMock() + mock_custom_routing.return_value = (torch.ones(self.num_tokens, + self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.int32)) + + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + custom_routing_function=mock_custom_routing) + + mock_custom_routing.assert_called_once() + self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) + self.assertEqual(ids.dtype, torch.int32) + + @patch('torch.topk') + def test_renormalize(self, mock_topk): + """Test weight renormalization""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long)) + + weights, _ = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=True, + ) + + # Check if weights are normalized (sum to 1 for each token) + sums = weights.sum(dim=-1) + self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) + + @patch('torch.topk') + def test_output_dtypes(self, mock_topk): + """Test output dtypes""" + mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), + torch.zeros(self.num_tokens, + self.top_k, + dtype=torch.long)) + + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + ) + + self.assertEqual(weights.dtype, self.hidden_states.dtype) + self.assertEqual(ids.dtype, torch.int32) + + +class TestNativeGroupedTopkPartialMock(TestBase): + + def test_basic_group_selection(self): + topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6], + [0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1], + [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], + [0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]], + dtype=torch.float32) + + expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]]) + + with patch('torch.topk', + return_value=(None, expected_topk_indices)) as mock_topk: + result = native_grouped_topk(topk_weights=topk_weights, + num_expert_group=2, + topk_group=2) + + mock_topk.assert_called_once() + + expected_result = topk_weights + self.assertTrue(torch.allclose(result, expected_result)) + + def test_partial_group_selection(self): + + topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6], + [0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]]) + + expected_topk_indices = torch.tensor([[0], [1]]) + + with patch('torch.topk', return_value=(None, expected_topk_indices)): + result = native_grouped_topk(topk_weights=topk_weights, + num_expert_group=2, + topk_group=1) + + expected_result = torch.tensor( + [[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]]) + self.assertTrue(torch.allclose(result, expected_result)) + + def test_single_group(self): + topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]]) + + expected_topk_indices = torch.tensor([[0], [0]]) + + with patch('torch.topk', return_value=(None, expected_topk_indices)): + result = native_grouped_topk(topk_weights=topk_weights, + num_expert_group=1, + topk_group=1) + self.assertTrue(result.numel() > 0) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 5ec4dd72cc..a123790dbd 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -1,16 +1,32 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + import os -import unittest from unittest import mock from transformers import PretrainedConfig from vllm.config import ModelConfig, VllmConfig -from vllm_ascend.ascend_config import (check_ascend_config, +from tests.ut.base import TestBase +from vllm_ascend.ascend_config import (_check_torchair_supported, + check_ascend_config, clear_ascend_config, get_ascend_config, init_ascend_config) -class TestAscendConfig(unittest.TestCase): +class TestAscendConfig(TestBase): @staticmethod def _clean_up_ascend_config(func): @@ -242,3 +258,10 @@ def test_check_ascend_config_wrong_case(self): test_vllm_config.model_config = fake_model_config init_ascend_config(test_vllm_config) check_ascend_config(test_vllm_config, False) + + def test_check_torchair_supported(self): + test_cases = [('deepseek_v3', True), ('PanguProMoE', True), + ('qwen', False), ('llama', False)] + for model_type, expected_output in test_cases: + self.assertEqual(_check_torchair_supported(model_type), + expected_output) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py new file mode 100644 index 0000000000..c09964a745 --- /dev/null +++ b/tests/ut/test_platform.py @@ -0,0 +1,717 @@ +import importlib +import unittest +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import PrefixStore +from vllm.config import CompilationLevel +from vllm.platforms import PlatformEnum + +from tests.ut.base import TestBase +from vllm_ascend.platform import NPUPlatform +from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD + + +class TestNPUPlatform(TestBase): + + def setUp(self): + self.platform = NPUPlatform() + + self.mock_vllm_config = MagicMock() + self.mock_vllm_config.compilation_config = MagicMock() + self.mock_vllm_config.model_config = MagicMock() + self.mock_vllm_config.parallel_config = MagicMock() + self.mock_vllm_config.cache_config = MagicMock() + self.mock_vllm_config.scheduler_config = MagicMock() + self.mock_vllm_config.speculative_config = None + + self.mock_ascend_config = MagicMock() + self.mock_ascend_config.expert_tensor_parallel_size = 0 + self.mock_ascend_config.torchair_graph_config.enabled = False + self.mock_ascend_config.ascend_scheduler_config.enabled = False + + def test_class_variables(self): + self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT) + self.assertEqual(NPUPlatform.device_name, "npu") + self.assertEqual(NPUPlatform.device_type, "npu") + self.assertEqual(NPUPlatform.simple_compile_backend, "eager") + self.assertEqual(NPUPlatform.ray_device_key, "NPU") + self.assertEqual(NPUPlatform.device_control_env_var, + "ASCEND_RT_VISIBLE_DEVICES") + self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1") + self.assertEqual(NPUPlatform.supported_quantization, + [ASCEND_QUATIZATION_METHOD]) + + def test_is_sleep_mode_available(self): + self.assertTrue(self.platform.is_sleep_mode_available()) + + @patch("vllm_ascend.utils.adapt_patch") + @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + def test_pre_register_and_update_with_parser(self, mock_quant_config, + mock_adapt_patch): + mock_parser = MagicMock() + mock_action = MagicMock() + mock_action.choices = ["awq", "gptq"] + mock_parser._option_string_actions = {"--quantization": mock_action} + + self.platform.pre_register_and_update(mock_parser) + + mock_adapt_patch.assert_called_once_with(is_global_patch=True) + + self.assertTrue(ASCEND_QUATIZATION_METHOD in mock_action.choices) + self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend + + @patch("vllm_ascend.utils.adapt_patch") + @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + def test_pre_register_and_update_without_parser(self, mock_quant_config, + mock_adapt_patch): + self.platform.pre_register_and_update(None) + + mock_adapt_patch.assert_called_once_with(is_global_patch=True) + + @patch("vllm_ascend.utils.adapt_patch") + @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + def test_pre_register_and_update_with_parser_no_quant_action( + self, mock_quant_config, mock_adapt_patch): + mock_parser = MagicMock() + mock_parser._option_string_actions = {} + + self.platform.pre_register_and_update(mock_parser) + + mock_adapt_patch.assert_called_once_with(is_global_patch=True) + + @patch("vllm_ascend.utils.adapt_patch") + @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + def test_pre_register_and_update_with_existing_ascend_quant( + self, mock_quant_config, mock_adapt_patch): + mock_parser = MagicMock() + mock_action = MagicMock() + mock_action.choices = ["awq", ASCEND_QUATIZATION_METHOD] + mock_parser._option_string_actions = {"--quantization": mock_action} + + self.platform.pre_register_and_update(mock_parser) + + mock_adapt_patch.assert_called_once_with(is_global_patch=True) + self.assertEqual(len(mock_action.choices), 2) + + def test_get_device_capability(self): + self.assertIsNone(self.platform.get_device_capability(device_id=0)) + + @patch("torch.npu.get_device_name") + def test_get_device_name(self, mock_get_device_name): + device_id = 0 + device_name = "Ascend910B2" + mock_get_device_name.return_value = device_name + self.assertEqual(self.platform.get_device_name(device_id), device_name) + mock_get_device_name.assert_called_once_with(0) + + def test_is_async_output_supported(self): + self.assertTrue( + self.platform.is_async_output_supported(enforce_eager=None)) + self.assertTrue( + self.platform.is_async_output_supported(enforce_eager=True)) + self.assertTrue( + self.platform.is_async_output_supported(enforce_eager=False)) + + @patch("torch.inference_mode") + def test_inference_mode(self, mock_inference_mode): + mock_inference_mode.return_value = None + self.assertIsNone(self.platform.inference_mode()) + mock_inference_mode.assert_called_once() + + @patch("torch.npu.set_device") + def test_set_device_normal(self, mock_set_device): + device = torch.device("npu:0") + self.platform.set_device(device) + mock_set_device.assert_called_once_with(device) + + @patch("torch.npu.set_device", + side_effect=RuntimeError("Device not available")) + def test_set_device_failure(self, mock_set_device): + device = torch.device("npu:0") + with self.assertRaises(RuntimeError): + self.platform.set_device(device) + mock_set_device.assert_called_once_with(device) + + @patch("torch.npu.empty_cache") + def test_empty_cache_normal(self, mock_empty_cache): + self.platform.empty_cache() + mock_empty_cache.assert_called_once() + + @patch("torch.npu.empty_cache", + side_effect=RuntimeError("Cache clearing failed")) + def test_empty_cache_failure(self, mock_empty_cache): + with self.assertRaises(RuntimeError): + self.platform.empty_cache() + mock_empty_cache.assert_called_once() + + @patch("torch.npu.synchronize") + def test_synchronize_normal(self, mock_synchronize): + self.platform.synchronize() + mock_synchronize.assert_called_once() + + @patch("torch.npu.synchronize", + side_effect=RuntimeError("Synchronization failed")) + def test_synchronize_failure(self, mock_synchronize): + with self.assertRaises(RuntimeError): + self.platform.synchronize() + mock_synchronize.assert_called_once() + + @patch("torch.npu.mem_get_info") + def test_mem_get_info_normal(self, mock_mem_get_info): + free_memory_size = 1024 + total_memory_size = 2048 + memory_info = (free_memory_size, total_memory_size) + mock_mem_get_info.return_value = memory_info + result = self.platform.mem_get_info() + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + self.assertEqual(result, memory_info) + mock_mem_get_info.assert_called_once() + + @patch("torch.npu.mem_get_info", + side_effect=RuntimeError("NPU not available")) + def test_mem_get_info_failure(self, mock_mem_get_info): + with self.assertRaises(RuntimeError): + self.platform.mem_get_info() + mock_mem_get_info.assert_called_once() + + @patch("gc.collect") + @patch("torch.npu.empty_cache") + @patch("torch.npu.reset_peak_memory_stats") + def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache, + mock_gc_collect): + self.platform.clear_npu_memory() + + mock_gc_collect.assert_called_once() + mock_empty_cache.assert_called_once() + mock_reset_stats.assert_called_once() + + @patch("gc.collect", side_effect=Exception("GC failed")) + @patch("torch.npu.empty_cache") + @patch("torch.npu.reset_peak_memory_stats") + def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats, + mock_empty_cache, + mock_gc_collect): + with self.assertRaises(Exception): + self.platform.clear_npu_memory() + + mock_gc_collect.assert_called_once() + mock_empty_cache.assert_not_called() + mock_reset_stats.assert_not_called() + + @patch("gc.collect") + @patch("torch.npu.empty_cache", + side_effect=RuntimeError("Cache clear failed")) + @patch("torch.npu.reset_peak_memory_stats") + def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats, + mock_empty_cache, + mock_gc_collect): + with self.assertRaises(RuntimeError): + self.platform.clear_npu_memory() + + mock_gc_collect.assert_called_once() + mock_empty_cache.assert_called_once() + mock_reset_stats.assert_not_called() + + @patch("gc.collect") + @patch("torch.npu.empty_cache") + @patch("torch.npu.reset_peak_memory_stats", + side_effect=RuntimeError("Reset failed")) + def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats, + mock_empty_cache, + mock_gc_collect): + with self.assertRaises(RuntimeError): + self.platform.clear_npu_memory() + + mock_gc_collect.assert_called_once() + mock_empty_cache.assert_called_once() + mock_reset_stats.assert_called_once() + + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.utils.update_aclgraph_sizes") + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("os.environ", {}) + def test_check_and_update_config_basic_config_update( + self, mock_is_310p, mock_update_acl, mock_init_ascend, + mock_check_ascend): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.parallel_config.enable_expert_parallel = False + + # Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used. + # Without this reload, when calling self.platform.check_and_update_config, + # it would execute the original unmocked init_ascend_config method, causing the unit test to fail. + from vllm_ascend import platform + + importlib.reload(platform) + + self.platform.check_and_update_config(self.mock_vllm_config) + + mock_init_ascend.assert_called_once_with(self.mock_vllm_config) + mock_check_ascend.assert_called_once() + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_expert_parallel_enabled( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.parallel_config.enable_expert_parallel = True + self.mock_vllm_config.parallel_config.tensor_parallel_size = 2 + self.mock_vllm_config.parallel_config.world_size_across_dp = 4 + + from vllm_ascend import platform + + importlib.reload(platform) + + self.platform.check_and_update_config(self.mock_vllm_config) + + self.assertEqual( + self.mock_vllm_config.parallel_config.expert_tensor_parallel_size, + 1) + self.assertEqual( + self.mock_vllm_config.parallel_config.expert_parallel_size, + self.mock_vllm_config.parallel_config.world_size_across_dp, + ) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_no_model_config_warning( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.model_config = None + + with self.assertLogs(logger="vllm", level="WARNING") as cm: + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertTrue("Model config is missing" in cm.output[0]) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_enforce_eager_mode( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.model_config.enforce_eager = True + + with self.assertLogs(logger="vllm", level="INFO") as cm: + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertTrue("Compilation disabled, using eager mode by default" in + cm.output[0]) + self.assertEqual( + self.mock_vllm_config.compilation_config.level, + CompilationLevel.NO_COMPILATION, + ) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_unsupported_compilation_level( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.model_config.enforce_eager = False + self.mock_vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE + + with self.assertLogs(logger="vllm", level="WARNING") as cm: + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertTrue("NPU does not support" in cm.output[0]) + self.assertEqual( + self.mock_vllm_config.compilation_config.level, + CompilationLevel.NO_COMPILATION, + ) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_torchair_enabled_compilation( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + self.mock_ascend_config.torchair_graph_config.enabled = True + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.model_config.enforce_eager = False + self.mock_vllm_config.compilation_config.level = CompilationLevel.PIECEWISE + + with self.assertLogs(logger="vllm", level="INFO") as cm: + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertTrue("Torchair compilation enabled" in cm.output[0]) + self.assertEqual( + self.mock_vllm_config.compilation_config.level, + CompilationLevel.NO_COMPILATION, + ) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_cache_config_block_size( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.cache_config.block_size = None + self.mock_vllm_config.cache_config.enable_prefix_caching = True + + from vllm_ascend import platform + + importlib.reload(platform) + + self.platform.check_and_update_config(self.mock_vllm_config) + + self.assertEqual(self.mock_vllm_config.cache_config.block_size, 128) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm.envs.VLLM_USE_V1", True) + def test_check_and_update_config_v1_worker_class_selection( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.parallel_config.worker_cls = "auto" + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + + self.assertEqual( + self.mock_vllm_config.parallel_config.worker_cls, + "vllm_ascend.worker.worker_v1.NPUWorker", + ) + + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm.envs.VLLM_USE_V1", False) + def test_check_and_update_config_speculative_worker_config( + self, mock_init_ascend, mock_check_ascend): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.speculative_config = MagicMock() + self.mock_vllm_config.speculative_config.disable_logprobs = True + self.mock_vllm_config.parallel_config.worker_cls = "auto" + + with patch.dict("os.environ", {}): + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + import os + + self.assertEqual(os.environ.get("ACL_OP_INIT_MODE"), "1") + self.assertEqual( + self.mock_vllm_config.parallel_config.worker_cls, + "vllm.spec_decode.spec_decode_worker.create_spec_worker", + ) + self.assertEqual( + self.mock_vllm_config.parallel_config.sd_worker_cls, + "vllm_ascend.worker.worker.NPUWorker", + ) + + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm.envs.VLLM_USE_V1", False) + def test_check_and_update_config_multi_step_worker_config( + self, mock_init_ascend, mock_check_ascend): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.scheduler_config.is_multi_step = True + self.mock_vllm_config.parallel_config.worker_cls = "auto" + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertEqual( + self.mock_vllm_config.parallel_config.worker_cls, + "vllm_ascend.worker.multi_step_worker.MultiStepWorker", + ) + + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm.envs.VLLM_USE_V1", False) + def test_check_and_update_config_default_worker_config( + self, mock_init_ascend, mock_check_ascend): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.parallel_config.worker_cls = "auto" + self.mock_vllm_config.scheduler_config.is_multi_step = False + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertEqual( + self.mock_vllm_config.parallel_config.worker_cls, + "vllm_ascend.worker.worker.NPUWorker", + ) + + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.utils.is_310p", return_value=True) + @patch("vllm.envs.VLLM_USE_V1", True) + def test_check_and_update_config_310p_no_custom_ops( + self, mock_is_310p, mock_init_ascend, mock_check_ascend): + mock_init_ascend.return_value = self.mock_ascend_config + self.mock_vllm_config.compilation_config.custom_ops = [] + + from vllm_ascend import platform + + importlib.reload(platform) + + self.platform.check_and_update_config(self.mock_vllm_config) + self.assertEqual(self.mock_vllm_config.compilation_config.custom_ops, + []) + + @patch("vllm_ascend.utils.is_310p", return_value=False) + @patch("vllm_ascend.ascend_config.check_ascend_config") + @patch("vllm_ascend.ascend_config.init_ascend_config") + def test_check_and_update_config_ascend_scheduler_config( + self, mock_init_ascend, mock_check_ascend, mock_is_310p): + self.mock_ascend_config.ascend_scheduler_config.enabled = True + mock_init_ascend.return_value = self.mock_ascend_config + + with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig" + ) as mock_scheduler: + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.check_and_update_config(self.mock_vllm_config) + mock_scheduler.initialize_from_config.assert_called_once() + + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=True, + use_mla=True, + ) + self.assertEqual(result, + "vllm_ascend.attention.mla_v1.AscendMLABackend") + + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_and_torchair(self, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = True + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=True, + use_mla=False, + ) + self.assertEqual( + result, + "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" + ) + + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=True, + use_mla=False, + ) + self.assertEqual( + result, + "vllm_ascend.attention.attention_v1.AscendAttentionBackend") + + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=False, + use_mla=True, + ) + self.assertEqual( + result, + "vllm_ascend.attention.attention.AscendMLAAttentionBackend") + + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=False, + use_mla=False, + ) + self.assertEqual( + result, "vllm_ascend.attention.attention.AscendAttentionBackend") + + def test_get_punica_wrapper(self): + result = self.platform.get_punica_wrapper() + self.assertEqual( + result, + "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU") + + @patch("torch.npu.reset_peak_memory_stats") + @patch("torch.npu.max_memory_allocated") + def test_get_current_memory_usage_with_specific_device( + self, mock_max_memory, mock_reset_stats): + max_memory_allocated_result = 1024.0 + mock_max_memory.return_value = max_memory_allocated_result + test_device = torch.device("npu:0") + result = self.platform.get_current_memory_usage(device=test_device) + + mock_reset_stats.assert_called_once_with(test_device) + mock_max_memory.assert_called_once_with(test_device) + self.assertEqual(result, max_memory_allocated_result) + + @patch("torch.npu.reset_peak_memory_stats") + @patch("torch.npu.max_memory_allocated") + def test_get_current_memory_usage_with_default_device( + self, mock_max_memory, mock_reset_stats): + max_memory_allocated_result = 1024.0 + mock_max_memory.return_value = max_memory_allocated_result + + result = self.platform.get_current_memory_usage() + + mock_reset_stats.assert_called_once_with(None) + mock_max_memory.assert_called_once_with(None) + self.assertEqual(result, max_memory_allocated_result) + + @patch("torch.npu.reset_peak_memory_stats", + side_effect=RuntimeError("Device error")) + @patch("torch.npu.max_memory_allocated") + def test_get_current_memory_usage_when_reset_stats_fails( + self, mock_max_memory, mock_reset_stats): + with self.assertRaises(RuntimeError): + self.platform.get_current_memory_usage() + mock_reset_stats.assert_called_once() + mock_max_memory.assert_not_called() + + @patch("torch.npu.reset_peak_memory_stats") + @patch( + "torch.npu.max_memory_allocated", + side_effect=RuntimeError("Memory query failed"), + ) + def test_get_current_memory_usage_when_query_fails(self, mock_max_memory, + mock_reset_stats): + with self.assertRaises(RuntimeError): + self.platform.get_current_memory_usage() + mock_reset_stats.assert_called_once() + mock_max_memory.assert_called_once() + + def test_get_device_communicator_cls_returns_correct_value(self): + self.assertEqual( + self.platform.get_device_communicator_cls(), + "vllm_ascend.distributed.communicator.NPUCommunicator", + ) + + def test_is_pin_memory_available_returns_true(self): + self.assertTrue(self.platform.is_pin_memory_available()) + + def test_supports_v1(self): + from vllm.config import ModelConfig + + mock_config = MagicMock(spec=ModelConfig) + self.assertTrue(self.platform.supports_v1(mock_config)) + + def test_get_piecewise_backend_cls_returns_correct_value(self): + self.assertEqual( + self.platform.get_piecewise_backend_cls(), + "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend", + ) + + @patch("torch.distributed.is_hccl_available", return_value=True) + @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL") + @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options") + @patch("torch.distributed.ProcessGroup") + def test_successful_initialization(self, mock_pg, mock_options_cls, + mock_pg_hccl, _): + mock_prefix = MagicMock(spec=PrefixStore) + mock_options = MagicMock(spec=ProcessGroup.Options) + mock_options_cls.return_value = mock_options + mock_backend = MagicMock() + mock_pg_hccl.return_value = mock_backend + group_rank = 0 + group_size = 4 + + mock_pg_instance = MagicMock(spec=ProcessGroup) + mock_pg.return_value = mock_pg_instance + + # Use importlib.reload() to force-reload the platform module and ensure the mocked ProcessGroup is used. + # Without this reload, when executing self.platform.stateless_init_device_torch_dist_pg(), + # it would invoke the original unmocked ProcessGroup implementation instead of our test mock, + # which would cause the unit test to fail. + from vllm_ascend import platform + + importlib.reload(platform) + + result = self.platform.stateless_init_device_torch_dist_pg( + backend="hccl", + prefix_store=mock_prefix, + group_rank=group_rank, + group_size=group_size, + timeout=timedelta(seconds=30), + ) + + mock_pg.assert_called_once_with(mock_prefix, group_rank, group_size, + unittest.mock.ANY) + mock_pg_hccl.assert_called_once_with(mock_prefix, group_rank, + group_size, unittest.mock.ANY) + mock_backend._set_sequence_number_for_group.assert_called_once() + mock_pg_instance._register_backend.assert_called_once_with( + torch.device("npu"), unittest.mock.ANY, mock_backend) + self.assertEqual(result, mock_pg_instance) + + @patch("torch.distributed.is_hccl_available", return_value=False) + def test_hccl_unavailable(self, _): + with self.assertRaises(AssertionError): + from vllm_ascend import platform + + importlib.reload(platform) + self.platform.stateless_init_device_torch_dist_pg( + backend="hccl", + prefix_store=MagicMock(), + group_rank=0, + group_size=4, + timeout=timedelta(seconds=30), + ) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py new file mode 100644 index 0000000000..5ddc59dea5 --- /dev/null +++ b/tests/ut/test_utils.py @@ -0,0 +1,355 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import math +import os +import unittest +from threading import Lock +from unittest import mock + +import torch +from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig, + VllmConfig) + +from tests.ut.base import TestBase +from vllm_ascend import utils + + +class TestUtils(TestBase): + + def test_is_310p(self): + utils._IS_310P = None + with mock.patch("vllm_ascend._build_info.__soc_version__", + "Ascend310P3"): + self.assertTrue(utils.is_310p()) + utils._IS_310P = None + with mock.patch("vllm_ascend._build_info.__soc_version__", + "Ascend910P1"): + self.assertFalse(utils.is_310p()) + + def test_sleep_mode_enabled(self): + utils._SLEEP_MODE_ENABLED = None + with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__", + True): + self.assertTrue(utils.sleep_mode_enabled()) + utils._SLEEP_MODE_ENABLED = None + with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__", + False): + self.assertFalse(utils.sleep_mode_enabled()) + + def test_nd_to_nz_2d(self): + # can be divided by 16 + input_tensor = torch.randn(32, 64) + output = utils.nd_to_nz_2d(input_tensor) + self.assertEqual(output.shape[0], 1) + self.assertEqual(output.shape[1], 64 // 16) + self.assertEqual(output.shape[2], 32) + self.assertEqual(output.shape[3], 16) + + # cannot be divided by 16 + input_tensor = torch.randn(30, 62) + output = utils.nd_to_nz_2d(input_tensor) + self.assertEqual(output.shape[0], 1) + self.assertEqual(output.shape[1], math.ceil(62 / 16)) + self.assertEqual(output.shape[2], 32) + self.assertEqual(output.shape[3], 16) + + # pad to 16 + input_tensor = torch.randn(8, 12) + output = utils.nd_to_nz_2d(input_tensor) + self.assertEqual(output.shape[0], 1) + self.assertEqual(output.shape[1], 1) # 12->16, 16//16=1 + self.assertEqual(output.shape[2], 16) # 8->16 + self.assertEqual(output.shape[3], 16) + + # check if the output is contiguous + input_tensor = torch.randn(32, 64) + output = utils.nd_to_nz_2d(input_tensor) + self.assertTrue(output.is_contiguous()) + + # check if the output values are preserved + input_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + output = utils.nd_to_nz_2d(input_tensor) + expected = torch.tensor( + [[[[1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]) + self.assertTrue(torch.allclose(output, expected)) + + def test_aligned_16(self): + # align to 16 + input_tensor = torch.randn(15, 64) + output_tensor = utils.aligned_16(input_tensor) + self.assertEqual(output_tensor.shape[0], 16) + + # align to 16 + input_tensor = torch.randn(16, 64) + output_tensor = utils.aligned_16(input_tensor) + self.assertEqual(output_tensor.shape[0], 16) + self.assertTrue(torch.equal(input_tensor, output_tensor)) + + # align to 32 + input_tensor = torch.randn(17, 64) + output_tensor = utils.aligned_16(input_tensor) + self.assertEqual(output_tensor.shape[0], 32) + + @mock.patch('torch_npu.get_npu_format') + @mock.patch('torch_npu.npu_format_cast') + @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', + new=mock.MagicMock) + @mock.patch('vllm_ascend.utils.is_310p') + @mock.patch('vllm_ascend.utils.get_ascend_config') + def test_maybe_converting_weight_acl_format(self, mock_get_config, + mock_310p, mock_npu_cast, + mock_get_format): + ACL_FORMAT_FRACTAL_NZ = 29 + mock_310p.return_value = True + + mock_config = mock.MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_config.return_value = mock_config + mock_get_format.return_value = 1 + + mock_npu_cast.return_value = 1 + + fused_moe = mock.MagicMock() + fused_moe.w13_weight = mock.MagicMock() + fused_moe.w2_weight = mock.MagicMock() + fused_moe.w13_weight.data = torch.randn(128, 256) + fused_moe.w2_weight.data = torch.randn(256, 128) + model = mock.MagicMock() + model.modules.return_value = [fused_moe] + + utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) + self.assertEqual(fused_moe.w13_weight.data, 1) + + @mock.patch('torch_npu.get_npu_format') + @mock.patch('torch_npu.npu_format_cast') + @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', + new=mock.MagicMock) + @mock.patch('vllm_ascend.utils.is_310p') + @mock.patch('vllm_ascend.utils.get_ascend_config') + def test_maybe_converting_weight_acl_format_format_true( + self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format): + ACL_FORMAT_FRACTAL_NZ = 29 + mock_310p.return_value = True + + mock_config = mock.MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_config.return_value = mock_config + mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ + + mock_npu_cast.return_value = 1 + + fused_moe = mock.MagicMock() + fused_moe.w13_weight = mock.MagicMock() + fused_moe.w2_weight = mock.MagicMock() + fused_moe.w13_weight.data = torch.randn(128, 256) + fused_moe.w2_weight.data = torch.randn(256, 128) + model = mock.MagicMock() + model.modules.return_value = [fused_moe] + + mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ + + utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) + + @mock.patch('vllm_ascend.utils.get_ascend_config') + @mock.patch('vllm_ascend.utils.is_310p', return_value=False) + def test_maybe_converting_weight_acl_format_not_310_not_graph( + self, mock_310p, mock_get_config): + mock_config = mock.MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_config.return_value = mock_config + + mock_constant = mock.MagicMock() + + mock_model = mock.MagicMock() + utils.maybe_converting_weight_acl_format(mock_model, mock_constant) + + @mock.patch('importlib.util.find_spec') + @mock.patch('importlib.import_module') + def test_try_register_lib(self, mock_import_module, mock_find_spec): + # import OK + mock_find_spec.return_value = mock.MagicMock() + mock_import_module.return_value = mock.MagicMock() + lib_name = "existing_lib" + lib_info = "Library found and imported successfully" + utils.try_register_lib(lib_name, lib_info) + + # Can't find lib + mock_find_spec.return_value = None + lib_name = "non_existing_lib" + utils.try_register_lib(lib_name) + + # import error + mock_find_spec.return_value = mock.MagicMock() + mock_import_module.side_effect = ImportError("import error") + lib_name = "error_lib" + utils.try_register_lib(lib_name) + + def test_enable_custom_op(self): + result = utils.enable_custom_op() + self.assertTrue(result) + + utils._CUSTOM_OP_ENABLED = None + + with mock.patch('builtins.__import__') as mock_import_module: + mock_import_module.side_effect = ImportError("import error") + self.assertFalse(utils.enable_custom_op()) + + def test_find_hccl_library(self): + with mock.patch.dict(os.environ, + {"HCCL_SO_PATH": "/path/to/hccl/libhccl.so"}): + self.assertEqual(utils.find_hccl_library(), + "/path/to/hccl/libhccl.so") + with mock.patch("torch.version.cann", None): + self.assertRaises(ValueError, utils.find_hccl_library) + with mock.patch("torch.version.cann", "Ascend910"): + self.assertEqual(utils.find_hccl_library(), "libhccl.so") + + def test_current_stream(self): + with mock.patch("torch.npu.current_stream") as mock_current_stream: + self.assertEqual(utils.current_stream(), mock_current_stream()) + + def test_vllm_version_is(self): + with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}): + with mock.patch("vllm.__version__", "1.0.0"): + self.assertTrue(utils.vllm_version_is("1.0.0")) + self.assertFalse(utils.vllm_version_is("2.0.0")) + with mock.patch("vllm.__version__", "2.0.0"): + self.assertTrue(utils.vllm_version_is("1.0.0")) + self.assertFalse(utils.vllm_version_is("2.0.0")) + with mock.patch("vllm.__version__", "1.0.0"): + self.assertTrue(utils.vllm_version_is("1.0.0")) + self.assertFalse(utils.vllm_version_is("2.0.0")) + with mock.patch("vllm.__version__", "2.0.0"): + self.assertTrue(utils.vllm_version_is("2.0.0")) + self.assertFalse(utils.vllm_version_is("1.0.0")) + + def test_update_aclgraph_sizes(self): + # max_num_batch_sizes < len(original_sizes) + test_compilation_config = CompilationConfig( + cudagraph_capture_sizes=[i for i in range(150)]) + model_path = os.path.join(os.path.dirname(__file__), "fake_weight") + test_model_config = ModelConfig(model=model_path, enforce_eager=True) + test_parallel_config = ParallelConfig() + test_vllm_config = VllmConfig( + model_config=test_model_config, + compilation_config=test_compilation_config, + parallel_config=test_parallel_config, + ) + utils.update_aclgraph_sizes(test_vllm_config) + self.assertEqual( + 147, + len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) + # max_num_batch_sizes >= len(original_sizes) + test_compilation_config = CompilationConfig( + cudagraph_capture_sizes=[1, 2, 3]) + test_vllm_config = VllmConfig( + model_config=test_model_config, + compilation_config=test_compilation_config, + parallel_config=test_parallel_config, + ) + utils.update_aclgraph_sizes(test_vllm_config) + self.assertEqual( + 3, + len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) + + def test_get_torchair_current_work_dir(self): + cache_dir = utils.TORCHAIR_CACHE_DIR + work_dir = utils.get_torchair_current_work_dir() + self.assertEqual(cache_dir, work_dir) + work_dir = utils.get_torchair_current_work_dir("test") + self.assertEqual(os.path.join(cache_dir, "test"), work_dir) + + def test_torchair_cache_dir(self): + utils.write_kv_cache_bytes_to_file(0, 100) + self.assertTrue(utils.check_torchair_cache_exist(), + "Create torchair cache dir failed") + self.assertTrue(utils.check_kv_cache_bytes_cache_exist(), + "Create kv cache bytes cache dir failed") + kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0) + self.assertEqual(100, kv_cache_bytes) + utils.delete_torchair_cache_file() + self.assertFalse(utils.check_torchair_cache_exist(), + "Delete torchair cache dir failed") + self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), + "Delete kv cache bytes cache dir failed") + + +class TestProfileExecuteDuration(unittest.TestCase): + + def setUp(self): + utils.ProfileExecuteDuration._instance = None + utils.ProfileExecuteDuration._observations = [] + utils.ProfileExecuteDuration._lock = Lock() + + def test_singleton_creation(self): + instance1 = utils.ProfileExecuteDuration() + self.assertIsNotNone(instance1) + self.assertIs(instance1, utils.ProfileExecuteDuration._instance) + + instance2 = utils.ProfileExecuteDuration() + self.assertIs(instance1, instance2) + + def test_thread_safety(self): + from threading import Thread + + instances = [] + + def create_instance(): + instances.append(utils.ProfileExecuteDuration()) + + threads = [Thread(target=create_instance) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + first_instance = instances[0] + for instance in instances[1:]: + self.assertIs(first_instance, instance) + + def test_atexit_registration(self): + with mock.patch('atexit.register') as mock_register: + instance = utils.ProfileExecuteDuration() + mock_register.assert_called_once_with(instance.destroy) + + def test_lock_usage(self): + original_lock = utils.ProfileExecuteDuration._lock + + with mock.patch.object(utils.ProfileExecuteDuration, + '_lock', + wraps=original_lock) as mock_lock: + utils.ProfileExecuteDuration() + mock_lock.__enter__.assert_called() + mock_lock.__exit__.assert_called() + + def test_observations_initialization(self): + instance = utils.ProfileExecuteDuration() + self.assertEqual(instance._observations, []) diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py new file mode 100644 index 0000000000..cbfd67f0a2 --- /dev/null +++ b/tests/ut/worker/test_input_batch.py @@ -0,0 +1,162 @@ +import unittest + +import numpy as np +import torch +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.block_table import MultiGroupBlockTable + +from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch + + +def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]): + return CachedRequestState( + req_id=req_id, + prompt_token_ids=prompt, + mm_inputs=[], + mm_positions=[], + sampling_params=SamplingParams(), + pooling_params=None, + generator=None, + block_ids=([], ), + num_computed_tokens=0, + output_token_ids=output, + ) + + +class TestInputBatch(unittest.TestCase): + + def setUp(self): + self.max_num_reqs = 10 + self.max_model_len = 32 + self.max_num_batched_tokens = 132 + self.vocab_size = 1000 + self.device = torch.device("cpu") + self.block_sizes = [128] + + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_batched_tokens, + device=self.device, + pin_memory=False, + vocab_size=self.vocab_size, + block_sizes=self.block_sizes, + ) + self.cached_request_state = mock_cached_request_state() + + def test_shapes_and_defaults(self): + # torch tensor shape assertions + self.assertEqual(self.input_batch.token_ids_cpu_tensor.shape, + (self.max_num_reqs, self.max_model_len)) + self.assertEqual(self.input_batch.temperature.shape, + (self.max_num_reqs, )) + self.assertEqual(self.input_batch.top_k.shape, (self.max_num_reqs, )) + self.assertEqual(self.input_batch.min_p_cpu_tensor.shape, + (self.max_num_reqs, )) + + # numpy shape assertions + self.assertEqual(self.input_batch.token_ids_cpu.shape, + (self.max_num_reqs, self.max_model_len)) + self.assertEqual(self.input_batch.num_tokens.shape, + (self.max_num_reqs, )) + self.assertEqual(self.input_batch.num_tokens.shape, + (self.max_num_reqs, )) + + # type assertions + self.assertIsInstance(self.input_batch.greedy_reqs, set) + self.assertIsInstance(self.input_batch.req_id_to_index, dict) + self.assertIsInstance(self.input_batch.sampling_metadata, + SamplingMetadata) + self.assertIsInstance(self.input_batch.block_table, + MultiGroupBlockTable) + self.assertIsNone(self.input_batch.allowed_token_ids_mask) + self.assertIsNone(self.input_batch.allowed_token_ids_mask_cpu_tensor) + + def test_add_request(self): + # case1: add a new req + self.input_batch.add_request(self.cached_request_state) + self.assertIn(self.cached_request_state.req_id, + self.input_batch.req_id_to_index) + req_index = self.input_batch.req_id_to_index[ + self.cached_request_state.req_id] + self.assertEqual(self.input_batch.num_prompt_tokens[req_index], + len(self.cached_request_state.prompt_token_ids)) + self.assertEqual(self.input_batch.num_tokens[req_index], + self.cached_request_state.num_tokens) + + # case2: add an existing req, maybe need update + self.cached_request_state.output_token_ids.extend([7, 8, 9]) + self.cached_request_state.num_computed_tokens += 3 + cached_index = self.input_batch.req_id_to_index[ + self.cached_request_state.req_id] + self.input_batch.add_request(self.cached_request_state, cached_index) + # check if this index in the input_batch is updated + # This np arrat "token_ids_cpu" should be filled with prompt_token_ids + output_token_ids + self.assertTrue( + np.all(self.input_batch.token_ids_cpu[ + cached_index, :self.cached_request_state.num_tokens]), + msg=f"Token IDs at index {cached_index} did not update correctly.") + + # case3: add req that greater than max_num_reqs + with self.assertRaises(AssertionError): + self.input_batch.add_request(self.cached_request_state, + req_index=self.max_num_reqs) + + # case4: add req that out of max_model_len + long_prompt = list(range(self.max_model_len + 1)) + long_request = mock_cached_request_state(req_id="2", + prompt=long_prompt, + output=[10]) + with self.assertRaises(ValueError) as cm: + self.input_batch.add_request(long_request) + self.assertIn("could not broadcast", str(cm.exception)) + + def test_remove_request(self): + self.input_batch.add_request(self.cached_request_state) + req_index = self.input_batch.remove_request( + self.cached_request_state.req_id) + self.assertIsNotNone(req_index) + self.assertNotIn(self.cached_request_state.req_id, + self.input_batch.req_id_to_index) + self.assertIsNone(self.input_batch._req_ids[req_index]) + + def test_condense(self): + # Let's say we have some requests like below + # Index Req ID + # 0 1 + # 1 2 + # 2 3 + # 3 4 + for i in range(4): + request = mock_cached_request_state(req_id=str(i + 1)) + self.input_batch.add_request(request) + removed_req_indices = [] + id_to_remove = ["2", "4"] # IDs to remove + for req_id in id_to_remove: + removed_index = self.input_batch.remove_request(req_id) + if removed_index is not None: + removed_req_indices.append(removed_index) + self.assertEqual(len(removed_req_indices), len(id_to_remove)) + self.input_batch.condense(sorted(removed_req_indices, reverse=True)) + + # Check if the remaining requests are condensed correctly + indices = [ + self.input_batch.req_id_to_index[req_id] for req_id in ["1", "3"] + ] + self.assertTrue(all(idx < self.input_batch.num_reqs + for idx in indices)) + + for i in range(self.input_batch.num_reqs): + self.assertIsNotNone(self.input_batch._req_ids[i]) + for i in range(self.input_batch.num_reqs, + len(self.input_batch._req_ids)): + self.assertIsNone(self.input_batch._req_ids[i]) + + for req_id in ["1", "3"]: + idx = self.input_batch.req_id_to_index[req_id] + tokens = self.input_batch.token_ids_cpu[idx] + self.assertTrue( + tokens.any(), + f"Tokens at index {idx} for req {req_id} should not be all zero" + ) diff --git a/tests/ut/worker/test_pooling_model_runner.py b/tests/ut/worker/test_pooling_model_runner.py new file mode 100644 index 0000000000..28a0a7d3c6 --- /dev/null +++ b/tests/ut/worker/test_pooling_model_runner.py @@ -0,0 +1,355 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.engine.arg_utils import EngineArgs +from vllm.pooling_params import PoolingParams +from vllm.sequence import SequenceData, SequenceGroupMetadata + +from vllm_ascend.worker.pooling_model_runner import ( + ModelInputForNPUWithPoolingMetadata, NPUPoolingModelRunner) + + +class TestPoolingModelRunner(unittest.TestCase): + """Unit tests for the NPUPoolingModelRunner class.""" + + def _create_model_runner(self, model: str, *args, + **kwargs) -> NPUPoolingModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = NPUPoolingModelRunner(vllm_config=engine_config, ) + return model_runner + + def setUp(self): + """Initialize test fixtures and common mocks""" + self.attn_backend = "npu" + + model_runner = self._create_model_runner( + "tests/ut/fake_weight", + trust_remote_code=True, + enable_chunked_prefill=False, + ) + + self.runner = model_runner + self.runner.attn_backend = self.attn_backend + model_runner.model = MagicMock() + self.runner = model_runner + # Sample test data + self.sample_tensor_dict = {"tensor1": torch.randn(3, 4)} + self.sample_seq_group = [MagicMock(spec=SequenceGroupMetadata)] + self.sample_finished_ids = ["req1", "req2"] + + @patch( + 'vllm_ascend.worker.pooling_model_runner.ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict' + ) + def test_make_model_input_from_broadcasted_tensor_dict( + self, mock_from_dict): + """Test tensor dictionary conversion to model input""" + # Setup mock return + expected_output = MagicMock() + mock_from_dict.return_value = expected_output + + # Execute + result = self.runner.make_model_input_from_broadcasted_tensor_dict( + self.sample_tensor_dict) + + # Verify + mock_from_dict.assert_called_once_with(self.sample_tensor_dict, + attn_backend=self.attn_backend) + self.assertEqual(result, expected_output) + + @patch.object(NPUPoolingModelRunner, '_prepare_pooling') + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_normal_case(self, mock_prepare_tensors, + mock_prepare_pooling): + """Test normal flow of model input preparation""" + # Setup mocks + mock_model_input = ModelInputForNPUWithPoolingMetadata( + seq_lens=[1, 2, 3]) + mock_prepare_tensors.return_value = mock_model_input + + mock_pooling_metadata = MagicMock() + mock_prepare_pooling.return_value = mock_pooling_metadata + + # Execute + result = self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + finished_requests_ids=self.sample_finished_ids) + + # Verify + mock_prepare_tensors.assert_called_once_with(self.sample_seq_group, + self.sample_finished_ids) + mock_prepare_pooling.assert_called_once_with(self.sample_seq_group, + mock_model_input.seq_lens) + self.assertEqual(result.pooling_metadata, mock_pooling_metadata) + + def test_prepare_model_input_null_sequence_group(self): + """Test assertion when seq_group_metadata_list is None""" + with self.assertRaises(AssertionError): + self.runner.prepare_model_input( + seq_group_metadata_list=None, + finished_requests_ids=self.sample_finished_ids) + + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_null_seq_lens(self, mock_prepare_tensors): + """Test assertion when seq_lens is None in model input""" + # Setup mock with None seq_lens + mock_model_input = MagicMock() + mock_model_input.seq_lens = None + mock_prepare_tensors.return_value = mock_model_input + + with self.assertRaises(AssertionError): + self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + finished_requests_ids=self.sample_finished_ids) + + @patch.object(NPUPoolingModelRunner, '_prepare_pooling') + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_with_virtual_engine(self, + mock_prepare_tensors, + mock_prepare_pooling): + """Test virtual engine parameter is properly handled""" + # Setup mocks + mock_model_input = ModelInputForNPUWithPoolingMetadata( + seq_lens=[1, 2, 3]) + mock_prepare_tensors.return_value = mock_model_input + + # Execute with virtual_engine parameter + result = self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + virtual_engine=1, + finished_requests_ids=self.sample_finished_ids) + + # Verify virtual_engine doesn't affect the flow + self.assertIsNotNone(result) + + @patch.object(NPUPoolingModelRunner, '_prepare_pooling') + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_with_null_finished_ids( + self, mock_prepare_tensors, mock_prepare_pooling): + """Test case when finished_requests_ids is None""" + # Setup mocks + mock_model_input = ModelInputForNPUWithPoolingMetadata( + seq_lens=[1, 2, 3]) + mock_prepare_tensors.return_value = mock_model_input + + # Execute with None finished_ids + result = self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + finished_requests_ids=None) + + # Verify + mock_prepare_tensors.assert_called_once_with(self.sample_seq_group, + None) + self.assertIsNotNone(result) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_normal_case(self, mock_pooling_metadata): + """Test normal case with multiple sequences in group""" + # Setup test data + mock_pooling_metadata.return_value = None + seq_data = { + 1: MagicMock(spec=SequenceData), + 2: MagicMock(spec=SequenceData) + } + pooling_params = MagicMock(spec=PoolingParams) + seq_group = MagicMock(spec=SequenceGroupMetadata) + seq_group.seq_data = seq_data + seq_group.pooling_params = pooling_params + + # Call the function + self.runner._prepare_pooling([seq_group], [10, 20]) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([1, 2], pooling_params) + ], + seq_data=seq_data, + prompt_lens=[10, 20]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_empty_group(self, mock_pooling_metadata): + """Test case with empty sequence group""" + # Setup empty group + mock_pooling_metadata.return_value = None + empty_seq_data: dict[int, SequenceData] = {} + pooling_params = MagicMock(spec=PoolingParams) + empty_group = MagicMock(spec=SequenceGroupMetadata) + empty_group.seq_data = empty_seq_data + empty_group.pooling_params = pooling_params + + # Call the function + self.runner._prepare_pooling([empty_group], []) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([], pooling_params) + ], + seq_data={}, + prompt_lens=[]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_single_sequence(self, mock_pooling_metadata): + """Test case with single sequence in group""" + # Setup single sequence + mock_pooling_metadata.return_value = None + single_seq_data = {3: MagicMock(spec=SequenceData)} + pooling_params = MagicMock(spec=PoolingParams) + single_group = MagicMock(spec=SequenceGroupMetadata) + single_group.seq_data = single_seq_data + single_group.pooling_params = pooling_params + + # Call the function + self.runner._prepare_pooling([single_group], [5]) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([3], pooling_params) + ], + seq_data=single_seq_data, + prompt_lens=[5]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_multiple_groups(self, mock_pooling_metadata): + """Test case with multiple sequence groups""" + # Setup multiple groups + mock_pooling_metadata.return_value = None + seq_data1 = {1: MagicMock(spec=SequenceData)} + seq_data2 = {2: MagicMock(spec=SequenceData)} + params1 = MagicMock(spec=PoolingParams) + params2 = MagicMock(spec=PoolingParams) + + group1 = MagicMock(spec=SequenceGroupMetadata) + group1.seq_data = seq_data1 + group1.pooling_params = params1 + + group2 = MagicMock(spec=SequenceGroupMetadata) + group2.seq_data = seq_data2 + group2.pooling_params = params2 + + # Call the function + self.runner._prepare_pooling([group1, group2], [10, 20]) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([1], params1), ([2], params2) + ], + seq_data={ + **seq_data1, + **seq_data2 + }, + prompt_lens=[10, 20]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_empty_input(self, mock_pooling_metadata): + """Test case with empty input lists""" + # Call the function with empty inputs + mock_pooling_metadata.return_value = None + self.runner._prepare_pooling([], []) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[], + seq_data={}, + prompt_lens=[]) + + @patch('vllm.forward_context.set_forward_context') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=True)) + @patch('torch.npu.Event') + @patch.object(NPUPoolingModelRunner, 'set_active_loras') + @patch.object(NPUPoolingModelRunner, 'set_active_prompt_adapters') + def test_execute_model_normal_flow(self, mock_set_adapters, mock_set_loras, + mock_event, mock_pp, mock_set_forward): + """Test normal execution path with all dependencies mocked""" + + # Setup model input mock + mock_input = MagicMock() + mock_input.input_tokens = torch.tensor([1]) + mock_input.input_positions = torch.tensor([0]) + mock_input.multi_modal_kwargs = {} + self.runner.is_driver_worker = True + # Execute + self.runner.execute_model(model_input=mock_input, + kv_caches=[], + num_steps=1) + + # Verify core calls + self.runner.model.pooler.assert_called_once() + + @patch('vllm.forward_context.set_forward_context') + def test_execute_model_invalid_steps(self, mock_set_forward): + """Test ValueError when num_steps != 1""" + with self.assertRaises(ValueError): + self.runner.execute_model(model_input=MagicMock(), + kv_caches=[], + num_steps=2) + mock_set_forward.assert_not_called() + + @patch('vllm.forward_context.set_forward_context') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=False)) + @patch('torch.npu.Event') + def test_execute_model_perf_monitoring(self, mock_event, mock_pp, + mock_set_forward): + """Test performance monitoring with timing mocks""" + # Setup mocks + + mock_event.return_value.elapsed_time.return_value = 15.0 + self.runner.observability_config = MagicMock( + collect_model_forward_time=True) + + # Execute + self.runner.execute_model(model_input=MagicMock( + input_tokens=torch.tensor([1]), + input_positions=torch.tensor([0]), + multi_modal_kwargs={}), + kv_caches=[], + num_steps=1) + + # Verify timing calls + self.assertEqual(mock_event.call_count, 2) + + @patch('vllm.forward_context.set_forward_context') + @patch.object(NPUPoolingModelRunner, 'set_active_loras') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=False)) + def test_execute_model_lora_config(self, mock_pp, set_active_loras, + mock_set_forward): + """Test LoRA configuration handling""" + # Setup + + self.runner.lora_config = True + mock_input = MagicMock() + mock_input.lora_requests = ["req1"] + mock_input.lora_mapping = {"map": 1} + + # Execute + self.runner.execute_model(model_input=mock_input, + kv_caches=[], + num_steps=1) + + # Verify LoRA call + set_active_loras.assert_called_once_with(["req1"], {"map": 1}) + + @patch('vllm.forward_context.set_forward_context') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=False)) + def test_execute_model_not_last_rank(self, mock_pp, mock_set_forward): + """Test behavior when not the last pipeline rank""" + # Setup + + # Execute + self.runner.execute_model(model_input=MagicMock( + input_tokens=torch.tensor([1]), + input_positions=torch.tensor([0]), + multi_modal_kwargs={}), + kv_caches=[], + num_steps=1) + + # Verify pooler not called + self.runner.model.pooler.assert_not_called() diff --git a/tests/utils.py b/tests/utils.py index ced7d9a1b1..2535d089e7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,10 +23,13 @@ import subprocess import sys import time +from collections.abc import Sequence from typing import Callable, Optional import openai import requests +import torch +import torch.nn.functional as F from typing_extensions import ParamSpec from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -197,3 +200,37 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: f" args {args} and kwargs {kwargs}") return wrapper + + +def matryoshka_fy(tensor: torch.Tensor, dimensions: int): + tensor = torch.tensor(tensor) + tensor = tensor[..., :dimensions] + tensor = F.normalize(tensor, p=2, dim=1) + return tensor + + +def check_embeddings_close( + *, + embeddings_0_lst: Sequence[list[float]], + embeddings_1_lst: Sequence[list[float]], + name_0: str, + name_1: str, + tol: float = 1e-3, +) -> None: + assert len(embeddings_0_lst) == len(embeddings_1_lst) + + for prompt_idx, (embeddings_0, embeddings_1) in enumerate( + zip(embeddings_0_lst, embeddings_1_lst)): + assert len(embeddings_0) == len(embeddings_1), ( + f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") + + sim = F.cosine_similarity(torch.tensor(embeddings_0), + torch.tensor(embeddings_1), + dim=0) + + fail_msg = (f"Test{prompt_idx}:" + f"\nCosine similarity: \t{sim:.4f}" + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}") + + assert sim >= 1 - tol, fail_msg diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index d8b87c6952..eb5b09c4fb 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -36,11 +36,12 @@ def __init__(self, vllm_config): self.ascend_scheduler_config = AscendSchedulerConfig( ascend_scheduler_config) - self.expert_tensor_parallel_size = int( - additional_config.get("expert_tensor_parallel_size", 0)) self.expert_map_path = additional_config.get("expert_map_path", None) + self.dynamic_eplb = additional_config.get("dynamic_eplb", False) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) + self.enable_weight_nz_layout = additional_config.get( + "enable_weight_nz_layout", False) class TorchairGraphConfig: @@ -138,6 +139,12 @@ def check_ascend_config(vllm_config, enforce_eager): else: # torchair_graph case if ascend_config.torchair_graph_config.enabled: + # torchair_graph is not supported for V1 without mla currently. + if envs.VLLM_MLA_DISABLE: + logger.warning( + "Torchair graph mode is still experimental and not supported for V1 without mla currently, " + "it has been disabled automatically.") + ascend_config.torchair_graph_config.enabled = False # torchair_graph is supported for deepseek model only currently. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py new file mode 100644 index 0000000000..75fd71c859 --- /dev/null +++ b/vllm_ascend/ascend_forward_context.py @@ -0,0 +1,79 @@ +from contextlib import contextmanager +from enum import Enum +from typing import Any, Optional + +import torch +from vllm.config import VllmConfig +from vllm.distributed import get_dp_group +from vllm.forward_context import get_forward_context, set_forward_context + + +class FusedMoEState(Enum): + AllGather = 0 + All2All = 1 + MC2 = 2 + + +# TODO(zzzzwwjj): add soc_version to choose branch +def get_fused_moe_state(ep_size: int, with_prefill: bool): + if ep_size == 1: + return FusedMoEState.AllGather + # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. + elif ep_size < 16 or with_prefill: + return FusedMoEState.All2All + else: + return FusedMoEState.MC2 + + +@contextmanager +def set_ascend_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + with_prefill: bool = True, + in_profile_run: bool = False): + """A context manager that stores the current forward context, + can be attention metadata, etc. + We add some additional param into forward_context. + """ + with set_forward_context(attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + forward_context = get_forward_context() + forward_context.with_prefill = with_prefill + + ep_size = torch.distributed.get_world_size( + ) if vllm_config.parallel_config.enable_expert_parallel else 1 + + fused_moe_state = get_fused_moe_state(ep_size, with_prefill) + + forward_context.fused_moe_state = fused_moe_state + + forward_context.in_profile_run = in_profile_run + + # NOTE: This cannot be set using set_forward_context + # due to multiple warmups before actual capturing + forward_context.capturing = False + + dp_world_size = get_dp_group().world_size + if dp_world_size > 1 and forward_context.dp_metadata is not None: + forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( + ) + elif num_tokens is not None: + forward_context.max_tokens_across_dp = num_tokens + elif attn_metadata is not None: + if hasattr(attn_metadata, 'num_actual_tokens'): + forward_context.max_tokens_across_dp = attn_metadata.num_actual_tokens + else: + forward_context.max_tokens_across_dp = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + else: + forward_context.max_tokens_across_dp = None + + try: + yield + finally: + pass diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index e6a2376786..3417bb87fb 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,12 +24,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill +from vllm_ascend.utils import get_graph_params class AscendAttentionBackend(AttentionBackend): @@ -114,6 +118,7 @@ class AscendMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor + seq_lens_list: list # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -133,7 +138,7 @@ class AscendMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - with_prefill_across_dp: bool = False + enable_dbo_across_dp: bool = False class AscendAttentionMetadataBuilder: @@ -149,23 +154,26 @@ def build(self, num_reqs, num_actual_tokens, max_query_len, - common_prefix_len, - with_prefill_across_dp: bool = False): + common_attn_metadata: CommonAttentionMetadata, + enable_dbo_across_dp: bool = False, + *args, + **kwargs): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # TODO: Refactor these two param to common metadata in runners, + # preparing for the hybrid KV groups feature + query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens + seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list + + slot_mapping = self.runner.slot_mapping[:num_actual_tokens] attn_mask = self.runner.attn_mask attn_state = self.runner.attn_state - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -173,11 +181,40 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, + seq_lens_list=seq_lens_list, max_query_len=max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - with_prefill_across_dp=with_prefill_across_dp) + enable_dbo_across_dp=enable_dbo_across_dp) + return attn_metadata + + def build_dummy_metadata(self, num_actual_tokens, num_reqs, + num_scheduled_tokens, attn_state): + if attn_state == AscendAttentionState.DecodeOnly: + # NOTE: We only need to pay attention to seq_lens_list and block_table here + common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] * + num_reqs) + + block_table = self.runner.input_batch.block_table[0].block_table + block_table[:num_reqs, 0] = torch.arange(1, + num_reqs + 1, + device=block_table.device, + dtype=block_table.dtype) + + attn_metadata = self.build( + num_reqs=num_reqs, + num_actual_tokens=num_actual_tokens, + max_query_len=num_scheduled_tokens.max(), + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state return attn_metadata @@ -217,6 +254,10 @@ def __init__( self.key_cache = None self.value_cache = None + vllm_config = get_current_vllm_config() + self.full_graph = vllm_config.compilation_config.full_cuda_graph + self.block_size = vllm_config.cache_config.block_size + def forward( self, layer: AttentionLayer, @@ -228,21 +269,7 @@ def forward( output: Optional[torch.Tensor] = None, trace_flag: bool = True, ) -> torch.Tensor: - """Forward pass with Ascend attention. - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache: shape = [2, num_blocks, block_size, - num_kv_heads, head_size] - key_cache = [num_blocks, block_size, - num_kv_heads, head_size] - value_cache = [num_blocks, block_size, - num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [batch_size * seq_len, num_heads, head_size] - """ + """Forward pass with Ascend attention.""" num_tokens = query.shape[0] if output is None: output = torch.empty(num_tokens, @@ -275,7 +302,7 @@ def forward( # TODO: Remove this contiguous in the future. value = value.contiguous() - if kv_cache.numel() > 0: + if len(kv_cache) > 0: if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping @@ -307,11 +334,13 @@ def forward( assert attn_metadata is not None assert attn_metadata.attn_mask is not None compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] torch_npu._npu_flash_attention_qlens( query=query, key_cache=self.key_cache, value_cache=self.value_cache, - block_table=attn_metadata.block_tables, + block_table=block_table, mask=compress_mask, seq_len=attn_metadata.query_lens, context_lens=attn_metadata.seq_lens, @@ -320,16 +349,92 @@ def forward( scale_value=self.scale, out=output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + if self.full_graph: + graph_params = get_graph_params() + q = query.view(num_tokens, -1, self.hidden_size) + k = self.key_cache.view( # type: ignore + -1, self.block_size, + self.num_kv_heads * self.head_size) + v = self.value_cache.view( # type: ignore + -1, self.block_size, + self.num_kv_heads * self.head_size) + actual_seq_lens = attn_metadata.seq_lens_list + attn_args = { + "query": q, + "key": k, + "value": v, + "actual_seq_lengths_kv": actual_seq_lens, + "block_table": attn_metadata.block_tables, + "num_heads": self.num_heads, + "scale": self.scale, + "input_layout": "BSH", + "num_key_value_heads": self.num_kv_heads, + "block_size": self.block_size, + } + + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + attn_output = torch.empty(num_tokens, + 1, + self.hidden_size, + dtype=output.dtype, + device=output.device) + softmax_lse = torch.empty(num_tokens, + dtype=output.dtype, + device=output.device) + + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + **attn_args) + graph_params.workspaces[num_tokens] = workspace + + forward_context = get_forward_context() + if not forward_context.capturing: + # Execute attention kernel directly in non-capturing mode + torch.ops.npu.npu_fused_infer_attention_score.out( + workspace=workspace, + out=[attn_output, softmax_lse], + **attn_args) + else: + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + graph_params.attn_params[num_tokens].append( + (q, k, v, actual_seq_lens, + attn_metadata.block_tables, self.num_heads, + self.scale, self.num_kv_heads, attn_output, + softmax_lse)) + + torch.npu.graph_task_group_begin(stream) + torch.ops.npu.npu_fused_infer_attention_score.out( + workspace=workspace, + out=[attn_output, softmax_lse], + **attn_args) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + + # Reshape output to match the expected format + output.copy_( + attn_output.view(num_tokens, self.num_heads, + self.head_size)) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) # Normal V1 situation. else: # use chunked prefill for head size 192 scenario, like deepseek diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 189aa38e89..816d93c028 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,34 +13,23 @@ UnquantizedLinearMethod) from vllm.utils import cdiv, round_down +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_stream_switch, + npu_wait_tensor) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch -@dataclass -class CommonAttentionMetadata: - """ - Attention metadata attributes that can be shared by layers in different KV - cache groups and thus having different block table. - """ - - query_start_loc: torch.Tensor - """(batch_size + 1,), the start location of each request in query Tensor""" - seq_lens: torch.Tensor - """(batch_size,), the length of each request including both computed tokens - and newly scheduled tokens""" - - class AscendMLABackend(AttentionBackend): accept_output_buffer: bool = True @@ -103,6 +92,7 @@ class AscendMLADecodeMetadata: seq_lens: torch.Tensor max_seq_lens: int seq_lens_list: list[int] + actual_seq_q_lens: Optional[list[int]] = None attn_mask: Optional[torch.Tensor] = None @@ -136,8 +126,8 @@ class AscendMLAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - max_num_tokens_across_dp: int = 0 - with_prefill_across_dp: bool = False + enable_dbo_across_dp: bool = False + is_mtp_model: bool = False query_lens: Optional[list[int]] = None # The dimension of the attention heads @@ -290,7 +280,7 @@ def _get_graph_runner_block_tables( self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs + assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}" if isinstance(self.runner.graph_block_tables, np.ndarray): graph_block_tables = torch.zeros((max_batch_size, max_blocks), @@ -312,8 +302,12 @@ def _get_graph_runner_block_tables( return graph_block_tables[:num_seqs, :max_blocks] - def build_dummy(self, num_reqs: int, - num_actual_tokens: int) -> AscendMLAMetadata: + def build_torchair_graph_dummy( + self, + num_reqs: int, + num_actual_tokens: int, + is_mtp_model: bool = False, + ) -> AscendMLAMetadata: device = self.runner.device _, max_blocks = self.runner.graph_block_tables.shape block_table = torch.zeros((num_reqs, max_blocks), @@ -321,11 +315,13 @@ def build_dummy(self, num_reqs: int, device=device) block_table = self._get_graph_runner_block_tables( num_reqs, block_table) - seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) - input_positions = torch.zeros(num_reqs, + num_tokens = num_reqs * self.runner.decode_token_per_req + seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) + seq_lens_list = seq_lens.tolist() + input_positions = torch.zeros(num_tokens, dtype=torch.int32, device=device).long() - slot_mapping = torch.full((num_reqs, ), + slot_mapping = torch.full((num_tokens, ), PAD_SLOT_ID, dtype=torch.int32, device=device) @@ -333,28 +329,38 @@ def build_dummy(self, num_reqs: int, -1, dtype=torch.int32, device=device) + if self.runner.speculative_config is not None and\ + self.runner.speculative_config.method == 'deepseek_mtp' and not is_mtp_model: + attn_state = AscendAttentionState.SpecDecoding + num_decode_tokens = 2 + else: + attn_state = AscendAttentionState.DecodeOnly + num_decode_tokens = 1 decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), + seq_lens_list=seq_lens_list, max_seq_lens=1, - attn_mask=self.runner.spec_attn_mask) + attn_mask=self.runner.spec_attn_mask, + actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], + ) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), num_decodes=1, - num_decode_tokens=1, + num_decode_tokens=num_decode_tokens, num_prefills=0, attn_mask=self.runner.attn_mask, - attn_state=AscendAttentionState.DecodeOnly, + attn_state=attn_state, prefill=None, decode=decode_metadata, query_start_loc=query_start_loc, seq_lens=seq_lens, block_tables=block_table, + is_mtp_model=is_mtp_model, ) def build( @@ -364,9 +370,10 @@ def build( max_query_len: int, common_attn_metadata: CommonAttentionMetadata, common_prefix_len: Optional[int] = None, - graph_pad_size: int = -1, - max_num_tokens_across_dp: int = 0, - with_prefill_across_dp: bool = False, + num_token_pad_size: int = -1, + num_reqs_pad_size: int = 0, + enable_dbo_across_dp: bool = False, + is_mtp_model: bool = False, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs @@ -450,8 +457,9 @@ def build( ) decode_metadata = None - use_torchair_graph = graph_pad_size != -1 + use_torchair_graph = num_token_pad_size != -1 if self._num_decodes > 0: + actual_seq_q_lens = None max_seq_lens = seq_lens[:self._num_decodes].max().item() seq_lens = seq_lens[:self._num_decode_tokens] input_positions = input_positions[:self._num_decode_tokens] @@ -460,41 +468,48 @@ def build( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ]: - num_seqs = len(seq_lens) - if graph_pad_size != 0: - pad_value = 1 - padded_seq_lens = seq_lens.tolist() + [pad_value - ] * graph_pad_size + if num_token_pad_size != 0: + pad_value = 0 + padded_seq_lens = seq_lens.tolist( + ) + [pad_value] * num_reqs_pad_size else: padded_seq_lens = seq_lens.tolist() seq_lens = torch.from_numpy( np.array(padded_seq_lens).astype(np.int32)) - padding = torch.full((graph_pad_size, ), + seq_lens_list = padded_seq_lens + padding = torch.full((num_token_pad_size, ), PAD_SLOT_ID, dtype=slot_mapping.dtype, device=slot_mapping.device) slot_mapping = torch.cat([slot_mapping, padding]) block_table_padding = torch.zeros( - (graph_pad_size, ) + block_table.shape[1:], + (num_reqs_pad_size, ) + block_table.shape[1:], dtype=block_table.dtype, device=block_table.device) block_table = torch.cat([block_table, block_table_padding], dim=0) block_table = self._get_graph_runner_block_tables( - num_seqs + graph_pad_size, block_table) - padding_0 = torch.zeros(graph_pad_size, + num_reqs + num_reqs_pad_size, block_table) + padding_0 = torch.zeros(num_token_pad_size, dtype=input_positions.dtype, device=input_positions.device) input_positions = torch.cat([input_positions, padding_0]) + actual_seq_q_lens = query_start_loc[1:].tolist( + ) + self.runner.actual_seq_q_lens[num_reqs:num_reqs + + num_reqs_pad_size] + else: + seq_lens_list = seq_lens.tolist() decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), + seq_lens_list=seq_lens_list, max_seq_lens=max_seq_lens, - attn_mask=self.runner.spec_attn_mask) + attn_mask=self.runner.spec_attn_mask, + actual_seq_q_lens=actual_seq_q_lens, + ) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -511,8 +526,8 @@ def build( query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - max_num_tokens_across_dp=max_num_tokens_across_dp, - with_prefill_across_dp=with_prefill_across_dp, + enable_dbo_across_dp=enable_dbo_across_dp, + is_mtp_model=is_mtp_model, ) @@ -570,15 +585,6 @@ def __init__( self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 - # TODO: support numHeads / numKvHeads < 16 in MLA kernel - if self.torchair_graph_enabled: - assert self.num_queries_per_kv in _ALLOWED_NUM_QUERIES_PER_KV, \ - ("The allowed number of queries per kv when enabling both MLA and Graph mode" - " only support {32, 64, 128}, Thus this is not supported for DeepSeek-V2-Lite," - " as it only has 16 attention heads. And if you're using DeepSeek-V3 or DeepSeek-R1," - " please make sure after the tensor parallel split, num_heads / num_kv_heads in " - "{32, 64, 128}.") - def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -651,20 +657,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + if get_ascend_config().enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, + ACL_FORMAT_FRACTAL_NZ) + self.W_UK_T.data = torch_npu.npu_format_cast( + self.W_UK_T.data, ACL_FORMAT_FRACTAL_NZ) def _compute_prefill_context( self, query: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, ): + assert len(kv_c_and_k_pe_cache) > 1 prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse @@ -674,21 +683,22 @@ def _compute_prefill_context( q_nope = query[..., :self.qk_nope_head_dim] seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) - latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim - cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim] - cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:] + cache_kv_c = kv_c_and_k_pe_cache[0] + cache_k_pe = kv_c_and_k_pe_cache[1] + num_heads = cache_k_pe.size(2) + latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] seq_len = torch.stack([seq_len1, seq_len2]) kv_c_normed = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, latent_kv_dim, dtype=query.dtype, device=query.device) k_pe = torch.empty(toks, - kv_c_and_k_pe_cache.size(2), + num_heads, rope_dim, dtype=query.dtype, device=query.device) @@ -738,10 +748,11 @@ def _forward_prefill( query: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert len(kv_c_and_k_pe_cache) > 1 num_tokens = query.size(0) attn_output = torch.empty(num_tokens, @@ -758,7 +769,8 @@ def _forward_prefill( if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output_torch = torch.empty(num_tokens, self.num_heads * self.v_head_dim, @@ -783,7 +795,8 @@ def _forward_prefill( causal=True) elif attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ]: attn_lse = torch.empty(self.num_heads, num_tokens, @@ -833,15 +846,12 @@ def _forward_prefill( num_kv_heads=self.num_heads, out=attn_output) attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output = attn_output_torch @@ -934,44 +944,17 @@ def _forward_decode( q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1) - num_tokens = q.size(0) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) + num_tokens = q_nope.size(0) if self.running_in_graph: - # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - assert num_tokens % self.spec_token_num == 0 - q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1), - self.spec_token_num + 1, self.num_heads, - -1) - q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1), - self.spec_token_num + 1, self.num_heads, -1) - if not self.enable_kv_nz: - q_nope = q_nope.transpose(1, 2).contiguous() - q_pe = q_pe.transpose(1, 2).contiguous() - sparse_mode = 3 - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore - else: - if self.enable_kv_nz: - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - sparse_mode = 0 - spec_attn_mask = None # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] + actual_seq_lengths = None if self.enable_kv_nz: k_nope = k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank // 16, block_size, 16) @@ -985,6 +968,26 @@ def _forward_decode( self.qk_rope_head_dim) input_layout = "BNSD" + # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + # [bs * q_seq_len, num_heads_per_rank, dim] + input_layout = "TND" + q_nope = q_nope.view(num_tokens, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_q_lens + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + attn_output, _ = torch_npu.npu_fused_infer_attention_score( q_nope, k_nope, @@ -1002,18 +1005,37 @@ def _forward_decode( block_table=decode_meta.block_table, block_size=block_size, actual_seq_lengths_kv=decode_meta.seq_lens_list, - ) + actual_seq_lengths=actual_seq_lengths) else: - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=kv_c_and_k_pe_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode.block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become + # public available + assert len(kv_c_and_k_pe_cache) > 1 + if envs.VLLM_ASCEND_MLA_PA: + attn_output = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, q_pe, kv_c_and_k_pe_cache[0], + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, self.num_heads, self.scale, + self.num_kv_heads) + else: + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=k_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode. + block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: return self._v_up_proj_and_o_proj(attn_output) @@ -1029,7 +1051,7 @@ def forward( hidden_states_or_q_c: torch.Tensor, # query in unified attn hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -1037,16 +1059,17 @@ def forward( if attn_metadata is None: # Profiling run. return output + # mtp model is not support for graph mode yet + self.torchair_graph_enabled = self.torchair_graph_enabled and not attn_metadata.is_mtp_model self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - if not self.torchair_graph_enabled: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -1065,19 +1088,20 @@ def forward( if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - if not self.torchair_graph_enabled: - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: decode_k_nope = None assert attn_metadata.decode is not None if self.running_in_graph: - seq_len = self.rotary_emb.max_position_embeddings + seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor cos = self.rotary_emb.cos_cached[:seq_len].to( dtype=decode_hs_or_q_c.dtype) sin = self.rotary_emb.sin_cached[:seq_len].to( @@ -1111,9 +1135,7 @@ def forward( else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) + decode_q_pe.contiguous(), decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ @@ -1122,7 +1144,7 @@ def forward( prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] - seq_len = self.rotary_emb.max_position_embeddings + seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor cos = self.rotary_emb.cos_cached[:seq_len].to( dtype=prefill_q_pe.dtype) sin = self.rotary_emb.sin_cached[:seq_len].to( @@ -1134,22 +1156,24 @@ def forward( prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] + prefill_k_c_normed = prefill_k_nope prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) + prefill_q_pe.contiguous(), prefill_k_pe) + + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" if self.torchair_graph_enabled: - if len(kv_cache) > 0 and kv_cache[0].numel( + if kv_cache[0].numel( ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: slots = attn_metadata.slot_mapping # NOTE: Separate the kv cache in advance to avoid OOM or other issues @@ -1159,16 +1183,15 @@ def forward( key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slots) - elif kv_cache.numel() > 0: - key = torch.cat([ - kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), - k_pe - ], - dim=2) - torch_npu._npu_reshape_and_cache_siso( - key=key, - key_cache=kv_cache, - slot_indices=attn_metadata.slot_mapping.flatten()) + else: + kv_c_normed = kv_c_normed.view( + [num_actual_toks, self.num_kv_heads, -1]) + torch_npu._npu_reshape_and_cache( + key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=attn_metadata.slot_mapping) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py new file mode 100644 index 0000000000..c2b7bc156a --- /dev/null +++ b/vllm_ascend/attention/utils.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class AscendCommonAttentionMetadata: + """ + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. + """ + + query_start_loc: torch.Tensor = None + """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: Optional[torch.Tensor] = None + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + query_lens: Optional[torch.Tensor] = None + """(batch_size,), the length of each request including only the newly + scheduled tokens""" + seq_lens_list: Optional[list] = None + """(num_input_tokens,), note that this is specifically for FIA kernel""" diff --git a/vllm_ascend/compilation/piecewise_backend.py b/vllm_ascend/compilation/piecewise_backend.py index c6a800b3d8..aafe639373 100644 --- a/vllm_ascend/compilation/piecewise_backend.py +++ b/vllm_ascend/compilation/piecewise_backend.py @@ -28,9 +28,13 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.utils import weak_ref_tensors +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.utils import get_graph_params, set_graph_params + @dataclasses.dataclass class ConcreteSizeEntry: @@ -95,6 +99,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + if self.compilation_config.full_cuda_graph: + self.update_stream = torch.npu.Stream() + set_graph_params(self.aclgraph_capture_sizes) + # the entries for different shapes that we need to either # compile or capture aclgraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} @@ -116,7 +124,40 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def update_attn_params(self, graph_params, forward_context, runtime_shape): + for layer_idx in range(len(graph_params.handles[runtime_shape])): + query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[ + runtime_shape][layer_idx] + block_table = forward_context.attn_metadata.block_tables + actual_seq_lens = forward_context.attn_metadata.seq_lens_list + + with torch.npu.stream(self.update_stream): + torch.npu.graph_task_update_begin( + self.update_stream, + graph_params.handles[runtime_shape][layer_idx]) + torch.ops.npu.npu_fused_infer_attention_score.out( + query, + key, + value, + workspace=graph_params.workspaces[runtime_shape], + actual_seq_lengths_kv=actual_seq_lens, + block_table=block_table, + num_heads=num_heads, + scale=scale, + input_layout="BSH", + num_key_value_heads=num_kv_heads, + block_size=128, + out=[output, softmax_lse], + ) + torch.npu.graph_task_update_end(self.update_stream) + + graph_params.events[runtime_shape][layer_idx].record( + self.update_stream) + def __call__(self, *args) -> Any: + forward_context = get_forward_context() + graph_params = get_graph_params() + if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() @@ -127,6 +168,11 @@ def __call__(self, *args) -> Any: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) + if (getattr(forward_context.attn_metadata, "attn_state", + None) != AscendAttentionState.DecodeOnly + and self.compilation_config.full_cuda_graph): + return self.compiled_graph_for_general_shape(*args) + entry = self.concrete_size_entries[runtime_shape] if entry.runnable is None: @@ -189,6 +235,7 @@ def __call__(self, *args) -> Any: patch("torch.npu.empty_cache", lambda: None)) # mind-exploding: carefully manage the reference and memory. + forward_context.capturing = True with torch.npu.graph(aclgraph, pool=self.graph_pool): # `output` is managed by pytorch's aclgraph pool output = entry.runnable(*args) @@ -222,4 +269,9 @@ def __call__(self, *args) -> Any: ) entry.aclgraph.replay() + + if self.compilation_config.full_cuda_graph: + self.update_attn_params(graph_params, forward_context, + runtime_shape) + return entry.output diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 2fa31c264c..3f1477c9f9 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -23,7 +23,6 @@ from vllm.logger import logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.utils import cdiv -from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs @@ -87,14 +86,11 @@ def skip_cur_request(): self.waiting.popleft() skipped_waiting_requests.appendleft(request) - num_prealloc_computed_tokens = 0 # P/D: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING - num_prealloc_computed_tokens = ( - request.num_computed_tokens) else: skip_cur_request() continue @@ -112,8 +108,8 @@ def skip_cur_request(): load_kv_async = False # Get already-cached tokens. - if num_prealloc_computed_tokens == 0: - new_computed_blocks, num_native_computed_tokens = \ + if request.num_computed_tokens == 0: + new_computed_blocks, num_new_local_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) @@ -121,18 +117,17 @@ def skip_cur_request(): if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) + request, num_new_local_computed_tokens)) # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + + num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. - new_computed_blocks = KVCacheBlocks.create_empty() - num_native_computed_tokens = 0 - - # Total computed tokens (allocated in prior step). - num_computed_tokens = num_prealloc_computed_tokens + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens # P/D: loading remote KV, do not allocate for new work. if load_kv_async: @@ -142,9 +137,6 @@ def skip_cur_request(): # Number of tokens to be scheduled. else: prompt_limit = self._get_prompt_limit(request) - # Get already-cached tokens. - computed_blocks, num_computed_tokens = ( - self.kv_cache_manager.get_computed_blocks(request)) # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. @@ -172,7 +164,7 @@ def skip_cur_request(): skip_cur_request() continue assert num_new_tokens > 0 - blocks = computed_blocks.blocks[0] + blocks = new_computed_blocks.blocks[0] watermark = getattr(self.scheduler_config, "watermark", 0.01) if not self._check_watermark_for_prefill(request, num_new_tokens, @@ -184,8 +176,8 @@ def skip_cur_request(): new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, - num_native_computed_tokens, - new_computed_blocks=computed_blocks, + num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async) if new_blocks is None: @@ -195,8 +187,7 @@ def skip_cur_request(): # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if num_external_computed_tokens: - assert self.connector is not None + if self.connector is not None: self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, @@ -210,6 +201,7 @@ def skip_cur_request(): skipped_waiting_requests.appendleft(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, @@ -509,3 +501,40 @@ def update_from_output( return super().update_from_output(scheduler_output, model_runner_output) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less then one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + + # This will cache the blocks if caching is enabled. + # Note: vllm fix this in main branch, but still have issue on v0.9.1, so we just adopt the + # change on 0.9.1 and without cherry-pick this back to main branch on vllm-ascend + if self.kv_cache_manager.enable_caching: + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 88c2f2199b..d7be705c2b 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -25,3 +25,8 @@ KVConnectorFactory.register_connector( "AscendSimpleConnector", "vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "LLMDataDistCMgrConnector", + "vllm_ascend.distributed.llmdatadist_c_mgr_connector", + "LLMDataDistCMgrConnector") diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py new file mode 100644 index 0000000000..34543cc05c --- /dev/null +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -0,0 +1,789 @@ +import contextlib +import json +import math +import threading +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import llm_datadist # type: ignore +import msgspec +import torch +import zmq +from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, + LLMException, LLMRole) +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.forward_context import ForwardContext +from vllm.utils import get_ip, logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +from vllm_ascend import envs +from vllm_ascend.soc_info import NPUSocInfo + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32 +} + + +class LLMDataDistCMgrAgentMetadata(msgspec.Struct): + super_pod_id: str + server_id: str + device_id: str + device_ip: str + super_device_id: str + cluster_id: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: str + engine_id: str + remote_tp_size: str + + +class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req(self, request_id: str, local_block_ids: list[int], + kv_transfer_params: dict[str, Any]): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_tp_size=kv_transfer_params["remote_tp_size"], + ) + + +class LLMDataDistCMgrConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[ + LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler( + vllm_config, self.engine_id) + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = LLMDataDistCMgrConnectorWorker(vllm_config) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches( + self, + kv_caches: dict[str, # type: ignore[override] + Tuple[torch.Tensor]]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished(finished_req_ids) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + LLMDataDistCMgrConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """LLMDataDistCMgrConnector does not do layerwise saving, the load is in blocking manager.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata, **kwargs) -> None: + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """LLMDataDistCMgrConnector does not save explicitly.""" + pass + + +class LLMDataDistCMgrConnectorScheduler(): + + def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + self.local_ip = get_ip() + # Can not retrieve the parallel config since it is not initialized. + self.local_dp_rank = None + self.tp_size = None + dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + + self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT + + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector get_num_new_matched_tokens: num_computed_tokens={num_computed_tokens}, kv_transfer_params={params}" + ) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, + num_externel_tokens: int): + params = request.kv_transfer_params + logger.debug( + f"LLMDataDistCMgrConnector update states num_externel_tokens: {num_externel_tokens} kv_transfer_params: {params}" + ) + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port", "remote_tp_size")): + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning("" \ + f"Invalid KVTransferParams {params}, This request will be discard") + else: + assert num_externel_tokens == 0 + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = LLMDataDistCMgrConnectorMetadata() + + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params) + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + + params = request.kv_transfer_params + logger.debug( + "LLMDataDistCMgrConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # note: NIXL transfer the full block only, but I don't see any reason to do that, so here + # we just transfer any data that computed from prefill node + # note: there might be some issue on this, check it if there is any unexpected result + computed_block_ids = block_ids + # If prompt < block_size, no xfer so free blocks immediately. + + return False, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.local_ip, + remote_port=self.port, + remote_tp_size=str( + self.vllm_config.parallel_config.tensor_parallel_size), + ) + + +class LLMDataDistCMgrConnectorWorker(): + """ + Implementation of Worker side methods + """ + + def __init__(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None + logger.info("Initialize the LLMDataDistCMgrConnectorWorker") + # we assume the local node only contains dp and tp, and tp will not communicate inter-node. + # for any scenario beyond this scope, the functionality of this connector is not guaranteed. + self.local_rank_on_node = get_world_group().rank % ( + vllm_config.parallel_config.data_parallel_size_local * + vllm_config.parallel_config.tensor_parallel_size) + self.local_rank = get_world_group().local_rank + self.local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_rank = get_tp_group().rank_in_group + self.rank = get_world_group().rank + self.local_ip = get_ip() + self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config + self.local_agent_metadata: Optional[ + LLMDataDistCMgrAgentMetadata] = None + self.vllm_config = vllm_config + + self.llm_datadist_role = None + self.llm_datadist_remote_role = None + if self.kv_transfer_config.kv_role == "kv_producer": + self.llm_datadist_role = LLMRole.PROMPT + self.llm_datadist_remote_role = LLMRole.DECODER + elif self.kv_transfer_config.kv_role == "kv_consumer": + self.llm_datadist_role = LLMRole.DECODER + self.llm_datadist_remote_role = LLMRole.PROMPT + else: + raise RuntimeError( + f"LLMDataDistWorker: Receive unexpected kv role in LLMDataDistWorker, this worker now only support kv_producer and kv_consumer, but receiving {vllm_config.kv_transfer_config.kv_role}" + ) + + # linked_cluster record the cluster that already build the connection its format should be {"cluster_id": "comm_name"} + self.linked_cluster: dict[Any, Any] = {} + self.prefill_device_list: list[tuple[int, int]] = [] + self.decode_device_list: list[tuple[int, int]] = [] + global_rank_table = self.read_offline_rank_table() + self.local_agent_metadata = self.read_agent_metadata( + global_rank_table, self.local_ip, self.local_rank_on_node, + self.llm_datadist_role) + self.llm_datadist = LLMDataDist(self.llm_datadist_role, + self.local_agent_metadata.cluster_id) + self.init_llm_datadist() + self.finished_reqs: set[str] = set() + self.soc_info = NPUSocInfo() + + def listen_for_agent_metadata_req(self, event: threading.Event): + assert self.local_agent_metadata is not None + port = envs.VLLM_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs.VLLM_LLMDD_RPC_PORT + self.tp_size + self.tp_rank + url = f"tcp://0.0.0.0:{port}" + msg_encoder = msgspec.msgpack.Encoder() + msg_decoder = msgspec.msgpack.Decoder() + msg_to_send = msg_encoder.encode(self.local_agent_metadata) + logger.debug(f"Start to listen to address: {url}") + logger.debug( + f"The local agent metadata have {len(msg_to_send)} bytes here") + logger.info( + f"LLMDataDistCMgrConnectorWorker: Cluster {self.local_agent_metadata.cluster_id} start to listen request from peers" + ) + with zmq_ctx(zmq.ROUTER, url) as sock: # type: ignore[attr-defined] + event.set() + while True: + identity, _, msg = sock.recv_multipart() + decode_msg = msg_decoder.decode(msg) + if "cluster_id" in decode_msg: + decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) + logger.info( + f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" + ) + sock.send_multipart((identity, b"", msg_to_send)) + self.add_remote_agent(decode_msg) + else: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + ) + + def init_llm_datadist(self): + assert self.local_agent_metadata is not None + llm_config = LLMConfig() + llm_config.device_id = self.local_rank + llm_config.sync_kv_timeout = 20000 + llm_config.enable_switch_role = True + llm_config.enable_cache_manager = True + llm_config.enable_remote_cache_accessible = True + llm_config_options = llm_config.generate_options() + self.llm_datadist.init(llm_config_options) + self.cache_manager = self.llm_datadist.cache_manager + logger.info( + f"Done initialize llm_datadist in rank {self.rank}, local rank {self.local_rank}, cluster id {self.local_agent_metadata.cluster_id}" + ) + + def read_offline_rank_table(self): + assert ( + envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + ), "Please set path of rank_table to env variable DISAGGREGATED_PREFILL_RANK_TABLE_PATH" + rank_table_path = envs.DISAGGREGATED_PREFILL_RANK_TABLE_PATH + with open(rank_table_path, "r", encoding="utf-8") as f: + global_rank_table = json.load(f) + decode_device_list = global_rank_table["decode_device_list"] + for decode_device in decode_device_list: + server_id = decode_device["server_id"] + device_id = decode_device["device_id"] + self.decode_device_list.append((server_id, device_id)) + prefill_device_list = global_rank_table["prefill_device_list"] + for prefill_device in prefill_device_list: + server_id = prefill_device["server_id"] + device_id = prefill_device["device_id"] + self.prefill_device_list.append((server_id, device_id)) + + # global_rank_table = json.dumps(global_rank_table) + return global_rank_table + + def read_agent_metadata(self, global_rank_table, server_id, device_rank, + agent_role): + devices_type_list = [] + agent_metadata = None + if self.llm_datadist_role == LLMRole.PROMPT: + devices_type_list.append("prefill_device_list") + elif self.llm_datadist_role == LLMRole.DECODER: + devices_type_list.append("decode_device_list") + else: + devices_type_list.append("prefill_device_list") + devices_type_list.append("decode_device_list") + for device_type in devices_type_list: + device_list = global_rank_table[device_type] + device_list = [ + d for d in device_list if d.get("server_id") == server_id + ] + if len(device_list) <= device_rank: + continue + device_info = device_list[device_rank] + super_pod_id_ = device_info.get("super_pod_id", None) + server_id_ = device_info["server_id"] + device_id_ = device_info["device_id"] + device_ip_ = device_info["device_ip"] + super_device_id_ = device_info.get("super_device_id", None) + cluster_id_ = int(device_info["cluster_id"]) + agent_metadata = LLMDataDistCMgrAgentMetadata( + super_pod_id=super_pod_id_, + server_id=server_id_, + device_id=device_id_, + device_ip=device_ip_, + super_device_id=super_device_id_, + cluster_id=cluster_id_, + ) + assert agent_metadata is not None, f"Can't read the target server_id {server_id} and device_rank {device_rank} from rank table" + return agent_metadata + + def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + assert len(first_kv_cache_tuple) > 1 + assert self.local_agent_metadata is not None + kv_cache_dtype = first_kv_cache.dtype + self.use_mla: bool = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) + # MLA case. [2 (k_normed, k_pe), num_blocks, ...] + # MHA case. [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + + self.block_len = math.prod(block_shape) + self.cache_addr: list[int] = [] + alignment = 2 * 1024 * 1024 + if self.use_mla: + cache_k_normed_addr_list = [] + cache_k_pe_addr_list = [] + k_normed = None + k_pe = None + for cache_or_caches in kv_caches.values(): + assert len(cache_or_caches) > 1 + k_normed, k_pe = cache_or_caches[0], cache_or_caches[1] + cache_k_normed_addr_list.append(k_normed.data_ptr()) + cache_k_pe_addr_list.append(k_pe.data_ptr()) + self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list) + + cache_desc_k_normed = CacheDesc( + len(self.cache_addr[0]), [*k_normed.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_desc_k_pe = CacheDesc( + len(self.cache_addr[1]), [*k_pe.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + cache_key_k_normed = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=0) + cache_key_k_pe = BlocksCacheKey(cluster_id=int( + self.local_agent_metadata.cluster_id), + model_id=1) + self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe) + self.cache_key = (cache_key_k_normed, cache_key_k_pe) + try: + cache_k_normed = self.cache_manager.register_blocks_cache( + self.cache_desc[0], self.cache_addr[0], self.cache_key[0]) + cache_k_pe = self.cache_manager.register_blocks_cache( + self.cache_desc[1], self.cache_addr[1], self.cache_key[1]) + self.cache = (cache_k_normed, cache_k_pe) + logger.info("LLMDataDistWorker: End of register Paged Cache.") + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + else: + for cache_or_caches in kv_caches.values(): + for cache in cache_or_caches: + base_addr = cache.data_ptr() + assert base_addr % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + self.cache_addr.append(base_addr) + # register paged kv cache into the llm_cache manager + self.cache_desc = CacheDesc( + len(self.cache_addr), [*cache.shape], + TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype]) + self.cache_key = BlocksCacheKey( + cluster_id=int(self.local_agent_metadata.cluster_id)) + logger.info( + f"num of cache: {len(self.cache_addr)}, size of cache: {[*cache.shape]}, real size of cache: {first_kv_cache.shape}" + ) + try: + self.cache = self.cache_manager.register_blocks_cache( + self.cache_desc, self.cache_addr, self.cache_key) + logger.info( + "LLMDataDistCMgrConnectorWorker: End of register Paged Cache." + ) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" + ) + self.ready_event = threading.Event() + self.metadata_agent_listener_t = threading.Thread( + target=self.listen_for_agent_metadata_req, + args=(self.ready_event, ), + daemon=True, + name="metadata_agent_listener") + self.metadata_agent_listener_t.start() + self.ready_event.wait() + + def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata): + for req_id, meta in metadata.requests.items(): + logger.debug(f"Start to transmit {req_id}") + self._read_blocks(meta.local_block_ids, + meta.remote_block_ids, meta.remote_host, + int(meta.remote_port), meta.engine_id, req_id, + meta.remote_tp_size) + self.finished_reqs.add(req_id) + + def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: + assert self.local_agent_metadata is not None + remote_cluster_id = metadata.cluster_id + if remote_cluster_id in self.linked_cluster: + logger.debug( + f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection" + ) + return remote_cluster_id + remote_super_pod_id = metadata.super_pod_id + remote_server_id = metadata.server_id + is_same_server = remote_server_id == self.local_agent_metadata.server_id + is_same_pod = remote_super_pod_id == self.local_agent_metadata.super_pod_id + if self.llm_datadist_role == LLMRole.PROMPT: + prefill_metadata = self.local_agent_metadata + decode_metadata = metadata + else: + prefill_metadata = metadata + decode_metadata = self.local_agent_metadata + comm_name = f"pd_comm_{prefill_metadata.device_ip}_{decode_metadata.device_ip}" + cluster_rank_info = { + prefill_metadata.cluster_id: 0, + decode_metadata.cluster_id: 1 + } + rank_table = {} + rank_table["version"] = "1.2" + rank_table["server_count"] = "1" if is_same_server else "2" + rank_table["status"] = "completed" + + # generate server_list for rank table + rank_table["server_list"] = [] # type: ignore[assignment] + decode_server_device_info = None + prefill_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", prefill_metadata.device_id + ), ("device_ip", prefill_metadata.device_ip + ), ("super_device_id", + prefill_metadata.super_device_id), ("rank_id", "0")] + if v is not None + }], + "server_id": + prefill_metadata.server_id + } + if is_same_server: + prefill_server_device_info["device"].append( # type: ignore[attr-defined] + { + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }) + else: + decode_server_device_info = { + "device": [{ + k: v + for k, v in [( + "device_id", decode_metadata.device_id + ), ("device_ip", decode_metadata.device_ip + ), ("super_device_id", + decode_metadata.super_device_id), ("rank_id", "1")] + if v is not None + }], + "server_id": + decode_metadata.server_id + } + rank_table["server_list"].append( # type: ignore[attr-defined] + prefill_server_device_info) + if decode_server_device_info is not None: + rank_table["server_list"].append( # type: ignore[attr-defined] + decode_server_device_info) + + if self.soc_info.is_a3: + # generate super_pod_list for rank table + super_pod_list = [] + prefill_super_pod_info = { + "super_pod_id": prefill_metadata.super_pod_id, + "server_list": [{ + "server_id": prefill_metadata.server_id + }], + } + if is_same_pod and not is_same_server: + prefill_super_pod_info[ + "server_list"].append( # type: ignore[attr-defined] + {"server_id": decode_metadata.server_id}) + super_pod_list.append(prefill_super_pod_info) + if not is_same_pod: + decode_super_pod_id = { + "super_pod_id": decode_metadata.super_pod_id, + "server_list": [{ + "server_id": decode_metadata.server_id + }], + } + super_pod_list.append(decode_super_pod_id) + rank_table[ + "super_pod_list"] = super_pod_list # type: ignore[assignment] + logger.info( + f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}" + ) + logger.info(f"rank table \n{rank_table}") + logger.info(f"comm name: {comm_name}") + logger.info(f"cluster rank info: {cluster_rank_info}") + comm_id = self.llm_datadist.link(comm_name, cluster_rank_info, + json.dumps(rank_table)) + while True: + ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id) + if ret == llm_datadist.RegisterMemStatus.OK: + logger.info( + f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}" + ) + break + elif ret == llm_datadist.RegisterMemStatus.FAILED: + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}" + ) + time.sleep(1) + logger.info("Checking query_register_mem_status again") + self.linked_cluster.update({remote_cluster_id: comm_id}) + logger.info(f"cached linked cluster: {self.linked_cluster}") + logger.info( + f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !" + ) + return remote_cluster_id + + def remove_remote_agent(self, cluster_id: int): + if cluster_id not in self.linked_cluster: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: Warning! Can't remove remote client with cluster id {cluster_id} for its not exist in linked_cluster list" + ) + comm_id = self.linked_cluster[cluster_id] + try: + self.llm_datadist.unlink(comm_id) + self.linked_cluster.pop(cluster_id) + except LLMException: + logger.error( + f"Try to remove remote client with cluster id {cluster_id} failed!, program won't terminate, but please carefully check your environment" + ) + logger.info( + f"Successfully remove remote client with cluster id {cluster_id} !" + ) + + def connect_to_remote_agent(self, host: str, port: int) -> int: + url = f"tcp://{host}:{port}" + logger.debug(f"Querying metadata from url: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode(self.local_agent_metadata) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + logger.info("Try request remote metadata from socket......") + sock.send(msg_send) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder() + metadata = decoder.decode(metadata_bytes) + metadata = LLMDataDistCMgrAgentMetadata(**metadata) + logger.info(f"recving metadata: {metadata}") + cluster_id = self.add_remote_agent(metadata) + return cluster_id + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_ip: str, + remote_port: int, + remote_engine_id: str, + request_id: str, + remote_tp_size: str, + ): + # if remote_ip not in self.linked_cluster: + tp_offset = self.tp_rank % int(remote_tp_size) + remote_cluster_id = self.connect_to_remote_agent( + remote_ip, remote_port + tp_offset) + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + return + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + logger.info(f"remote cluster id is: {remote_cluster_id}") + if self.use_mla: + remote_cache_key_k_normed = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=0) + remote_cache_key_k_pe = BlocksCacheKey( + cluster_id=remote_cluster_id, model_id=1) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key_k_normed, + self.cache[0], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + self.cache_manager.pull_blocks( + remote_cache_key_k_pe, + self.cache[1], # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + else: + remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id) + logger.info("Try pull blocks from remote server") + try: + self.cache_manager.pull_blocks( + remote_cache_key, + self.cache, # type: ignore[has-type] + remote_block_ids, + local_block_ids) + except (TypeError, ValueError): + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type] + ) + except LLMException: + raise RuntimeError( + "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" + ) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Get the finished recving and sending requuests.""" + import copy + req_ids_to_ret = copy.deepcopy(self.finished_reqs) + self.finished_reqs.clear() + if self.llm_datadist_role == LLMRole.PROMPT: + return req_ids_to_ret, None + else: + return None, req_ids_to_ret + + +# adopt this from https://github.com/vllm-project/vllm/blob/main/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + """Context manager for a ZMQ socket""" + + ctx: Optional[zmq.Context] = None # type: ignore[name-defined] + try: + ctx = zmq.Context() # type: ignore[attr-defined] + + if socket_type == zmq.ROUTER: # type: ignore[attr-defined] + socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined] + socket.bind(addr) + elif socket_type == zmq.REQ: # type: ignore[attr-defined] + socket = ctx.socket(zmq.REQ) # type: ignore[attr-defined] + socket.connect(addr) + else: + raise ValueError(f"Unexpected socket type: {socket_type}") + + yield socket + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py deleted file mode 100644 index 2778a6ef27..0000000000 --- a/vllm_ascend/distributed/parallel_state.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional - -import torch -from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, - init_model_parallel_group) - -# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for -# customize parallel solution -_EP: Optional[GroupCoordinator] = None -_ETP: Optional[GroupCoordinator] = None - - -def get_ep_group() -> GroupCoordinator: - assert _EP is not None, ("expert model parallel group is not initialized") - return _EP - - -def get_etp_group() -> GroupCoordinator: - assert _ETP is not None, ( - "expert tensor parallel group is not initialized") - return _ETP - - -def model_parallel_initialized(): - return (_ETP is not None and _EP is not None) - - -def init_ascend_model_parallel( - expert_parallel_size: int = 1, - expert_tensor_parallel_size: int = 1, - world_size: Optional[int] = None, - backend: Optional[str] = None, -): - if model_parallel_initialized(): - return - assert torch.distributed.is_initialized() - world_size = world_size or torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - num_expert_parallel_groups = expert_tensor_parallel_size - num_expert_tensor_parallel_groups = expert_parallel_size - - global _EP - group_ranks = [] - for i in range(num_expert_parallel_groups): - ranks = list(range(i, world_size, num_expert_parallel_groups)) - group_ranks.append(ranks) - - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") - - group_ranks = [] - global _ETP - for i in range(num_expert_tensor_parallel_groups): - ranks = list( - range(i * expert_tensor_parallel_size, - (i + 1) * expert_tensor_parallel_size)) - group_ranks.append(ranks) - - _ETP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="etp") - - -def destory_ascend_model_parallel(): - global _EP - if _EP: - _EP.destroy() - _EP = None - - global _ETP - if _ETP: - _ETP.destroy() - _ETP = None diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 02ecd6625b..27d0131720 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -116,6 +116,27 @@ # value to False to disable the optimized model. "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), + # `LLMDataDistCMgrConnector` required variable. `DISAGGREGATED_PREFILL_RANK_TABLE_PATH` is + # used for llmdatadist to build the communication topology for kv cache transfer, it is + # a required variable if `LLMDataDistCMgrConnector` is used as kv connector for disaggregated + # pd. The rank table can be generated by adopting the script `gen_ranktable.sh` + # in vllm_ascend's example folder. + "DISAGGREGATED_PREFILL_RANK_TABLE_PATH": + lambda: os.getenv("DISAGGREGATED_PREFILL_RANK_TABLE_PATH", None), + # `LLMDataDistCMgrConnector` required variable. `VLLM_ASCEND_LLMDD_RPC_IP` is used as the + # rpc communication listening ip, which will be used to receive the agent metadata from the + # remote worker. + "VLLM_ASCEND_LLMDD_RPC_IP": + lambda: os.getenv("VLLM_ASCEND_LLMDD_RPC_IP", "0.0.0.0"), + # `LLMDataDistCMgrConnector` required variable. `VLLM_LLMDD_RPC_PORT` is used as the + # rpc communication listening port, which will be used to receive the agent metadata from the + # remote worker. + "VLLM_LLMDD_RPC_PORT": + lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)), + # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible + # and the mla_pa will be the default path of deepseek decode path. + "VLLM_ASCEND_MLA_PA": + lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)) } # end-env-vars-definition diff --git a/vllm_ascend/eplb/adaptor/abstract_adaptor.py b/vllm_ascend/eplb/adaptor/abstract_adaptor.py new file mode 100644 index 0000000000..8513b69ea0 --- /dev/null +++ b/vllm_ascend/eplb/adaptor/abstract_adaptor.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from abc import ABC, abstractmethod + +class EplbAdaptor(): + + def __init__(self, **args): + pass + + @abstractmethod + def get_rank_expert_workload(self, num_moe_layers): + raise NotImplementedError + + @abstractmethod + def get_init_expert_map(self): + raise NotImplementedError + + @abstractmethod + def do_update_expert_map(self): + raise NotImplementedError + + @abstractmethod + def do_update_expert_weight(self): + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py new file mode 100644 index 0000000000..585fcad7eb --- /dev/null +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -0,0 +1,209 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os +import json +import torch +import random +import torch.distributed as dist +import numpy as np + +from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor +from vllm.logger import logger + + + +class VllmEplbAdaptor(EplbAdaptor): + + def __init__(self, model, **args): + super().__init__(**args) + self.model = model + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + self.param_dict = dict(self.model.named_parameters()) + self.num_dense_layers = self.model.config.first_k_dense_replace + self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers + self.global_expert_num = self.model.config.n_routed_experts + + + # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here + self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset", + "w2_weight_scale", "w2_weight_offset"] + + self.expert_map_per_layer = dict() # reference to expert map on device for expert map update + self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently + for layer_idx in range(self.num_moe_layers): + self.expert_map_per_layer[self.num_dense_layers + layer_idx] =\ + self.model.get_expert_map(self.num_dense_layers + layer_idx) + + # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved + num_buffer_tensor = torch.where(self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + self.buffer_tensor_list = [[] for _ in range(num_buffer_tensor)] + self.init_buffer_tensor(num_buffer_tensor) + + self.expert_param_per_layer = dict() + self.init_expert_param_per_layer() + + self.log2phy_map_per_layer = dict() + for layer_idx in range(self.num_moe_layers): + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\ + self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + self.all_topk_ids = [] + + def init_buffer_tensor(self, num_buffer_tensor): + for name in self.expert_weight_names: + complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name + expert_tensor = self.param_dict[complete_name].data[0:num_buffer_tensor] + buffer_tensors = torch.empty_like(expert_tensor) + for buffer_id in range(num_buffer_tensor): + self.buffer_tensor_list[buffer_id].append(buffer_tensors[buffer_id]) + + def init_expert_param_per_layer(self): + num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) +\ + ".mlp.experts." + self.expert_weight_names[0]].data.shape[0] + for moe_layer_id in range(self.num_moe_layers): + layer_idx = self.num_dense_layers + moe_layer_id + self.expert_param_per_layer[layer_idx] = list() + for local_expert_id in range(num_local_expert): + self.expert_param_per_layer[layer_idx].append( + [self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name].data[local_expert_id] + for name in self.expert_weight_names] + ) + + # def collect_topk_ids(self, dummy_run=False): + # if dummy_run: + # return + # self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers)) + + def get_rank_expert_workload(self) -> torch.Tensor: + self.moe_load = self.model.get_all_moe_loads() + return self.moe_load + + def get_init_expert_map(self, num_moe_layers): + expert_map = self.model.get_all_expert_map(num_moe_layers) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + gathered = torch.empty((world_size, *expert_map.shape), # [W, L, E] + dtype=expert_map.dtype, + device=expert_map.device) + + dist.all_gather_into_tensor(gathered, expert_map) + all_maps = gathered.permute(1, 0, 2) + all_expert_maps = all_maps.cpu() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \ + all_expert_maps[layer_idx][self.rank_id] + + return all_expert_maps + + def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path): + + try: + expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(expert_map_path) + expert_map_all = self.local2global(expert_map_tensor) + except (TypeError, FileNotFoundError, OSError): + expert_map_all = self.determine_expert_map_all() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[layer_idx+3] = \ + expert_map_all[layer_idx][self.rank_id] + return expert_map_all + + def _expert_file_to_tensor(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + logger.error(f"failed to read expert_map_path: {expert_map_path}") + + def do_update_expert_map(self, layer_id, updated_expert_map): + self.expert_map_per_layer[layer_id].copy_(updated_expert_map) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): + for expert_tensor, buffer_tensor in zip( + self.expert_param_per_layer[layer_id][local_expert_to_replace], + self.buffer_tensor_list[buffer_tensor_id] + ): + expert_tensor.copy_(buffer_tensor) + + def do_update_log2phy_map(self, layer_id, updated_log2phy_map): + if self.log2phy_map_per_layer[layer_id] is not None: + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map[self.rank_id]) + + def local2global(self, + placement_local: torch.Tensor + ) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def determine_expert_map_all(self): + + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.num_moe_layers, self.world_size, self.global_expert_num), + -1, + dtype=torch.int32 + ) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(self.num_moe_layers, -1) + + return expert_map_all \ No newline at end of file diff --git a/vllm_ascend/eplb/core/loader/abstract_loader.py b/vllm_ascend/eplb/core/loader/abstract_loader.py new file mode 100644 index 0000000000..b1bef11c5d --- /dev/null +++ b/vllm_ascend/eplb/core/loader/abstract_loader.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from abc import ABC, abstractmethod + +class ExpertWeightLoader: + + @abstractmethod + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/eplb/core/loader/device_transfer_loader.py b/vllm_ascend/eplb/core/loader/device_transfer_loader.py new file mode 100644 index 0000000000..579f653323 --- /dev/null +++ b/vllm_ascend/eplb/core/loader/device_transfer_loader.py @@ -0,0 +1,155 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import torch +import torch.distributed as dist +from enum import Enum + +from vllm.logger import logger +from vllm_ascend.eplb.core.loader.abstract_loader import ExpertWeightLoader + +class ExpertWeightUpdateState(Enum): + WAITING = 0 # waiting for updated expert_map by EplbWorker + READY = 1 # ready for d2d expert weights updating + TRANSFERING = 2 # d2d finished and waiting for updating expert_map into model + +class D2DExpertWeightLoader(ExpertWeightLoader): + + def __init__(self, eplb_adaptor): + self.comm_op_list = None + self.eplb_adaptor = eplb_adaptor + + self.updated_expert_map = None + self.updated_log2phy_map = None + self.layer_id = -1 # layer id to be updated + self.state = ExpertWeightUpdateState.WAITING + self.recv_expert_list = [] + self.mock_flag = True + + def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, + updated_expert_map, layer_id): + # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task + if self.state != ExpertWeightUpdateState.WAITING: + logger.error("current d2d weight update tasks are on-going, cannot accept new weight update task") + return + + # If neither send nor receive task is needed for this layer on this rank, return + if not (expert_send_info or expert_recv_info): + return + + self.updated_expert_map = updated_expert_map + + self.layer_id = layer_id + self.comm_op_list = [] + for send_info in expert_send_info: + dst_rank, global_expert_id_to_send = send_info + local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item() + for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]: + self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + + buffer_tensor_id = 0 + for recv_info in expert_recv_info: + recv_rank, global_expert_id_to_recv = recv_info + for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]: + self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item() + self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id)) + buffer_tensor_id += 1 + + self.state = ExpertWeightUpdateState.READY + + def set_log2phy_map(self, log2phy_map): + self.updated_log2phy_map = log2phy_map + + def asyn_expert_weight_transfer(self, reqs): + # Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched + if self.state != ExpertWeightUpdateState.READY: + return + + # set asynchronous stream for d2d expert weight transfer + if self.comm_op_list: + ret_list = dist.batch_isend_irecv(self.comm_op_list) + reqs.extend(ret_list) + + self.state = ExpertWeightUpdateState.TRANSFERING + + def update_expert_map_and_weight(self, reqs, redundant_enable): + # Only after send/recv tasks have been luanched, expert_map and weight can be updated + if self.state != ExpertWeightUpdateState.TRANSFERING: + return + + # Waiting for send/recv tasks finish + for req in reqs: + req.wait() + + if self.comm_op_list is not None: + self.comm_op_list = None + + # update expert_map + self.eplb_adaptor.do_update_expert_map(self.layer_id, self.updated_expert_map) + + #update log2phy_map + if redundant_enable: + self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.updated_log2phy_map) + + # update expert weight + buffer_tensor_id = 0 + for recv_expert_info in self.recv_expert_list: + local_expert_to_replace, buffer_tensor_id = recv_expert_info + self.eplb_adaptor.do_update_expert_weight(self.layer_id, local_expert_to_replace, buffer_tensor_id) + + logger.info(f"[EPLB] finished update expert weight for layer: {self.layer_id}") + + self.recv_expert_list = [] + self.updated_expert_map = None + self.layer_id = -1 + self.state = ExpertWeightUpdateState.WAITING + + def generate_mock_update_info(self, rank_id): + if rank_id == 0: + expert_send_info = [(1, 0)] + expert_recv_info = [(1, 64)] + updated_expert_map_list = [-1] + [i for i in range(1, 64)] + [0] + [j for j in [-1] * 191] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + if rank_id == 1: + expert_send_info = [(0, 64)] + expert_recv_info = [(0, 0)] + updated_expert_map_list = [0] + [k for k in [-1] * 63] + [i for i in range(1, 64)] + [j for j in [-1] * 129] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + if rank_id == 2: + expert_send_info = [(3, 128)] + expert_recv_info = [(3, 192)] + updated_expert_map_list = [k for k in [-1] * 129] + [i for i in range(1, 64)] + [0] + [j for j in [-1] * 63] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + if rank_id == 3: + expert_send_info = [(2, 192)] + expert_recv_info = [(2, 128)] + updated_expert_map_list = [k for k in [-1] * 128] + [0] + [k for k in [-1] * 64] + [i for i in range(1, 64)] + updated_expert_map = torch.tensor(updated_expert_map_list) + layer_id = 3 + + self.mock_flag = False + return (expert_send_info, expert_recv_info, updated_expert_map, layer_id) + + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError + diff --git a/vllm_ascend/eplb/core/policy/__init__.py b/vllm_ascend/eplb/core/policy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/eplb/core/policy/dynamic_ep.py b/vllm_ascend/eplb/core/policy/dynamic_ep.py new file mode 100644 index 0000000000..c081191aab --- /dev/null +++ b/vllm_ascend/eplb/core/policy/dynamic_ep.py @@ -0,0 +1,337 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +from collections import defaultdict +import numpy as np + +from .eplb_policy import EplbPolicy, DynamicConfig + + +class DynamicTable: + # workload_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的热度 + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的热度 + # 对于收集不到的专家,填为 -1 + workload_table = None + + # placement_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的物理专家id + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的物理id + # 对于收集不到的专家,填为 -1 + placement_table = None + + +class DynamicEplb(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # 热点专家拆分为冗余专家 + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + if item_id in boxes[i]: + continue + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # 热点专家拆分为冗余专家 + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num,), dtype='object') + all_weights[: route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # 无冗余专家方案 + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [int(x) for x in current_expert_table[layer_id][card_id]] + new_list = [int(x) for x in global_deployment[layer_id][card_id]] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + + def rebalance_experts(self, current_expert_table, expert_workload): + + info = DynamicTable() + info.workload_table = np.array(expert_workload) + info.placement_table = np.array(current_expert_table) + layer_num, num_npus, experts_per_npu= info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # 计算负载均衡,部署冗余专家 + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + # 校验专家数量、卡数量、冗余专家数量不能超过卡数量 + if num_original_expert != expert_num: + raise ValueError(f"原始专家数量 {num_original_expert} 必须等于 expert_num {expert_num}") + + if num_npus <= 0: + raise ValueError("NPUs 数量必须大于 0") + + if num_npus < num_redundancy_expert: + raise ValueError(f"NPUs 数量 {num_npus} 必须大于或等于冗余专家数量 {num_redundancy_expert}") + + # 每个卡部署的专家数量 一个冗余专家 + global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + # 遍历获得每一层的放置策略,考虑计算均衡 + max_heat_per_layer_after = np.zeros([layer_num]) + for layer in range(layer_num): + # 获取当前层专家ID和对应负载,负载需要进行正则化处理, 每个卡加一个冗余专家 + weights = np.zeros((expert_num,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads[layer]): + weights[expert_id] = (expert_id, workload_weight) + + # 获取每一层全局计算均衡的放置策略 + result, layer_deployment = self.original_compute_balanced_pack_redundancy( + weights, num_npus, num_redundancy_expert + ) + + global_deployment[layer] = layer_deployment + max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_weight'])['total_weight'] + + new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) + # 获取层优先级 + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array(new_global_deployment).tolist() + diff --git a/vllm_ascend/eplb/core/policy/dynamic_ep_v2.py b/vllm_ascend/eplb/core/policy/dynamic_ep_v2.py new file mode 100644 index 0000000000..775cf5f71d --- /dev/null +++ b/vllm_ascend/eplb/core/policy/dynamic_ep_v2.py @@ -0,0 +1,842 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +from collections import defaultdict +import numpy as np +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 + # 一台机器上,一层最多搬运多少专家 + + ep_worldsize = 64 # 整个集群上所有的专家分布在多少个die上 + num_die_per_host = 8 # 每台机器上有几个die + + +class EplbPolicy: + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + 传入weight并返回相关限制条件下的专家复制和放置 + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass + +class DynamicTable: + # workload_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的热度 + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的热度 + # 对于收集不到的专家,填为 -1 + workload_table = None + + # placement_table: + # 三维矩阵,[layer, gpus, experts_per_gpu_per_layer] -> value: 所在位置的物理专家id + # 大小为 层数 * 卡数 * 每层每卡的专家数量 + # 里面i, j, k的元素代表 第 i 层 第 j 张卡第 k 个专家的物理id + # 对于收集不到的专家,填为 -1 + placement_table = None + + +class DynamicEplbV2(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # 热点专家拆分为冗余专家 + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # 热点专家拆分为冗余专家 + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num,), dtype='object') + all_weights[: route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # 无冗余专家方案 + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def calculate_initial_imbalance(global_deployment, new_layer_workloads): + + device_num = global_deployment.shape[1] + layer_imbalance = [] + expert_num = np.zeros_like(new_layer_workloads) + # 基于部署做更新负载 + for layer_id, layer in enumerate(global_deployment): + for device in layer: + for expert_id in device: + expert_num[layer_id][expert_id] += 1 + + for layer_id, layer in enumerate(global_deployment): + cur_layer_max_workload = 0 + total_workload = 0 + for box in layer: + box_workload = 0 + for expert_id in box: + update_workload = new_layer_workloads[layer_id][expert_id] / expert_num[layer_id][expert_id] + box_workload += update_workload + total_workload += update_workload + if cur_layer_max_workload < box_workload: + cur_layer_max_workload = box_workload + + cur_layer_imbalance = cur_layer_max_workload / (total_workload / device_num) + layer_imbalance.append(cur_layer_imbalance) + + return layer_imbalance + + @staticmethod + def compute_redundant_assignments(base_experts, num_redundant_experts, num_experts): + """ + 计算每个基础专家需要分配的冗余专家,并动态调整专家权重 + 返回冗余分配表和更新后的基础专家权重列表 + """ + redundant_assignments = [[] for _ in range(num_experts)] + current_weights = base_experts.copy() + + for i in range(num_redundant_experts): + # 按权重降序排序(使用稳定排序保持相同权重的顺序) + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + # 选择当前权重最高的专家 + target_expert = sorted_weights[0] + expert_id, original_weight = target_expert + + # 计算添加冗余后的新平均权重 + current_redundancy = len(redundant_assignments[expert_id]) + new_avg_weight = original_weight * (current_redundancy + 1) / (current_redundancy + 2) + + # 更新分配表和权重列表 + redundant_assignments[expert_id].append(num_experts + i) + current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) + + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + return redundant_assignments, sorted_weights + + @staticmethod + def prepare_expert_list(base_experts, redundant_assignments, num_redundant_experts): + """ + 生产冗余专家的完整列表,并按权重降序排序 + """ + redundant_expert_list = np.empty(num_redundant_experts, dtype=object) + + # 填充冗余专家(使用对应基础专家的当前权重) + index = 0 + num_experts = len(redundant_assignments) + for expert_id in range(num_experts): + for _ in redundant_assignments[expert_id]: + redundant_expert_list[index] = (expert_id, next(w for eid, w in base_experts if eid == expert_id)) + index += 1 + + # 按权重降序排序 + sorted_indices = np.argsort([w for _, w in redundant_expert_list], kind='stable')[::-1] + return [redundant_expert_list[i] for i in sorted_indices] + + @staticmethod + def non_redundant_expert_information(origin_deployment, updated_weights, num_radundant_experts): + + device_num = len(origin_deployment) + + device_assignments = [[] for _ in range(device_num)] + device_weights = [[] for _ in range(device_num)] + device_loads = [0] * device_num + device_counts = [0] * device_num + if num_radundant_experts: + start_id = 1 + else: + start_id = 0 + + # 统计卡上非冗余专家信息 + for box_id, box in enumerate(origin_deployment): + for i in range(start_id, len(box)): + device_assignments[box_id].append(box[i]) + cur_weight = next(weight for expert_id, weight in updated_weights if expert_id == box[i]) + device_weights[box_id].append(cur_weight) + device_loads[box_id] += cur_weight + device_counts[box_id] += 1 + + return device_assignments, device_weights, device_loads, device_counts + + @staticmethod + def recomputing_weight(layer_workloads, device_assignments, device_weights, device_loads): + # 统计专家出现次数 + num_all_experts = [0] * len(layer_workloads) + num_devices = len(device_assignments) + for device_id in range(num_devices): + num_expert_per_npu = len(device_assignments[device_id]) + for idx in range(num_expert_per_npu): + num_all_experts[idx] += device_assignments[device_id][idx] + + for device_id in range(num_devices): + num_expert_per_npu = len(device_weights[device_id]) + total_weight = 0.0 + for idx in range(num_expert_per_npu): + expert_id = device_assignments[device_id][idx] + if num_all_experts[expert_id] == 0: + print("Error: Division by zero") + device_weights[device_id][idx] = layer_workloads[expert_id] / num_all_experts[expert_id] + total_weight += device_weights[device_id][idx] + device_loads[device_id] = total_weight + + return device_weights, device_loads + + @staticmethod + def distribute_redun_experts(self, layer_workloads, device_assignments, device_weights, device_loads, device_counts, redundant_expert_list, + items_per_device, expert_form_device, num_experts): + + num_devices = len(device_assignments) + com_between_devices = [{} for _ in range(num_devices)] + + for expert_id, weight in redundant_expert_list: + # 寻找最优设备(满足容量限制且负载最小) + candidate = -1 + for dev_id in range(num_devices): + # 保证设备内节点不同 + if expert_id in device_assignments[dev_id]: + continue + # 检查容量限制 + if device_counts[dev_id] < items_per_device: + # 选择负载最小的候选设备 + if candidate == -1 or device_loads[dev_id] < device_loads[candidate]: + candidate = dev_id + if candidate != -1: + # 分配专家到选定的设备 + device_assignments[candidate].insert(0, expert_id) + device_weights[candidate].insert(0, weight) + device_loads[candidate] += weight + device_counts[candidate] += 1 + + communication_box_index = expert_form_device[expert_id] + com_between_devices[candidate][communication_box_index] = expert_id + # 极端情况下存在冗余专家没装箱 导致箱子有空位 随机填入专家 待优化 + flag = False + for dev_id in range(num_devices): + # 检查容量限制 + if device_counts[dev_id] < items_per_device: + # 遍历合适的专家 + for expert_id in range(num_experts): + if expert_id not in device_assignments[dev_id]: + flag = True + # 随机初始化一个权重 + weight = 0.0 + # 和该专家相关的卡权重发生变化 待修改 + device_assignments[dev_id].insert(0, expert_id) + device_weights[dev_id].insert(0, weight) + device_loads[dev_id] += weight + device_counts[dev_id] += 1 + + communication_box_index = expert_form_device[expert_id] + com_between_devices[dev_id][communication_box_index] = expert_id + break + + if flag: + device_weights, device_loads = self.recomputing_weight(layer_workloads, device_assignments, device_weights, device_loads) + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def redundancy_again(self, layer_workloads, origin_weights, num_redundant_experts, origin_deployment, expert_form_device, num_node, + is_node_redundant): + + # 每张卡上专家数量 + expert_num_per_device = origin_deployment.shape[1] + + num_experts = len(origin_weights) + if is_node_redundant: + num_experts = num_experts * num_node + + # 根据新负载重新计算冗余专家 + redundant_assignments, updated_weights = self.compute_redundant_assignments(origin_weights, + num_redundant_experts, + num_experts) + + # 收集冗余专家信息并排序 + redundant_expert_list = self.prepare_expert_list(updated_weights, redundant_assignments, num_redundant_experts) + + # 收集重新计算冗余后卡上非冗余专家信息 + device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( + origin_deployment, updated_weights, num_redundant_experts) + + # 新计算的冗余专家进行分配 + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( + self, + layer_workloads, + device_assignments, + device_weights, + device_loads, + device_counts, + redundant_expert_list, + expert_num_per_device, + expert_form_device, + num_experts) + + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def generate_allocation_report(device_assignments, device_weights, device_loads, device_counts): + """ + 生成最终分配报告并计算最大负载 + """ + report = [] + max_load = 0.0 + + for dev_id in range(len(device_assignments)): + current_load = device_loads[dev_id] + max_load = max(max_load, current_load) + + report.append({ + "device_id": dev_id + 1, + "assigned_experts": device_assignments[dev_id], + "expert_weights": device_weights[dev_id], + "total_load": current_load, + "expert_count": device_counts[dev_id] + }) + + return report, max_load + + @staticmethod + def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, next_device_id, cur_layer_result, + com_between_devices): + + cur_device_deployment = cur_layer_result[cur_device_id]['assigned_experts'] + next_device_deployment = cur_layer_result[next_device_id]['assigned_experts'] + + cur_device_weight = cur_layer_result[cur_device_id]['expert_weights'] + next_device_weight = cur_layer_result[next_device_id]['expert_weights'] + + # 两张卡上对应的两个专家进行交换 + cur_expert_id = cur_device_deployment[cur_exchange_index] + next_expert_id = next_device_deployment[next_exchange_index] + cur_device_deployment[cur_exchange_index] = next_expert_id + next_device_deployment[next_exchange_index] = cur_expert_id + + cur_expert_weight = cur_device_weight[cur_exchange_index] + next_expert_weight = next_device_weight[next_exchange_index] + cur_device_weight[cur_exchange_index] = next_expert_weight + next_device_weight[next_exchange_index] = cur_expert_weight + + cur_layer_result[cur_device_id]['total_load'] += next_expert_weight - cur_expert_weight + cur_layer_result[next_device_id]['total_load'] += cur_expert_weight - next_expert_weight + + # 记录这两卡进行了通信 + com_between_devices[cur_device_id][next_device_id] = next_expert_id + com_between_devices[next_device_id][cur_device_id] = cur_expert_id + + @staticmethod + # 分层调整冗余专家 + def redundant_expert_deployment(self, layer_workloads, original_deployment, expert_form_device, node_num, + is_node_redundant): + device_num, per_device_expert_num = original_deployment.shape + route_expert_num = layer_workloads.shape[0] + redundancy_expert_num = per_device_expert_num * device_num - route_expert_num + per_node_device_num = device_num // node_num + per_node_route_expert_num = per_node_device_num * (per_device_expert_num - 1) + per_node_redun_expert_num = redundancy_expert_num // node_num + + weights = np.zeros((route_expert_num,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + weights[expert_id] = (expert_id, int(workload_weight)) + + if is_node_redundant: + + device_assignments = [] + device_weights = [] + device_loads = [] + device_counts = [] + com_between_devices = [] + + for node_id in range(node_num): + cur_node_weights = weights[ + node_id * per_node_route_expert_num: (node_id + 1) * per_node_route_expert_num] + cur_original_deployment = original_deployment[ + node_id * per_node_device_num: (node_id + 1) * per_node_device_num] + + cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( + self, + layer_workloads, + cur_node_weights, + per_node_redun_expert_num, + cur_original_deployment, + expert_form_device, + node_num, + is_node_redundant) + device_assignments += cur_device_assignments + device_weights += cur_device_weights + device_loads += cur_device_loads + device_counts += cur_device_counts + com_between_devices += cur_com_between_devices + + else: + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( + self, + layer_workloads, + weights, + redundancy_expert_num, + original_deployment, + expert_form_device, + node_num, + is_node_redundant) + # 生成报告 + report, max_load = self.generate_allocation_report(device_assignments, device_weights, device_loads, + device_counts) + + return report, max_load, com_between_devices + + @staticmethod + def two_device_exchange_experts(cur_device_result, exchange_device_result, cur_exchanged_expert_id, + next_exchanged_expert_id, ave_workload, increment, num_redundancy_expert, cur_org_placement, next_org_placement): + + cur_device_weight = cur_device_result['expert_weights'] + next_device_weight = exchange_device_result['expert_weights'] + + cur_device_expert_id = cur_device_result['assigned_experts'] + next_device_expert_id = exchange_device_result['assigned_experts'] + + cur_device_total_weight = int(cur_device_result['total_load']) + next_device_total_weight = int(exchange_device_result['total_load']) + max_weight = max(cur_device_total_weight, next_device_total_weight) + + cur_exchange_index = -1 + next_exchange_index = -1 + + redun = False + if num_redundancy_expert != 0: + redun = True + + for index, weight in enumerate(cur_device_weight): + for next_index, next_weight in enumerate(next_device_weight): + # 跳过冗余专家 + if (index == 0 or next_index == 0) and redun : + continue + # 交换专家限制卡内专家不同 + change_flag = True + if ((cur_device_expert_id[index] in next_device_expert_id or next_device_expert_id[next_index] in cur_device_expert_id) or + (cur_org_placement[0] == next_device_expert_id[next_index] or next_org_placement[0] == cur_device_expert_id[index])): + change_flag = False + # 选择的专家不能是参与过交换的 + if (cur_device_expert_id[index] not in cur_exchanged_expert_id) and ( + next_device_expert_id[next_index] not in next_exchanged_expert_id) and change_flag: + cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight + next_total_weight_after_exchange = next_device_total_weight - next_weight + weight + exchange_max_weight = max(cur_total_weight_after_exchange, next_total_weight_after_exchange) + if exchange_max_weight < max_weight and (max_weight - exchange_max_weight) >= ( + ave_workload * increment): + max_weight = exchange_max_weight + cur_exchange_index = index + next_exchange_index = next_index + + return cur_exchange_index, next_exchange_index + + @staticmethod + def expert_exchange_between_devices(self, ave_workload, increment, cur_layer_result, com_between_devices, num_redundancy_expert, + org_placement_table, node_idx=0, per_node_device_num=0, is_node_redundant=False): + + if is_node_redundant: + # 拿出当前节点内设备的信息 + cur_devices_result = cur_layer_result[node_idx * per_node_device_num:(node_idx + 1) * per_node_device_num] + else: + # 拿取所有设备信息 + cur_devices_result = cur_layer_result + + devices_total_weight = [] + for device in cur_devices_result: + devices_total_weight.append((int(device['total_load']), device['device_id'] - 1)) + + # 当迭代次数超过100或负载最大的设备无法进行调整时退出 + exchange_frequency = 100 + while exchange_frequency > 0: + exchange_frequency -= 1 + + # 根据负载从小到大排序 + devices_total_weight.sort(key=lambda x: x[0]) + # 负载最大的设备id + max_weight_device_id = devices_total_weight[-1][1] + + exchange = False + # 按照负载从小到大依次取卡 + for index in range(0, len(devices_total_weight) - 1): + min_weight_device_id = devices_total_weight[index][1] + # 两个节点没有进行过通信 + if min_weight_device_id not in com_between_devices[max_weight_device_id]: + # 找到设备中交换过的专家id,(除了冗余之外通信过的id) + set_cur_com_expert_id = set(com_between_devices[max_weight_device_id].values()) + set_next_com_expert_id = set(com_between_devices[min_weight_device_id].values()) + if num_redundancy_expert != 0: + set_cur_device_expert_id = set(cur_layer_result[max_weight_device_id]['assigned_experts'][1:]) + set_next_device_expert_id = set(cur_layer_result[min_weight_device_id]['assigned_experts'][1:]) + else: + set_cur_device_expert_id = set(cur_layer_result[max_weight_device_id]['assigned_experts']) + set_next_device_expert_id = set(cur_layer_result[min_weight_device_id]['assigned_experts']) + + cur_exchanged_expert_id = set_cur_com_expert_id & set_cur_device_expert_id + next_exchanged_expert_id = set_next_com_expert_id & set_next_device_expert_id + + cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( + cur_layer_result[max_weight_device_id], + cur_layer_result[min_weight_device_id], + cur_exchanged_expert_id, + next_exchanged_expert_id, + ave_workload, + increment, + num_redundancy_expert, + org_placement_table[max_weight_device_id], + org_placement_table[min_weight_device_id]) + + # 有符合条件的专家进行交换 + if cur_exchange_index != -1: + self.exchange_expert(cur_exchange_index, + next_exchange_index, + max_weight_device_id, + min_weight_device_id, + cur_layer_result, + com_between_devices) + + devices_total_weight[-1] = ( + cur_layer_result[max_weight_device_id]['total_load'], max_weight_device_id) + devices_total_weight[index] = ( + cur_layer_result[min_weight_device_id]['total_load'], min_weight_device_id) + exchange = True + break + + if not exchange: + break + + @staticmethod + def exchange_experts(self, layer_result, layer_com_between_devices, num_nodes, device_num, is_node_redundant, + ave_workload, increment, num_redundancy_expert, org_placement_table): + + global_deployment = [] + + if is_node_redundant: + per_node_device_num = device_num // num_nodes + for node_idx in range(num_nodes): + self.expert_exchange_between_devices(self, ave_workload, increment, layer_result, + layer_com_between_devices, num_redundancy_expert, + org_placement_table, node_idx, per_node_device_num, is_node_redundant) + else: + self.expert_exchange_between_devices(self, ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert, org_placement_table) + + max_workload = 0 + for box in layer_result: + global_deployment.append(box['assigned_experts']) + if max_workload < box['total_load']: + max_workload = box['total_load'] + + global_deployment = np.array(global_deployment) + + return global_deployment, max_workload + + @staticmethod + def count_elements(self, lst): + count = 0 + for item in lst: + if isinstance(item, list): + count += self.count_elements(self, item) + else: + count += 1 + return count + + def rebalance_experts(self, current_expert_table, expert_workload): + # 输入:当前专家部署信息和对应的负载信息,形状为layer_num, num_npus, experts_per_npu + info = DynamicTable() + info.workload_table = expert_workload.numpy() + info.placement_table = current_expert_table.numpy() + layer_num, num_npus, experts_per_npu = info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + # 负载信息转化为 58 * 256 + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # 计算负载均衡,部署冗余专家 + num_node = num_npus / 8 + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + expert_from_device = np.zeros((layer_num, num_original_expert)) + # 校验专家数量、卡数量、冗余专家数量不能超过卡数量 + if num_original_expert != expert_num: + raise ValueError(f"原始专家数量 {num_original_expert} 必须等于 expert_num {expert_num}") + + if num_npus <= 0: + raise ValueError("NPUs 数量必须大于 0") + + if num_npus < num_redundancy_expert: + raise ValueError(f"NPUs 数量 {num_npus} 必须大于或等于冗余专家数量 {num_redundancy_expert}") + + # 每个卡部署的专家数量 一个冗余专家 + global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + # 统计更换数据集后的初始58层不均衡度 + layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads) + # 遍历获得每一层的放置策略,考虑计算均衡 + max_heat_per_layer_after = np.zeros([layer_num]) + sum_num = 0 + for layer in range(layer_num): + # 不均衡度小于特定阈值不调整 + if layer_initial_imbalance[layer] < 1.1: + global_deployment[layer] = info.placement_table[layer] + continue + + ave_workload = np.sum(layer_workloads[layer]) / num_npus + for device_id, device in enumerate(info.placement_table[layer]): + for index, expert_id in enumerate(device): + if index != 0: + expert_from_device[layer][expert_id] = device_id + + # 调整冗余专家 + result, max_workload, com_between_devices = self.redundant_expert_deployment(self, layer_workloads[layer], + info.placement_table[layer], + expert_from_device[layer], + num_node, False) + # 交换专家 + global_deployment[layer], new_max_workload = self.exchange_experts(self, result, com_between_devices, + num_node, num_npus, False, ave_workload, + 0.05, num_redundancy_expert, info.placement_table[layer]) + + for device_id in range(num_npus): + com_between_devices[device_id] = {int(key): int(value) for key, value in + com_between_devices[device_id].items()} + sum_num += self.count_elements(self, com_between_devices[device_id]) + + max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_load'])['total_load'] + + # 获取层优先级 + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array(global_deployment).tolist() \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/eplb_policy.py b/vllm_ascend/eplb/core/policy/eplb_policy.py new file mode 100644 index 0000000000..1de60c348d --- /dev/null +++ b/vllm_ascend/eplb/core/policy/eplb_policy.py @@ -0,0 +1,42 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 + # 一台机器上,一层最多搬运多少专家 + + ep_worldsize = 64 # 整个集群上所有的专家分布在多少个die上 + num_die_per_host = 8 # 每台机器上有几个die + + +class EplbPolicy: + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + 传入weight并返回相关限制条件下的专家复制和放置 + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass diff --git a/vllm_ascend/eplb/core/policy/mock_load_balance.py b/vllm_ascend/eplb/core/policy/mock_load_balance.py new file mode 100644 index 0000000000..6626d3fb5c --- /dev/null +++ b/vllm_ascend/eplb/core/policy/mock_load_balance.py @@ -0,0 +1,30 @@ +# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import copy +import random +import torch +import torch + +from .eplb_policy import EplbPolicy, DynamicConfig + +random.seed(42) + +class MockLoadBalance(EplbPolicy): + def __init__(self, config: DynamicConfig): + super().__init__(config) + + def rebalance_experts(self, current_expert_table, expert_workload): + new_table = copy.deepcopy(current_expert_table) + num_layers = len(current_expert_table) + num_card = len(current_expert_table[0]) + + for i in range(num_layers): + # 随机选两个卡 + # indices = random.sample(range(num_card), 2) + indices = [3,1] + + # 交换冗余专家 + expert_id_to_exchange = new_table[i][indices[0]][-1].clone() + new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1] + new_table[i][indices[1]][-1] = expert_id_to_exchange + + return 1, [-i for i in range(num_layers)], new_table \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/policy_factory.py b/vllm_ascend/eplb/core/policy/policy_factory.py new file mode 100644 index 0000000000..7ebd048ff9 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_factory.py @@ -0,0 +1,27 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from .eplb_policy import EplbPolicy, DynamicConfig +from .mock_load_balance import MockLoadBalance +from .dynamic_ep import DynamicEplb +from .dynamic_ep_v2 import DynamicEplbV2 + + + +class PolicyFactory: + @staticmethod + def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: + policy = { + # Constraint applying Dynamic EPLB policy V2: + # If there exists redundant expert: + # only one redundant expert can be placed in one NPU and its physical expert index must be 0 + + # Applying bipartite d2d expert weight update composing + 0:MockLoadBalance, # MockLoadBalance + 1:DynamicEplb, # Dynamic EPLB policy + 2:DynamicEplbV2, # Dynamic EPLB policy V2 + + # Applying greedy d2d expert weight update composing + 3:MockLoadBalance, # MockLoadBalance + 4:DynamicEplb, # Dynamic EPLB policy + 5:DynamicEplbV2, # Dynamic EPLB policy V2 + } + return policy.get(policy_type, MockLoadBalance)(config) diff --git a/vllm_ascend/eplb/core/worker/eplb_worker.py b/vllm_ascend/eplb/core/worker/eplb_worker.py new file mode 100644 index 0000000000..c4aa86a4ad --- /dev/null +++ b/vllm_ascend/eplb/core/worker/eplb_worker.py @@ -0,0 +1,408 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import time +import numpy as np +import networkx as nx +import torch +import torch_npu +import logging +import torch.distributed as dist +from multiprocessing import Process, Queue, Manager +from abc import ABC, abstractmethod +from vllm.logger import logger + +from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory, DynamicConfig +from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils + + +class EplbWorker: + + def __init__(self, shared_dict, policy_type, enable_d2d: bool = True, redundant_enable=0): + self.policy_type = policy_type + self.policy = PolicyFactory.generate_policy(policy_type, DynamicConfig()) + self.shared_dict = shared_dict + self.old_expert_maps = None + self.enable_d2d = enable_d2d + self.redundant_enable = redundant_enable + self.rank_id = dist.get_rank() + + def do_update(self): + # put data in to queue + # in process self.policy.generate_policy() + # get epxert table && tensor + + # async stream + # D2D + # H2D + + # Get initial expert_map + if self.old_expert_maps is None: + self.old_expert_maps = self.get_init_expert_maps() + self.num_local_experts = self.old_expert_maps.max() + 1 + + # Get MOE load information + load_info = self.fetch_and_sum_load_info() + if load_info is None: + return + + #根据负载信息,获取更新后的专家表 + old_placement = self.global2local(self.old_expert_maps, self.num_local_experts) + changed, priority, new_placement = self.calculate_rebalance_experts(load_info, old_placement) + + if not torch.is_tensor(new_placement): + new_placement = torch.tensor(new_placement) + self.check_expert_placement(old_placement, new_placement) + new_expert_maps = self.local2global(new_placement) + self.update_expert_map(new_expert_maps) + logger.debug(f"[EPLB Process new_map differs, performing D2D") + + update_info = self.compose_expert_update_info_bipartite(new_expert_maps, self.old_expert_maps)\ + if self.policy_type <= 2 else self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps) + self.old_expert_maps = new_expert_maps + logger.info("EPLB Process compute complete") + + packed_update_info = self.pack_update_info(update_info) + + return packed_update_info + + def check_expert_placement(self, old_placement, new_placement): + num_layers = old_placement.shape[0] + num_ranks = old_placement.shape[1] + + for layer_id in range(num_layers): + # check if any logical expert is not placed on any rank + if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel(): + logger.error(f"There exists expert not placed on any rank in layer {layer_id}") + new_placement[layer_id] = old_placement[layer_id] + continue + + for rank_id in range(num_ranks): + new_placement_check = new_placement[layer_id][rank_id] + old_placement_check = old_placement[layer_id][rank_id] + + # check if same logical experts are placed on the same NPU + if new_placement_check.numel() != torch.unique(new_placement_check).numel(): + logger.error(f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid") + new_placement[layer_id] = old_placement[layer_id] + break + + # check if there is any experts movement inside one NPU + expert_not_move = torch.isin(new_placement_check, old_placement_check) + if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]): + logger.error(f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid") + new_placement[layer_id] = old_placement[layer_id] + break + + def compose_expert_update_info_bipartite(self, updated_expert_maps_org, current_expert_maps_org): + # transform numpy array to torch tensor + updated_expert_maps = updated_expert_maps_org.clone() + current_expert_maps = current_expert_maps_org.clone() + updated_expert_maps = np.array(updated_expert_maps) + current_expert_maps = np.array(current_expert_maps) + + num_layers = current_expert_maps.shape[0] + num_ranks = current_expert_maps.shape[1] + num_experts = current_expert_maps.shape[2] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + updated_expert_maps_this_layer_org = updated_expert_maps_org[layer_id] + + expert_send_info_this_layer = dict() + expert_recv_info_this_layer = dict() + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if (np.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer)).all(): + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = np.where((current_expert_maps_this_layer == -1) + & (updated_expert_maps_this_layer != -1)) + + # record src ranks for potential transfer + src_ranks_set = dict() + for idx in range(len(dst_rank_indices)): + expert_id = experts_to_recv[idx].item() + if expert_id not in src_ranks_set: + src_ranks_set[expert_id] = np.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + + # loop until all experts are scheduled + while len(dst_rank_indices) > 0: + # construct bipartite graph + graph_expert_update = nx.Graph() + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + # add src ranks + src_rank_ids = src_ranks_set[expert_id] + graph_expert_update.add_nodes_from(src_rank_ids, bipartite=0) + # add dest rank + graph_expert_update.add_node(str(dst_rank_id), bipartite=1) + # add edges + for src_rank_id in src_rank_ids: + graph_expert_update.add_edge(src_rank_id, str(dst_rank_id)) + + # graph may not be connected + connected_components = list(nx.connected_components(graph_expert_update)) + all_matches = {} + # matching in this loop + for i, component in enumerate(connected_components): + subgraph = graph_expert_update.subgraph(component) + component_matching = nx.bipartite.maximum_matching(subgraph) + all_matches.update(component_matching) + + for src_rank, dst_rank in all_matches.items(): + dst_rank = int(dst_rank) + assert src_rank != dst_rank + if graph_expert_update.nodes[src_rank]['bipartite'] == 0: + # currently not scheduled experts in rank dst_rank + experts_v = experts_to_recv[np.where( + dst_rank_indices == dst_rank)] + # src: src_rank, dest: dst_rank, expert: expert_id + expert_id = np.intersect1d(experts_v, np.where( + current_expert_maps_this_layer[src_rank] != -1))[0] + + # record send/rcv pairs + if src_rank not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank] = [] + if dst_rank not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank] = [] + expert_send_info_this_layer[src_rank].append((dst_rank, expert_id)) + expert_recv_info_this_layer[dst_rank].append((src_rank, expert_id)) + + remove_index = np.where(np.logical_and( + dst_rank_indices == dst_rank, experts_to_recv == expert_id)) + + # update + dst_rank_indices = np.delete( + dst_rank_indices, remove_index) + experts_to_recv = np.delete(experts_to_recv, remove_index) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases + def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps): + num_layers = current_expert_maps.shape[0] + num_ranks = current_expert_maps.shape[1] + num_experts = current_expert_maps.shape[2] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + + expert_send_info_this_layer = dict() + expert_recv_info_this_layer = dict() + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if torch.equal(updated_expert_maps_this_layer, current_expert_maps_this_layer): + yield (expert_send_info_this_layer, expert_recv_info_this_layer, updated_expert_maps_this_layer, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ + & (updated_expert_maps_this_layer != -1)) + + # Parse expert_ids each rank needs to send to other ranks + src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ + & (updated_expert_maps_this_layer == -1)) + + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + if dst_rank_id not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank_id] = [] + + if not torch.isin(torch.tensor(expert_id), experts_to_send).any(): + # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert + candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0] + else: + candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id] + + #TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... + src_rank_id = candidate_src_rank_indices[0].item() + if src_rank_id not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank_id] = [] + + expert_send_info_this_layer[src_rank_id].append((dst_rank_id, expert_id)) + expert_recv_info_this_layer[dst_rank_id].append((src_rank_id, expert_id)) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, updated_expert_maps_this_layer, layer_id) + + + def calculate_rebalance_experts(self, load_info, old_placement): + """ + 通过 policy 实例的 rebalance_experts 方法计算 new_map。 + """ + if self.old_expert_maps is None: + return False, None, None + + changed, priority, new_map = self.policy.rebalance_experts(old_placement, load_info) + return changed, priority, new_map + + def get_init_expert_maps(self): + """ + Read the initial expert_map from shared_dict. + """ + return self.shared_dict.get("expert_maps", None) + + def fetch_and_sum_load_info(self): + """ + Each time the subprocess is awakened, read the latest moe_load + (shape: [num_moe_layers, num_experts_per_layer]) from shared_dict. + """ + return self.shared_dict.get("moe_load", None) + + def update_expert_map(self, expert_maps): + + self.shared_dict["expert_maps"] = expert_maps + + def global2local(self, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + + def local2global(self, + placement_local: torch.Tensor + ) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def pack_update_info(self, update_info_generator): + """ + Pack a list of update info tuples for efficient IPC. + """ + send_all = [] + recv_all = [] + maps = [] + log2phy_all = [] + layer_ids = [] + + for send_info, recv_info, new_expert_map, layer_id in update_info_generator: + + send_info_this_rank = send_info[self.rank_id] if self.rank_id in send_info else [] + recv_info_this_rank = recv_info[self.rank_id] if self.rank_id in recv_info else [] + send_all.append(send_info_this_rank) + recv_all.append(recv_info_this_rank) + + maps.append(new_expert_map[self.rank_id].numpy().tolist()) + + if self.redundant_enable: + log2phy_map = ExpertMapUtils.generate_log2phy_map(new_expert_map) + log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist()) + else: + log2phy_all.append([]) + + layer_ids.append(layer_id) + + return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + +class EplbProcess: + def __init__(self, shared_dict, planner_q, block_update_q, redundant_enable, policy_type: int = 0, enable_d2d: bool = True): + """ + Args: + shared_dict: Cross-process shared dict returned by Manager().dict() + policy_type: Integer passed to PolicyFactory.generate_policy + enable_d2d: Whether to enable D2D loading + """ + self.shared_dict = shared_dict + self.policy_type = policy_type + self.enable_d2d = enable_d2d + self.planner_q = planner_q + self.block_update_q = block_update_q + self.redundant_enable = redundant_enable + + # Create EplbWorker instance + self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d, self.redundant_enable) + + + def worker_process(self, planner_q, block_update_q): + """ + Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. + """ + while True: + try: + + planner_q.get() + + packed_update_info = self.worker.do_update() + + while True: + if not block_update_q.empty(): + continue + block_update_q.put(packed_update_info) + break + + except Exception as e: + logger.warning(f"[EPLB subprocess Exiting due to error: {e}", exc_info=True) + break + + def _launch_process(self): + """ + Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). + """ + proc = Process( + target=self.worker_process, + args=(self.planner_q,self.block_update_q), + daemon=True + ) + + proc.start() + return proc + diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py new file mode 100644 index 0000000000..02c03c7933 --- /dev/null +++ b/vllm_ascend/eplb/eplb_updator.py @@ -0,0 +1,253 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import numpy +from typing import Dict, List +import torch.distributed as dist +import vllm.envs as envs +from multiprocessing import Queue, Manager + +from vllm.logger import logger +from vllm_ascend.eplb.core.worker.eplb_worker import EplbProcess +from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader +from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils + +class EplbUpdator: + + def __init__(self, expert_map_path): + self.init_eplb(expert_map_path) + + def set_adaptor(self, adaptor): + self.adaptor = adaptor + self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor) + self.num_moe_layers = self.adaptor.num_moe_layers + self.global_expert_num = self.adaptor.global_expert_num + + def init_eplb(self, expert_map_path): + self.num_expert_load_gather = 10 + self.periodic_load_gather = True + self.redundant_enable = (expert_map_path is not None) + self.num_iterations_eplb_update: torch.int64 = 130 + self.expert_map_path = expert_map_path + + try: + if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + except Exception as e: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + + self.expert_map_initialized = False + self.gate_eplb = True + + self.reqs = [] + self.update_info_all = [] + + self.cur_iterations: torch.int64 = 0 + + self.num_wait_worker_iterations: torch.int64 = 20 + + self.planner_block_queue = Queue() + self.block_update_queue = Queue(maxsize=1) + + self.manager = Manager() + self.shared_dict = self.manager.dict({ + # 当前rank_id的专家表[num_layers,num_experts] + "expert_map": None, + # 热度负载信息 [num_layers, world_size, num_experts] + "moe_load": None, + # 所有的专家表[num_layers, world_size, num_experts] + "expert_maps": None, + }) + + self.eplb = EplbProcess( + shared_dict = self.shared_dict, + planner_q = self.planner_block_queue, + block_update_q = self.block_update_queue, + redundant_enable = self.redundant_enable, + policy_type = 1, + enable_d2d = True + ) + + self.eplb_process = self.eplb._launch_process() + + logger.info(f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})") + + def update_iteration(self): + self.cur_iterations += 1 + if self.cur_iterations == (self.num_iterations_eplb_update +\ + self.num_wait_worker_iterations + self.num_moe_layers): + if not self.gate_eplb: + self.cur_iterations = 0 + + def get_update_info_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update + self.num_wait_worker_iterations) + + def wakeup_eplb_worker_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update - 1) + + def update_expert_weight_flag(self): + weight_update_counter = self.cur_iterations - (self.num_iterations_eplb_update + self.num_wait_worker_iterations) + return (weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers) + + def get_init_expert_map(self): + try: + if not self.expert_map_initialized: + self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers, self.expert_map_path) + self.expert_map_initialized = True + except Exception as e: + logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", exc_info=True) + + def wakeup_eplb_worker(self): + self.planner_block_queue.put(1) + + def forward_before(self): + if self.update_expert_weight_flag(): + (expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0) + rank_id = torch.distributed.get_rank() + if self.redundant_enable: + log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) + self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) + updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map)) + #logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}") + self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info, expert_recv_info, + updated_expert_map_this_rank, layer_id + self.adaptor.num_dense_layers) + + # set asynchronous stream for d2d expert weight update + self.reqs = [] + self.eplb_loader.asyn_expert_weight_transfer(self.reqs) + + def take_update_info_from_eplb_process(self): + # Batch after eplb process being triggered, get update info provided by eplb process + if self.get_update_info_flag(): + self.update_info_all = self.block_update_queue.get() + + + def forward_end(self): + if self.wakeup_eplb_worker_flag(): + moe_load = self.compute_and_set_moe_load(is_clear=True) + self.wakeup_eplb_worker() + + if self.update_expert_weight_flag(): + self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable) + + self.update_iteration() + + def compute_and_set_moe_load(self, is_clear=False): + local_load = self.adaptor.get_rank_expert_workload() + + self._gather_buffer = None + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") + else: + moe_load = local_load.unsqueeze(1) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") + self.adaptor.model.clear_all_moe_loads() + return moe_load + + def warm_up_eplb(self): + + self.get_init_expert_map() + self.compute_and_set_moe_load() + + src_tensor = torch.empty((1,), device=self.device) + self_rank = dist.get_rank() + + comm_op_list = [] + + for dst_rank in range(self.world_size): + if dst_rank == self_rank: + continue + comm_op_list.append( + dist.P2POp(dist.isend, src_tensor, dst_rank) + ) + + for src_rank in range(self.world_size): + if src_rank == self_rank: + continue + comm_op_list.append( + dist.P2POp(dist.irecv, src_tensor, src_rank) + ) + if comm_op_list: + reqs = dist.batch_isend_irecv(comm_op_list) + + for req in reqs: + req.wait() + + def unpack_update_batch(self, packed_update_info): + """ + Unpack the IPC batch back into original update_info_list. + """ + send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor = packed_update_info + + maps = stacked_maps.unbind(0) + layer_ids = layer_id_tensor.tolist() + + if self.redundant_enable: + log2phy_list = stacked_log2phy.unbind(0) + else: + log2phy_list = [None] * len(maps) + + _zip = zip + _send = send_all + _recv = recv_all + _maps = maps + _l2p = log2phy_list + _lids = layer_ids + + recovered = [ + (_s, _r, _m, _lp, _lid) + for _s, _r, _m, _lp, _lid + in _zip(_send, _recv, _maps, _l2p, _lids) + ] + return recovered + + def get_expert_load(self) -> tuple: + expert_maps = self.shared_dict["expert_maps"] + moe_load = self.shared_dict["moe_load"] # Tensor [L, W, global_experts_num] + num_local_experts = expert_maps.max() + 1 + return moe_load, expert_maps, num_local_experts + + def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int): + logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations_eplb_update}...") + self.num_expert_load_gather = num_expert_load_gather + self.num_iterations_eplb_update = num_iterations + logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations_eplb_update} success...") + + def shutdown(self): + """ + Clean up the EPLB process. + """ + if self.eplb_process.is_alive(): + self.eplb_process.terminate() + self.eplb_process.join() + logger.info("[ModelRunner] EPLB process terminated") diff --git a/vllm_ascend/eplb/tool/eplb_utils.py b/vllm_ascend/eplb/tool/eplb_utils.py new file mode 100644 index 0000000000..156f7a9b9d --- /dev/null +++ b/vllm_ascend/eplb/tool/eplb_utils.py @@ -0,0 +1,114 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import random + +class ExpertMapUtils(): + + @classmethod + def generate_index_dicts(cls, tensor_2d): + dict_list = [] + current_idx = 0 + + for row in tensor_2d: + value_to_index = {} + for i in range(row.size(0)): + value = row[i].item() + value_to_index[value] = current_idx + i + dict_list.append(value_to_index) + current_idx += row.size(0) + + return dict_list + + @classmethod + def generate_log2phy_map(cls, expert_map): + num_local_experts = expert_map.max() + 1 + log2phy_map = expert_map.clone() + num_ranks, num_global_expert = log2phy_map.shape + + row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks,\ + num_global_expert) * num_local_experts + log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1] + + for idx in range(num_global_expert): + positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0] + negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0] + num_rank_holding_expert = positive_rank_idx.size(0) + + if num_rank_holding_expert == 1: + log2phy_map[negative_rank_idx, idx] = torch.full((num_ranks - 1,), + log2phy_map[positive_rank_idx, idx].item(), + dtype=log2phy_map.dtype) + else: + random_list = [random.choice(log2phy_map[positive_rank_idx, idx]) + for _ in range(num_ranks - num_rank_holding_expert)] + log2phy_map[negative_rank_idx, idx] = torch.tensor(random_list,\ + dtype=log2phy_map.dtype) + + return log2phy_map + + @classmethod + def global2local(cls, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + + G, _ = placement.shape + device = placement.device + + pt_local = torch.full(( G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + g_idx, k_idx = valid.nonzero(as_tuple=True) + slot_idx = placement[g_idx, k_idx] + + pt_local[g_idx, slot_idx] = k_idx + + return pt_local + + @classmethod + def global2local_load(self, + workload: torch.Tensor, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + L, G, _ = placement.shape + device = placement.device + + wt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=workload.dtype, + device=device) + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + values = workload[l_idx, g_idx, k_idx] + + wt_local[l_idx, g_idx, slot_idx] = values + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return wt_local, pt_local \ No newline at end of file diff --git a/vllm_ascend/eplb/tool/generate_map.py b/vllm_ascend/eplb/tool/generate_map.py new file mode 100644 index 0000000000..b498e73a06 --- /dev/null +++ b/vllm_ascend/eplb/tool/generate_map.py @@ -0,0 +1,65 @@ +import numpy as np +import json +import argparse + + +def split_and_insert(n, k, m): + ''' + n: expert num + k: card num + m: redundant expert num, make sure m%k==0 + ''' + + A = np.arange(n) + + B = np.random.choice(n, size=m, replace=False) + + groups = np.array_split(A, k) + + for j in range(m // k): + for i in range(k): + groups[i] = np.append(groups[i], B[i + j * k]) + return np.concatenate(groups) + + +def random_generation(n_layer=58, n_expert=256, start_layer_idx=0, device_count=128, n_redundant=128, output_name=""): + expert_data = {} + expert_data["moe_layer_count"] = n_layer + layer_list = [] + for i in range(n_layer): + layer = {"layer_id": start_layer_idx + i, "device_count": device_count} + random_placement = split_and_insert(n_expert, device_count, n_redundant) + device_list = [] + step = random_placement.shape[0] // device_count + for j in range(device_count): + device = {} + device["device_id"] = j + device["device_expert"] = random_placement[j * step: (j + 1) * step].tolist() + device_list.append(device) + layer["device_list"] = device_list + layer_list.append(layer) + + expert_data["layer_list"] = layer_list + json_file_path = output_name + + with open(json_file_path, "w") as f: + json.dump(expert_data, f, indent=4) + + print(f"JSON file generated: {json_file_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="python generate_map.py --n_layers 2 --n_experts 256 --card_num 8 --n_redundant 8 --output expert_map.json") + parser.add_argument("--n_layers", type=int, required=True) + parser.add_argument("--n_experts", type=int, required=True) + parser.add_argument("--card_num", type=int, required=True) + parser.add_argument("--n_redundant", type=int, default=0) + parser.add_argument("--output", type=str, default="expert_map.json") + args = parser.parse_args() + + n_layers = args.n_layers + n_experts = args.n_experts + card_num = args.card_num + n_redundant = args.n_redundant + output = args.output + + random_generation(n_layers, n_experts, 0, card_num, n_redundant, output) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 490cd4ed5e..d85572b32c 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -35,14 +35,19 @@ def register_model(): ModelRegistry.register_model( "DeepseekV2ForCausalLM", "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + else: ModelRegistry.register_model( "DeepseekV2ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") - ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") ModelRegistry.register_model( "Qwen3MoeForCausalLM", diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 6ab0837e37..000bd39ed5 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -25,38 +25,33 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union import torch import torch.distributed as dist import torch_npu # noqa: F401 -import vllm.envs as envs from torch import nn from transformers import PretrainedConfig -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import (ReplicatedLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_v2 import \ DeepseekV2ForCausalLM # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, - DeepseekV2DecoderLayer, - DeepseekV2MLAAttention) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -64,7 +59,9 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, + CustomDeepseekV2MLP) from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -74,8 +71,9 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.w8a8_dynamic import ( + AscendW8A8DynamicLinearMethod, apply_mlp) from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -83,16 +81,48 @@ class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__(hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + quant_config=quant_config, + prefix=prefix) + self.is_dynamic_quant = not isinstance( + self.gate_up_proj.quant_method, + UnquantizedLinearMethod) and isinstance( + self.gate_up_proj.quant_method.quant_method, + AscendW8A8DynamicLinearMethod) + def _forward_ms_mlp(self, x): current_ms_metadata = get_multistream_comm_context() assert current_ms_metadata is not None gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() + if self.is_dynamic_quant: + x, dynamic_scale = self.act_fn(gate_up) + x = torch_npu.npu_quant_matmul( + x, + self.down_proj.weight, + self.down_proj.weight_scale, + pertoken_scale=dynamic_scale, + output_dtype=torch.bfloat16, + ) + if self.down_proj.reduce_results and self.down_proj.tp_size > 1: + current_ms_metadata.before_comm_event.record() + with torch.npu.stream(current_ms_metadata.comm_stream): + current_ms_metadata.before_comm_event.wait() + x = tensor_model_parallel_all_reduce(x) + current_ms_metadata.after_comm_event.record() + else: + x = self.act_fn(gate_up) x, _ = self.down_proj(x) - current_ms_metadata.after_comm_event.record() return x @@ -163,7 +193,10 @@ def __init__( self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group - + self.kv_consumer = None + transfer_config = get_current_vllm_config().kv_transfer_config + if transfer_config is not None: + self.kv_consumer = transfer_config.kv_role = "kv_consumer" self.params_dtype = torch.get_default_dtype() ascend_config = get_ascend_config() @@ -173,39 +206,34 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata + attn_metadata = forward_context.attn_metadata + # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + enable_force_load_balance = forward_context.in_profile_run - old_hidden_states = hidden_states.clone() + is_prefill = forward_context.with_prefill + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer: + is_prefill = False # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - hidden_states = self.experts( + experts_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, is_prefill=is_prefill, top_k=CustomDeepseekDBOMoE.top_k, enable_force_load_balance=enable_force_load_balance, - ) * self.routed_scaling_factor - - if self.n_shared_experts is not None: - shared_output = self.shared_experts(old_hidden_states) + shared_experts=self.shared_experts) - if shared_output is not None: - hidden_states = hidden_states + shared_output + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + experts_hidden_states[1]) return hidden_states @@ -225,199 +253,6 @@ def _forward_ms_op_gate( router_logits, _ = self.gate(hidden_states) return router_logits - def _forward_ms_op_tp_allgather( - self, - hidden_states: torch.Tensor, - chunk_hidden_states: torch.Tensor, - num_tokens: int = 0, - ): - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens > 0: - final_hidden_states = final_hidden_states[:-num_tokens] - current_ms_metadata.after_comm_event.record() - return final_hidden_states - - -class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) - else: - hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: - forward_kwargs = {} - if envs.VLLM_USE_V1: - output_shape = hidden_states.shape - output = torch.empty(output_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - forward_kwargs['output'] = output - - output = self.mla_attn.impl.forward(self.mla_attn, - hidden_states_or_q_c, - hidden_states, None, kv_cache, - attn_metadata, - **forward_kwargs) - if envs.VLLM_USE_V1: - output = output.view(-1, output_shape[-1]) - return output - else: - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) - class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer): @@ -440,10 +275,7 @@ def __init__( layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx # TODO: enable mla in vllm-ascend - if model_config.use_mla: - attn_cls = CustomDeepseekDBOMLAAttention - else: - attn_cls = DeepseekV2Attention + attn_cls = CustomDeepseekV2MLAAttention self.self_attn = attn_cls( config=config, hidden_size=self.hidden_size, @@ -461,6 +293,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) + self.tp_size = get_tensor_model_parallel_world_size() + self.dp_size = get_dp_group().world_size + self.tp_group = get_tp_group().device_group + self.global_num_experts = config.n_routed_experts if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace @@ -566,7 +402,26 @@ def _forward_ms_layer( shared_outputs = [] router_logits = [] chunk_hidden_states = [] - + chunk_router_logits = [] + topk_weights = [] + topk_ids = [] + num_moe_tokens = [] + original_shapes = [] + expanded_row_idx = [] + scatter_size_list = [] + gather_size_list = [] + local_expert_idx = [] + scatter_sizes = [] + expanded_expert_idx = [] + sorted_local_expert_idx = [] + sorted_idx = [] + + global_num_experts = len( + self.mlp.experts.expert_map + ) if self.mlp.experts.expert_map is not None else self.global_num_experts + ep_group = get_ep_group() + local_num_experts = global_num_experts // ep_group.world_size + fused_moe_state = get_forward_context().fused_moe_state # block 1 : attention # block 2 : attn tp communication # the attn computation of microbatch 1 can be overlapped with the moe @@ -631,88 +486,221 @@ def _forward_ms_layer( # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata[i] is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata[i].num_prefills > 0 - enable_force_load_balance = False - - if self.mlp.tp_size > 1: - num_token, _ = hidden_states[i].shape - padded_num_tokens = (self.mlp.tp_size - num_token % - self.mlp.tp_size) % self.mlp.tp_size - if padded_num_tokens > 0: - hidden_states[i] = nn.functional.pad( - hidden_states[i], (0, 0, 0, padded_num_tokens)) - chunk_hidden_state = torch.tensor_split(hidden_states[i], - self.mlp.tp_size, - dim=0) - chunk_hidden_states.append(chunk_hidden_state) - local_hidden_states = chunk_hidden_state[self.mlp.tp_rank] - else: - local_hidden_states = hidden_states[i] - - router_logit = self.mlp._forward_ms_op_gate(local_hidden_states) + router_logit = self.mlp._forward_ms_op_gate(hidden_states[i]) router_logits.append(router_logit) if CustomDeepseekDBOMoE.top_k: real_top_k = CustomDeepseekDBOMoE.top_k else: real_top_k = self.mlp.experts.top_k + if (self.tp_size > 1 + and fused_moe_state != FusedMoEState.AllGather): + if num_tokens[i] < self.tp_size: + hidden_states[i] = nn.functional.pad( + hidden_states[i], + (0, 0, 0, self.tp_size - num_tokens[i])) + router_logits[i] = nn.functional.pad( + router_logits[i], + (0, 0, 0, self.tp_size - num_tokens[i])) + chunk_hidden_state = torch.tensor_split(hidden_states[i], + self.tp_size, + dim=0) + chunk_hidden_states.append(chunk_hidden_state) + chunk_router_logit = torch.tensor_split(router_logits[i], + self.tp_size, + dim=0) + chunk_router_logits.append(chunk_router_logit) + tp_rank = get_tensor_model_parallel_rank() + hidden_states[i] = chunk_hidden_states[i][tp_rank] + router_logits[i] = chunk_router_logits[i][tp_rank] + + if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + if attn_metadata[i] is not None: + max_num_tokens_across_dp = attn_metadata[ + i].max_tokens_across_dp + if num_tokens[i] < max_num_tokens_across_dp: + hidden_states[i] = nn.functional.pad( + hidden_states[i], + (0, 0, 0, + max_num_tokens_across_dp - num_tokens[i])) + router_logits[i] = nn.functional.pad( + router_logits[i], + (0, 0, 0, + max_num_tokens_across_dp - num_tokens[i])) + hidden_states[i] = get_dp_group().all_gather( + hidden_states[i], 0) + router_logits[i] = get_dp_group().all_gather( + router_logits[i], 0) + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weight, topk_id, _ = torch_npu.npu_moe_gating_top_k( + router_logits[i], + k=real_top_k, # topk当前写8 + bias=self.mlp.experts.e_score_correction_bias, + k_group=self.mlp.experts.topk_group, # fix: 4 + group_count=self.mlp.experts.num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weight, topk_id = self.mlp.experts.select_experts( + hidden_states=hidden_states[i], + router_logits=router_logits[i], + top_k=real_top_k, + use_grouped_topk=self.mlp.experts.use_grouped_topk, + renormalize=self.mlp.experts.renormalize, + topk_group=self.mlp.experts.topk_group, + num_expert_group=self.mlp.experts.num_expert_group, + custom_routing_function=self.mlp.experts. + custom_routing_function, + scoring_func=self.mlp.experts.scoring_func, + e_score_correction_bias=self.mlp.experts. + e_score_correction_bias, + ) + topk_weight = topk_weight.to(hidden_states[i].dtype) + topk_weights.append(topk_weight) + topk_ids.append(topk_id) + original_shape = hidden_states[i].shape + original_shapes.append(original_shape) + if len(original_shapes[i]) == 3: + hidden_states[i] = hidden_states[i].view( + -1, hidden_states[i].shape[-1]) + num_token, _ = hidden_states[i].shape + num_moe_tokens.append(num_token) + device = hidden_states[i].device + + row_idx_len = num_moe_tokens[i] * real_top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(real_top_k, + -1).permute( + 1, 0).contiguous()) + hidden_states[ + i], expanded_row_idx_i, expanded_expert_idx_i = torch_npu.npu_moe_init_routing( + hidden_states[i], + row_idx=row_idx, + expert_idx=topk_ids[i], + active_num=num_moe_tokens[i]) + expanded_row_idx.append(expanded_row_idx_i) + expanded_expert_idx.append(expanded_expert_idx_i) - hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp( - local_hidden_states, router_logits[i], is_prefill, real_top_k, - enable_force_load_balance) - - # the following kernels will be submitted to the comm stream to overlap the computation of the - # moe computation of next microbatch and the attn computation of next layer context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_COM_FINISH], + MSEventKey.MOE_ALL_TO_ALL], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.MOE_ALL_TO_ALL_FINISH], ) context.before_comm_event.record() with torch.npu.stream(ms_metadata.communicate_stream): context.before_comm_event.wait() - if self.mlp.experts.reduce_results and ( - self.mlp.experts.tp_size > 1 - or self.mlp.experts.ep_size > 1): - hidden_states[i] = tensor_model_parallel_all_reduce( - hidden_states[i]) - hidden_states[ - i] = hidden_states[i] * self.mlp.routed_scaling_factor + global_expert_tokens = torch.bincount( + expanded_expert_idx[i], minlength=global_num_experts) + scatter_size = global_expert_tokens.view( + ep_group.world_size, -1).sum(-1) + scatter_sizes.append(scatter_size) + gather_sizes = torch.empty_like(scatter_sizes[i]) + dist.all_to_all_single(gather_sizes, + scatter_sizes[i], + group=ep_group.device_group) + scatter_size_list_i = scatter_sizes[i].cpu().tolist() + gather_size_list_i = gather_sizes.cpu().tolist() + scatter_size_list.append(scatter_size_list_i) + gather_size_list.append(gather_size_list_i) + expanded_expert_idx[ + i] = expanded_expert_idx[i] % local_num_experts + hidden_states[i] = ep_group.all_to_all(hidden_states[i], 0, 0, + scatter_size_list[i], + gather_size_list[i]) + local_expert_idx_i = ep_group.all_to_all( + expanded_expert_idx[i], 0, 0, scatter_size_list[i], + gather_size_list[i]) + local_expert_idx.append(local_expert_idx_i) + + sorted_local_expert_idx_i, sorted_idx_i = torch.sort( + local_expert_idx[i]) + sorted_local_expert_idx.append(sorted_local_expert_idx_i) + sorted_idx.append(sorted_idx_i) context.after_comm_event.record() + for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_ALL_TO_ALL_FINISH) + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx[i], local_num_experts).to(torch.int64) + group_list_type = 0 + hidden_states[i] = hidden_states[i][sorted_idx[i]] + hidden_states[i] = apply_mlp( + hidden_states[i], + self.mlp.experts.w13_weight, + self.mlp.experts.w13_weight_scale, #17 + self.mlp.experts.w2_weight, + self.mlp.experts.w2_weight_scale, + expert_tokens, #16 + group_list_type=group_list_type, + w1_scale_bias=None, + w2_scale_bias=None) + + resorted_idx = torch.argsort(sorted_idx[i]) + hidden_states[i] = hidden_states[i][resorted_idx] + hidden_states[i] = ep_group.all_to_all(hidden_states[i], 0, 0, + gather_size_list[i], + scatter_size_list[i]) + + hidden_states[i] = torch_npu.npu_moe_finalize_routing( + hidden_states[i], + skip1=None, + skip2=None, + bias=None, + scales=topk_weights[i], + expanded_src_to_dst_row=expanded_row_idx[i], + export_for_source_row=topk_ids[i], + ) + if len(original_shapes[i]) == 3: + hidden_states[i] = hidden_states[i].view(original_shapes[i]) + + # the following kernels will be submitted to the comm stream to overlap the computation of the + # moe computation of next microbatch and the attn computation of next layer context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.FFN_COM_FINISH], after_comm_event=ms_metadata.ms_events[layer_index][i][ MSEventKey.FFN_AR_FINISH], ) - with set_multistream_context(context, i): - if self.mlp.tp_size > 1: - hidden_states[i] = self.mlp._forward_ms_op_tp_allgather( - hidden_states[i], chunk_hidden_states[i], - padded_num_tokens) + context.before_comm_event.record() with torch.npu.stream(ms_metadata.communicate_stream): + context.before_comm_event.wait() + if (self.tp_size > 1 + and fused_moe_state != FusedMoEState.AllGather): + dist.all_gather(list(chunk_hidden_states[i]), + hidden_states[i], self.tp_group) + hidden_states[i] = torch.cat(chunk_hidden_states[i], dim=0) + if num_tokens[i] < self.tp_size: + hidden_states[i] = hidden_states[i][:num_tokens[i]] + elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + hidden_states[ + i] = dist._functional_collectives.reduce_scatter_tensor( + hidden_states[i], + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + hidden_states[i] = hidden_states[i][:num_tokens[i]] + if self.tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + hidden_states[i] = tensor_model_parallel_all_reduce( + hidden_states[i]) # last if shared_outputs[i] is not None: - hidden_states[i] = hidden_states[i] + shared_outputs[i] + hidden_states[i] = hidden_states[ + i] * self.routed_scaling_factor + shared_outputs[i] hidden_states[i] = hidden_states[i].view( num_tokens[i], hidden_dims[i]) - if isinstance(self.mlp, CustomDeepseekDBOMLP - ) and hidden_states[i].dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states[i] *= 1. / self.routed_scaling_factor context.after_comm_event.record() return hidden_states, residual @@ -767,9 +755,7 @@ def _forward_ms_op_post_attn_layernorm( class CustomDeepseekDBOModel(nn.Module): - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -835,6 +821,7 @@ def forward( attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -848,10 +835,12 @@ def forward( residual = intermediate_tensors["residual"] num_normal_layers = (self.first_k_dense_replace - if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() - else self.end_layer - self.start_layer) + if VLLM_ASCEND_ENABLE_DBO and not graph_enable + and self.can_run_ms() else self.end_layer - + self.start_layer) - for i in range(self.start_layer, self.start_layer + num_normal_layers): + moe_start_layer = self.start_layer + num_normal_layers + for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, residual, @@ -859,8 +848,7 @@ def forward( self.start_layer] if kv_caches is not None else None, attn_metadata) - moe_start_layer = self.start_layer + num_normal_layers - if moe_start_layer != self.end_layer: + if moe_start_layer < self.end_layer: # if we enable multistream/dbo, process sparse layers here hidden_states, residual = self._forward_ms_layers( positions=positions, @@ -881,34 +869,18 @@ def forward( def can_run_ms(self): attn_metadata = get_forward_context().attn_metadata - # support mla attention and V1 engine at present - if not self.use_mla or not envs.VLLM_USE_V1: - return False # enable prefill overlap - if attn_metadata is None or attn_metadata.num_prefills == 0: - return False - else: - [token_index, seq_index - ] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_decode_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return False - # check whether the total tokens exceed the threshold - if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split: + if attn_metadata is None or attn_metadata.num_prefills == 0 or not attn_metadata.enable_dbo_across_dp: return False return True - def _forward_ms_layers( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - moe_start_layer: int, - kv_caches: Optional[List[torch.Tensor]] = None, - is_prefill: bool = False, - ): + def _forward_ms_layers(self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + is_prefill: bool = False): if moe_start_layer == self.end_layer: return hidden_states, residual @@ -970,8 +942,9 @@ def forward( attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds) + inputs_embeds, graph_enable) return hidden_states diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 979a6099f1..400c7a0acf 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -28,8 +28,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) @@ -40,6 +40,20 @@ from .deepseek_v2 import CustomDeepseekV2DecoderLayer +class CustomDeepSeekShareHead(SharedHead): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + nn.Module.__init__(self) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) + + class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): def __init__( @@ -61,7 +75,10 @@ def __init__( self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.shared_head = CustomDeepSeekShareHead(config=config, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "shared_head")) self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, model_config, cache_config, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e96b2e9847..6e215b6b81 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -25,7 +25,7 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch_npu @@ -33,11 +33,11 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_dp_group, get_pp_group, get_tensor_model_parallel_world_size, get_tp_group) -from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -65,7 +65,6 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod @@ -285,7 +284,10 @@ def __init__( self.tp_group = get_tp_group().device_group self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() + self.kv_consumer = None + transfer_config = get_current_vllm_config().kv_transfer_config + if transfer_config is not None: + self.kv_consumer = transfer_config.kv_role == "kv_consumer" self.params_dtype = torch.get_default_dtype() @@ -293,23 +295,25 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata + attn_metadata = forward_context.attn_metadata + # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer: + is_prefill = False # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + if self.enable_multistream_moe: + router_logits = None + else: + router_logits, _ = self.gate(hidden_states) experts_hidden_states = self.experts( hidden_states=hidden_states, @@ -318,6 +322,7 @@ def forward( top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=self.shared_experts, + gate=self.gate if self.enable_multistream_moe else None, ) hidden_states = ( @@ -477,7 +482,8 @@ def forward( hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: + is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model + if self.torchair_graph_enabled and not is_mtp_model: forward_kwargs = {} if envs.VLLM_USE_V1: output_shape = hidden_states.shape @@ -727,13 +733,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.num_dense_layers = self.config.first_k_dense_replace + self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers + self.model = CustomDeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) @@ -755,6 +766,39 @@ def forward( inputs_embeds) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = filter(lambda x: ".module." not in x[0], weights) + # weights = ((name, data) for name, data in weights if ".module." not in name) + loaded_params = super().load_weights(weights) + + return loaded_params + + def get_expert_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_map() + + def get_log2phy_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_log2phy_map() + + def get_all_expert_map(self, num_moe_layers): + all_loads = [] + for layer_id in range(num_moe_layers): + load_tensor = self.get_expert_map(3+layer_id) # (num_experts_per_layer,) + all_loads.append(load_tensor) + + return torch.stack(all_loads, dim=0) + + def get_all_moe_loads(self): + all_moe_loads = torch.stack( + [self.model.layers[layer_id + self.num_dense_layers].mlp.experts.moe_load \ + for layer_id in range(self.num_moe_layers)], + dim=0 + ) + return all_moe_loads + + def clear_all_moe_loads(self): + for layer_id in range(self.num_moe_layers): + self.model.layers[layer_id + self.num_dense_layers].mlp.experts.clear_moe_load() class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): pass diff --git a/vllm_ascend/multistream/base.py b/vllm_ascend/multistream/base.py index fba58b460e..420839cde6 100644 --- a/vllm_ascend/multistream/base.py +++ b/vllm_ascend/multistream/base.py @@ -14,6 +14,8 @@ class MSEventKey(Enum): MOE_SE_COMM_FINISH = 6 MOE_SE_COMP_FINISH = 7 MOE_GATE_FINISH = 8 + MOE_ALL_TO_ALL = 9 + MOE_ALL_TO_ALL_FINISH = 10 @dataclass diff --git a/vllm_ascend/multistream/metadata.py b/vllm_ascend/multistream/metadata.py index b521d3f85f..e451f15f26 100644 --- a/vllm_ascend/multistream/metadata.py +++ b/vllm_ascend/multistream/metadata.py @@ -170,6 +170,8 @@ def make_multistream_metadata_ds( MSEventKey.MOE_SE_COMM_FINISH, MSEventKey.MOE_SE_COMP_FINISH, MSEventKey.MOE_GATE_FINISH, + MSEventKey.MOE_ALL_TO_ALL, + MSEventKey.MOE_ALL_TO_ALL_FINISH, ] return MultiStreamMetadata( calculate_stream=torch.npu.current_stream(), diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 430f57b03a..fd32a18abb 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -96,10 +96,12 @@ def model_input_split_v1_mla_attn( seq_lens = attn_metadata.prefill.seq_lens if attn_metadata.num_prefills > 0 else attn_metadata.decode.seq_lens [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(seq_lens, seq_index) - query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] - query_start_loc_post = deepcopy( - attn_metadata.query_start_loc[seq_index:] - ) - attn_metadata.query_start_loc[seq_index] + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] [block_table_pre, block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, seq_index) @@ -223,7 +225,7 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_pre, prefill=prefill_pre, decode=decode_pre, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) attention_metadata_post = _metadata_cls( num_actual_tokens=attn_metadata.num_actual_tokens - token_index, @@ -240,6 +242,6 @@ def model_input_split_v1_mla_attn( attn_state=attn_state_post, prefill=prefill_post, decode=decode_post, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index 8037c9545b..05600aee7a 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Tuple import torch from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -37,7 +37,7 @@ def vanilla_chunked_prefill( scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool = True, -) -> None: +) -> torch.Tensor: num_query_heads = query.shape[1] head_dim = value_cache.shape[3] num_kv_heads = value_cache.shape[2] @@ -138,7 +138,8 @@ def vanilla_chunked_prefill( def vanilla_chunked_prefill_mla( output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) - kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv) + kv_cache: Tuple[ + torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv) block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) query_lens: torch.Tensor, # (batch_size) context_lens: torch.Tensor, # (batch_size) @@ -152,22 +153,25 @@ def vanilla_chunked_prefill_mla( alibi_slopes: Optional[torch.Tensor], causal: bool = True) -> None: batch_size = block_tables.size(0) + assert len(kv_cache) > 1 assert query_lens.size(0) == batch_size num_heads = query.size(1) - block_size = kv_cache.size(1) - latent_kv_dim = kv_cache.size(3) - rope_dim + nope_cache = kv_cache[0] + rope_cache = kv_cache[1] + block_size = nope_cache.size(1) + latent_kv_dim = nope_cache.size(-1) max_num_blocks_per_seq = block_tables.size(1) batch_size = query_lens.size(0) - kv_cache = kv_cache.squeeze() - # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] - cache_kv_c_pe = kv_cache[block_tables].view( - batch_size, max_num_blocks_per_seq * block_size, - latent_kv_dim + rope_dim)[:, :max_context_len, :] - # get kv_c and k_pe + nope_cache = nope_cache.squeeze() + # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe # cached_kv_c: [batch_size, max_context_len, latent_kv] # cached_k_pe: [batch_size, max_context_len, rope_dim] - cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim] - cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:] + cache_kv_c = nope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + latent_kv_dim)[:, :max_context_len, :] + cache_k_pe = rope_cache[block_tables].view( + batch_size, max_num_blocks_per_seq * block_size, + rope_dim)[:, :max_context_len, :] # get k_rope and v # k_nope: [batch_size, max_context_len, num_heads, nope_dim] # value: [batch_size, max_context_len, num_heads, v_head_dim] @@ -258,8 +262,8 @@ def vanilla_chunked_prefill_mla( attn_output = (attn_output[q_mask].view([-1, num_heads, v_head_dim]).to(output.dtype)) - output = output.view([-1, num_heads, v_head_dim]) - output.copy_(attn_output[:query.size(0) - num_add_query]) + attn_output = attn_output.view_as(output) + output.copy_(attn_output) return attn_output diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 05daf69f79..e0819210d8 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,6 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py +import math import os from typing import Any, Callable, List, Optional, Tuple, Union @@ -26,7 +27,8 @@ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group, get_tp_group +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, @@ -36,10 +38,10 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group +from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, +from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, + get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -118,9 +120,24 @@ def fused_experts_with_mc2( top_k: int, expert_map: torch.Tensor = None, moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None + shared_experts: Optional[Any] = None, + is_torchair: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - global_bs = 0 + quant_mode = 0 + ep_group = get_ep_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + tp_world_size = get_tp_group().world_size + + # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, + # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. + global_bs = math.ceil(get_forward_context().max_tokens_across_dp / + tp_world_size) * ep_world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -131,27 +148,20 @@ def fused_experts_with_mc2( "global_bs": global_bs, } - rank = torch.distributed.get_rank() - - quant_mode = 0 - ep_group = get_ep_group().device_group - local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) - - tp_size = get_etp_group().world_size - tp_rank = rank % tp_size - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) @@ -204,20 +214,22 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, + "global_bs": global_bs, } tp_recv_counts = output[5] stage3_kwargs = { "ep_send_counts": ep_recv_counts, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - "tp_send_counts": tp_recv_counts, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -847,17 +859,14 @@ def __init__(self, moe: MoEConfig = None): super().__init__(moe=moe) vllm_config = get_current_vllm_config() - self.ep_group = get_ep_group() - self.ep_size = self.ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs - self.local_batch_size = self.global_batch_size // self.ep_size self.max_model_len = vllm_config.model_config.max_model_len ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: - device_group = self.ep_group.device_group + device_group = get_ep_group().device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) @@ -933,8 +942,7 @@ def apply( if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill) + fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, @@ -945,7 +953,8 @@ def apply( top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, - shared_experts=shared_experts) + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1046,8 +1055,14 @@ def __init__( self.log2phy = None self.global_redundant_expert_num = 0 + # TODO: if this is not need for dynamic eplb with redundant expert, remove this + # self.log2phy = torch.full((self.ep_size, self.global_num_experts), + # -1, + # dtype=torch.int32) + ascend_config = get_ascend_config() expert_map_path = ascend_config.expert_map_path + self.dynamic_eplb = ascend_config.dynamic_eplb if expert_map_path and os.path.exists(expert_map_path): # moe expert load balance expert_load_balancer = ExpertLoadBalancer(expert_map_path, @@ -1055,17 +1070,15 @@ def __init__( self.local_num_experts, self.expert_map = \ expert_load_balancer.get_rank_placement_map( self.moe_instance_id, - get_ep_group().rank_in_group) + self.ep_rank) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, - get_ep_group().rank_in_group) + self.moe_instance_id, self.ep_rank) self.global_redundant_expert_num = \ expert_load_balancer.get_global_redundant_expert_num() else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + self.ep_size, self.ep_rank, self.global_num_experts) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ @@ -1095,6 +1108,10 @@ def __init__( local_num_experts = torch.sum(self.expert_map != -1) \ if self.expert_map is not None else num_experts + self.moe_load = None + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + moe_quant_params = { "num_experts": local_num_experts, "hidden_size": hidden_size, @@ -1108,7 +1125,6 @@ def __init__( in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size - self.ep_group = get_ep_group() # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) @@ -1119,7 +1135,8 @@ def forward(self, is_prefill: bool, enable_force_load_balance: bool = False, top_k: Optional[int] = None, - shared_experts: Optional[Any] = None): + shared_experts: Optional[Any] = None, + gate: Optional[Any] = None): assert self.quant_method is not None if top_k: @@ -1129,8 +1146,21 @@ def forward(self, num_tokens, hidden_size = hidden_states.shape - fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, - is_prefill) + fused_moe_state = get_forward_context().fused_moe_state + # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. + quantized_x_for_share, dynamic_scale_for_share = None, None + from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod + if self.enable_multistream_moe: + assert gate is not None + router_logits, _ = gate(hidden_states) + if isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) + if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) @@ -1154,21 +1184,20 @@ def forward(self, if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: - max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp - if num_tokens < max_num_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + max_num_tokens_across_dp = get_forward_context( + ).max_tokens_across_dp + if num_tokens < max_num_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) # Matrix multiply. - e_hidden_states = self.quant_method.apply( + e_hidden_states, expert_token_num, group_list_type = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, @@ -1188,11 +1217,16 @@ def forward(self, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, ) if shared_experts: if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states + if self.dynamic_eplb: + self.moe_load += expert_token_num if group_list_type else \ + torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: dist.all_gather(list(chunk_hidden_states), e_hidden_states, @@ -1249,3 +1283,17 @@ def _forward_ms_fused_moe_comp( enable_force_load_balance=enable_force_load_balance) return hidden_states + + def update_map(self,new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.log2phy + + def clear_moe_load(self): + self.moe_load.zero_() + + diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 39a4c1cfe8..f55ab8e0cb 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -80,10 +80,7 @@ def native_rope_deepseek_forward(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - max_seq_len: Optional[int] = None): - if max_seq_len is not None and max_seq_len > self.max_seq_len: - _set_cos_sin_cache(self, max_seq_len, query.device, query.dtype) + offsets: Optional[torch.Tensor] = None): if len(key.shape) == 2: key = key[:, None, :] # Note: we implement the non neox_style method with shuffle the last dim and neox style @@ -198,8 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len +def _set_cos_sin_cache(self, max_seq_len, device, dtype): dim = self.rotary_dim freq_extra = 1.0 / (self.base**( @@ -219,9 +215,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(seq_len * self.scaling_factor, - device=device, - dtype=torch.float32) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale @@ -266,11 +260,10 @@ def deepseek_rope_init_func( super(DeepseekScalingRotaryEmbedding, self).__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - self.max_seq_len = max_position_embeddings - _set_cos_sin_cache(self, - max_position_embeddings, - dtype=dtype, - device="npu") + + # NOTE: For ascend friendly computing, reorder sin and cos cache + self.max_seq_len = max_position_embeddings * scaling_factor + _set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu") RotaryEmbedding.forward_oot = rope_forward_oot diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index d817f9063e..ae87010359 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -47,7 +47,16 @@ # Related PR (if no, explain why): # Future Plan: # Remove those patch when vllm merged them -# 2. `vllm.config.ParallelConfig.get_next_dp_init_port` +# 2. `vllm.v1.engine.core.DPEngineCoreProc._init_data_parallel` +# Why: +# There is some bug for ASCEND_RT_VISIBLE_DEVICES usage. +# How: +# The ASCEND_RT_VISIBLE_DEVICES related code is dropped. +# Related PR (if no, explain why): +# No, this is a bug for vllm ascend +# Future Plan: +# Remove this patch once ASCEND_RT_VISIBLE_DEVICES bug is fixed. +# 3. `vllm.config.ParallelConfig.get_next_dp_init_port` # Why: # vllm doesn't support get port from environment. # How: @@ -56,7 +65,7 @@ # Need a PR to vllm to support get port from environment. # Future Plan: # Remove those patch when vllm merged them -# 3. `vllm.config.ParallelConfig.ParallelConfig.stateless_init_dp_group` +# 4. `vllm.config.ParallelConfig.ParallelConfig.stateless_init_dp_group` # Why: # vLLM use gloo backend by default to initialize stateless dp process gourp, but we want to use hccl here to # get better performance @@ -65,7 +74,19 @@ # Related PR (if no, explain why): # Need a PR to vllm to support more backend. # Future Plan: -# Remove those patch when vllm support more backend. +# Remove those patch when vllm merged them +# +# ** File: platform/patch_common/patch_scheduler.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.core.sched.scheduler.Scheduler.destroy_model_parallel()` +# Why: +# Vllm transfer the kv blocks data only when this block have already been full filled. However, this behaviour may cause decode node +# exist prefill behaviour. In order to make decode node work as expected, we always transfer all data whether or not the block is filled. +# How: +# The num_computed_token shall always equals to the token number of request during scheduling. +# Related PR (if no, explain why): https://github.com/vllm-project/vllm/pull/17751 (nixl implementation) +# Future Plan: +# No plan, we will maintain this patch util vllm change it behaviour # # * Worker Patch: # =============== @@ -100,18 +121,6 @@ # Future Plan: # Revert it when the related pr is merged in vllm and vllm-ascend. # -# 2. `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_include_gpu_probs_tensor` and -# `vllm.spec_decode.multi_step_worker.MultiStepWorker.set_should_modify_greedy_probs_inplace` -# Why: -# vLLM `Remove Sampler from Model Code` so vllm-ascend needs adapt to this change. -# How: -# Use vLLM 0.8.4 method to patch it. -# Related PR (if no, explain why): -# - https://github.com/vllm-project/vllm/pull/15195 -# - https://github.com/vllm-project/vllm-ascend/pull/395 -# Future Plan: -# Remove it when we identify the reasons clearly. -# # ** File: worker/patch_common/patch_spec_decode_worker.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker` diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index 86515df86e..d971922840 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -17,8 +17,6 @@ # Adapted from vllm/model_executor/models/qwen2_vl.py # This file is a part of the vllm-ascend project. -import vllm -import vllm.distributed import vllm.envs as envs from torch.distributed import ProcessGroup from vllm.config import ParallelConfig @@ -26,25 +24,6 @@ stateless_init_torch_distributed_process_group -def ascend_destroy_model_parallel(): - """Set the groups to none and destroy them.""" - from vllm.distributed.parallel_state import _DP, _PP, _TP - if _TP: - _TP.destroy() - _TP = None - - if _PP: - _PP.destroy() - _PP = None - - if _DP: - _DP.destroy() - _DP = None - from vllm_ascend.distributed.parallel_state import \ - destory_ascend_model_parallel - destory_ascend_model_parallel() - - def parallel_config_get_dp_port(self) -> int: """ We might need to initialize process groups in multiple @@ -78,6 +57,5 @@ def stateless_init_dp_group(self) -> "ProcessGroup": return dp_group -vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port ParallelConfig.stateless_init_dp_group = stateless_init_dp_group diff --git a/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py b/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py index ca87729540..53ce312676 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py @@ -88,20 +88,4 @@ def sampler_output( return filtered_model_outputs, True -def set_include_gpu_probs_tensor(self) -> None: - # Need include_gpu_probs_tensor for MultiSteoWorker - if hasattr(self.model_runner.model, "sampler"): - self.model_runner.model.sampler.include_gpu_probs_tensor = True - self.model_runner.sampler.include_gpu_probs_tensor = True - - -def set_should_modify_greedy_probs_inplace(self) -> None: - if hasattr(self.model_runner.model, "sampler"): - self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( - True) - self.model_runner.sampler.should_modify_greedy_probs_inplace = True - - MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output) -MultiStepWorker.set_include_gpu_probs_tensor = set_include_gpu_probs_tensor -MultiStepWorker.set_should_modify_greedy_probs_inplace = set_should_modify_greedy_probs_inplace diff --git a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py index 66e7aa56b2..d271e65bfc 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py @@ -57,11 +57,6 @@ def create_worker( ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - # TODO(Yizhou): A quick fix, must be refactored ASAP - draft_worker_kwargs["vllm_config"].parallel_config.expert_parallel_size = 1 - draft_worker_kwargs[ - "vllm_config"].parallel_config.expert_tensor_parallel_size = 1 - draft_model_config = draft_worker_kwargs["vllm_config"].model_config draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'vllm_config'].parallel_config @@ -72,6 +67,13 @@ def create_worker( proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) else: + # TODO(Yizhou): A quick fix, must be refactored ASAP + # ngram need not this fix. + draft_worker_kwargs[ + "vllm_config"].parallel_config.expert_parallel_size = 1 + draft_worker_kwargs[ + "vllm_config"].parallel_config.expert_tensor_parallel_size = 1 + draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b9233da05d..08abc08cdd 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -125,17 +125,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config if parallel_config: - # Default value for expert tensor parallel size - parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size - - # NOTE: When enable_expert_parallel is True, we follow vLLM convention: - # ep_size = world_size, which means expert_tensor_parallel_size must be 1 if parallel_config.enable_expert_parallel: parallel_config.expert_tensor_parallel_size = 1 - # NOTE: When enable_expert_parallel is False and param `asceend_config.expert_tensor_parallel_size` - # is configured, use ascend_config - elif ascend_config.expert_tensor_parallel_size > 0: - parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size + else: + parallel_config.expert_tensor_parallel_size = parallel_config.world_size_across_dp # Calculate expert parallel size based on world size parallel_config.expert_parallel_size = ( @@ -177,8 +170,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "PIECEWISE compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") compilation_config.use_inductor = False - compilation_config.splitting_ops.extend( - ["vllm.unified_ascend_attention_with_output"]) + if not compilation_config.full_cuda_graph: + compilation_config.splitting_ops.extend( + ["vllm.unified_ascend_attention_with_output"]) update_aclgraph_sizes(vllm_config) if parallel_config and parallel_config.worker_cls == "auto": diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3567dba355..1b06a4294a 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -34,6 +34,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding) from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs @@ -104,6 +106,12 @@ def get_quant_method(self, layer: torch.nn.Module, return AscendUnquantizedFusedMoEMethod() return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) + elif isinstance(layer, VocabParallelEmbedding): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedEmbeddingMethod() + return AscendEmbeddingMethod(self, prefix, + self.packed_modules_mapping) return None def is_layer_skipped_ascend( @@ -194,6 +202,17 @@ def create_weights( layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) + pergroup_dict = self.quant_method.get_pergroup_param( + input_size_per_partition, output_size_per_partition, params_dtype) + for pergroup_name, pergroup_param in pergroup_dict.items(): + param = torch.nn.Parameter(pergroup_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(pergroup_name, param) + set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name: + setattr(param, "input_dim", 1) + param.input_dim = 1 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) @@ -305,6 +324,10 @@ def create_weights( param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in param_key or "weight_offset_second" in param_key: + setattr(param, "quant_method", + FusedMoeWeightScaleSupported.GROUP.value) + param.quant_method = FusedMoeWeightScaleSupported.GROUP.value def apply( self, @@ -337,3 +360,20 @@ def apply( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) + + +class AscendEmbeddingMethod(AscendLinearMethod): + """Embedding method for Ascend quantization. + + This class calls AscendQuantizer to search a specific quantization + implementations supported on ascend hardware for Embedding methods. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]) -> None: + self.quantizer = AscendQuantizer.get_quantizer( + quant_config.quant_description, prefix, packed_modules_mapping) + self.quant_method = self.quantizer.build_linear_method() diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index ea1297bf35..d27914139f 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -24,6 +24,8 @@ from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init) +from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) from .w8a8 import AscendW8A8LinearMethod from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) @@ -263,6 +265,17 @@ def get_quantizer(cls, f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") +class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return AscendW4A8DynamicLinearMethod() + + @staticmethod + def build_moe_method(): + return AscendW4A8DynamicFusedMoEMethod() + + class W8A8Quantizer(VLLMAscendQuantizer): @staticmethod @@ -282,6 +295,7 @@ def build_moe_method(): SUPPORT_ASCEND_QUANTIZER_TYPE = { + "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, "W8A8": W8A8Quantizer, "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, } diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py new file mode 100644 index 0000000000..227b6b680a --- /dev/null +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -0,0 +1,377 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import numpy as np +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.ops.fused_moe import select_experts +from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all, + fused_experts_with_mc2) + + +class AscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + self.group_size = get_current_vllm_config( + ).quant_config.quant_description.get("group_size", 256) + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, + per_group_scale: torch.Tensor): + k, n = weight.shape + group_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + ) + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) + + +class AscendW4A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W4A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + try: + device_group = self.ep_group.device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + config = get_current_vllm_config() + group_size = config.quant_config.quant_description.get( + "group_size", 256) + + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // group_size, + dtype=params_dtype) + + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // group_size, + dtype=params_dtype) + + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // group_size, + dtype=params_dtype) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // group_size, + dtype=params_dtype) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk当前写8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + + fused_moe_state = get_forward_context().fused_moe_state + if fused_moe_state == FusedMoEState.MC2: + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into fused_moe module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + group_num, k, n = weight.shape + per_group_scale = per_group_scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to( + torch.float32) + scale_fp32_np = scale_fp32.cpu().numpy() + scale_fp32_np.dtype = np.uint32 + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), + dtype=np.uint32) + + sscale_uint64[..., ::2] = scale_fp32_np + + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), + dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( + group_num, quantgroup_num, n) + sscale_uint64_tensor = sscale_uint64_tensor.npu() + return sscale_uint64_tensor, bias + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) + layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( + 1, 2).contiguous() + + layer.w13_weight_scale_second.data, bias = self.process_scale( + layer.w13_weight, layer.w13_weight_scale.data, + layer.w13_weight_scale_second.data) + param = torch.nn.Parameter(bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", param) + layer.w2_weight_scale_second.data, bias1 = self.process_scale( + layer.w2_weight, layer.w2_weight_scale.data, + layer.w2_weight_scale_second.data) + param = torch.nn.Parameter(bias1, requires_grad=False) + layer.register_parameter("w2_scale_bias", param) + + layer.w13_weight.data = torch_npu.npu_quantize( + layer.w13_weight.data.to(torch.float32), + torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) + layer.w2_weight.data = torch_npu.npu_quantize( + layer.w2_weight.data.to(torch.float32), + torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index db23cb024d..28925034c1 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -20,6 +20,9 @@ import torch import torch_npu +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, input_offset: torch.Tensor): @@ -37,6 +40,8 @@ class AscendW8A8LinearMethod: def __init__(self) -> None: # aclnn quant matmul requires to transpose matrix B, set to true by default. self.transpose_weight = True + ascend_config = get_ascend_config() + self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout @staticmethod def get_weight( @@ -77,6 +82,10 @@ def get_perchannel_param( dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module, @@ -110,6 +119,9 @@ def process_weights_after_loading(self, layer): requires_grad=False).to(layer.aclnn_input_scale.dtype) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + if self.enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 372c29bca7..d9738b9b55 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,19 +15,96 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional, Tuple, Union +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator +from vllm.distributed import GroupCoordinator, get_ep_group, get_tp_group +from vllm.forward_context import get_forward_context +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ops.fused_moe import select_experts -from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, - npu_wait_tensor) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, + dispose_tensor, get_ascend_soc_version, + npu_stream_switch, npu_wait_tensor) + + +def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + Returns: + hidden_states: output hidden states after MLP. + """ + + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() + if dynamic_scale is None: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + else: + pertoken_scale = dynamic_scale + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states def apply_mlp(hidden_states: torch.Tensor, @@ -37,7 +114,9 @@ def apply_mlp(hidden_states: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -71,17 +150,31 @@ def apply_mlp(hidden_states: torch.Tensor, else: pertoken_scale = dynamic_scale + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], scale=[w1_scale], + bias=bias1, per_token_scale=[pertoken_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=_output_dtype)[0] # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) @@ -93,12 +186,13 @@ def apply_mlp(hidden_states: torch.Tensor, x=[hidden_states], weight=[w2], scale=[w2_scale], + bias=bias2, per_token_scale=[swiglu_out_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=_output_dtype)[0] return hidden_states @@ -117,11 +211,33 @@ def fused_experts_with_mc2( log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, + is_torchair: bool = False, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if log2phy is not None: + if log2phy: topk_ids = log2phy[topk_ids] - global_bs = 0 - moe_expert_num = len(expert_map) + global_redundant_expert_num + quant_mode = 2 + ep_group = get_ep_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + tp_world_size = get_tp_group().world_size + + # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, + # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. + global_bs = math.ceil(get_forward_context().max_tokens_across_dp / + tp_world_size) * ep_world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num # hidden_states = hidden_states.bfloat16() kwargs_mc2 = { "x": hidden_states, @@ -130,53 +246,43 @@ def fused_experts_with_mc2( "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": global_bs, - "expert_scales": topk_weights.to(torch.float32), } - rank = torch.distributed.get_rank() - - quant_mode = 2 - ep_group = get_ep_group().device_group - local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) - - world_szie = torch.distributed.get_world_size() - tp_size = world_szie // all_to_all_group_size - tp_rank = rank % tp_size - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[ - 0:7] + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - npu_wait_tensor(shared_gate_up[0], expand_x) - shared_act = shared_experts.act_fn(shared_gate_up) + npu_wait_tensor(quantized_x_for_share, expand_x) + shared_act_out = shared_experts.act_fn( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] # `expand_x` will be disposed in the `apply_mlp` function - down_out_list = apply_mlp(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) + down_out_list = apply_mlp_decode([expand_x], + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) # moeCombine kwargs_mc2 = { @@ -187,8 +293,7 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": 0, - "expand_scales": expand_scales, + "global_bs": global_bs, } tp_recv_counts = torch.empty(1, dtype=torch.int32, @@ -196,44 +301,47 @@ def fused_experts_with_mc2( stage3_kwargs = { "ep_send_counts": ep_recv_counts, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - "tp_send_counts": tp_recv_counts, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - + group_list_type = 1 if shared_experts is None: - return hidden_states + return hidden_states, expert_token_nums, group_list_type else: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act[0], down_out_list) - shared_output, _ = shared_experts.down_proj(shared_act) - return hidden_states, shared_output + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) + return hidden_states, shared_output, expert_token_nums, group_list_type # currently expert parallelism implemented with all2all # is under-optimized. -def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, -): - if log2phy is not None: +def fused_experts_with_all2all(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None): + if log2phy: topk_ids = log2phy[topk_ids] original_shape = hidden_states.shape if len(original_shape) == 3: @@ -311,7 +419,9 @@ def fused_experts_with_all2all( w2, w2_scale, expert_tokens, #16 - group_list_type=group_list_type) + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) @@ -343,7 +453,7 @@ def fused_experts_with_all2all( ) if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states + return final_hidden_states, expert_tokens, group_list_type def fused_experts(hidden_states: torch.Tensor, @@ -457,7 +567,7 @@ def fused_experts(hidden_states: torch.Tensor, if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states + return final_hidden_states, expert_tokens, group_list_type class AscendW8A8DynamicLinearMethod: @@ -466,6 +576,8 @@ class AscendW8A8DynamicLinearMethod: def __init__(self): self.transpose_weight = True + ascend_config = get_ascend_config() + self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout @staticmethod def get_weight(input_size: int, output_size: int, @@ -493,6 +605,10 @@ def get_perchannel_param( dtype=params_dtype) return params_dict + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + @staticmethod def apply( layer: torch.nn.Module, @@ -527,8 +643,10 @@ def apply( def process_weights_after_loading(self, layer): if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - # cast quantized weight tensors in NZ format (29) for higher inference speed - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + if self.enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -545,6 +663,7 @@ def __init__(self): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout try: device_group = self.ep_group.device_group @@ -618,6 +737,8 @@ def apply( log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -652,6 +773,16 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -660,14 +791,12 @@ def apply( topk_weights = topk_weights.to(x.dtype) - fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill) if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, + w1_scale=layer.w13_weight_scale_fp32, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, @@ -676,7 +805,11 @@ def apply( moe_all_to_all_group_name=self.moe_all_to_all_group_name, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts) + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + **kwargs) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -713,8 +846,16 @@ def process_weights_after_loading(self, layer): 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() + if self.enable_weight_nz_layout: + # cast quantized weight tensors in NZ layout for higher inference speed + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( diff --git a/vllm_ascend/soc_info.py b/vllm_ascend/soc_info.py new file mode 100644 index 0000000000..ac1317e8e1 --- /dev/null +++ b/vllm_ascend/soc_info.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +import torch_npu + + +@dataclass +class NPUSocInfo: + is_a3: bool = False + + def __post_init__(self): + torch_npu.npu._lazy_init() + self.soc_version = torch_npu._C._npu_get_soc_version() + if self.soc_version in (250, 251, 252, 253, 254, 255): + self.is_a3 = True diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index eeab287906..f7ca0aba2e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -20,12 +20,13 @@ import atexit import math from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch -import torch_npu # noqa: F401 +import torch_npu import torchair # type: ignore[import] # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event @@ -57,6 +58,9 @@ CUSTOM_OP_ENABLED = None +ACL_FORMAT_ND = 2 +ACL_FORMAT_FRACTAL_NZ = 29 + def try_register_lib(lib_name: str, lib_info: str = ""): import importlib @@ -168,6 +172,27 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: original_sizes, compilation_config.cudagraph_capture_sizes = \ compilation_config.cudagraph_capture_sizes, None + if compilation_config.full_cuda_graph: + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + truncated_sizes = [x for x in original_sizes if x <= max_num_seqs] + compilation_config.init_with_cudagraph_sizes(truncated_sizes) + + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + + logger.warning(warning_message) + return + # Calculate parallel configuration factor num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers parallel_config = vllm_config.parallel_config @@ -278,19 +303,58 @@ def npu_wait_tensor(self: torch.Tensor, return _npu_wait_tensor(self, dependency) if enabled else self -# TODO(zzzzwwjj): move this into forward_context -class FusedMoEState(Enum): - AllGather = 0 - All2All = 1 - MC2 = 2 +class AscendSocVersion(Enum): + A2 = 0 + A3 = 1 + MAX = 2 + + +_ascend_soc_version = None -# TODO(zzzzwwjj): add soc_version to choose branch -def get_fused_moe_state(ep_size: int, with_prefill: bool): - if ep_size == 1: - return FusedMoEState.AllGather - # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. - elif ep_size < 16 or with_prefill: - return FusedMoEState.All2All +def init_ascend_soc_version(): + soc_version = torch_npu.npu.get_soc_version() + global _ascend_soc_version + if 220 <= soc_version <= 225: + _ascend_soc_version = AscendSocVersion.A2 + elif 250 <= soc_version <= 255: + _ascend_soc_version = AscendSocVersion.A3 else: - return FusedMoEState.MC2 + _ascend_soc_version = AscendSocVersion.MAX + + +def get_ascend_soc_version(): + global _ascend_soc_version + assert _ascend_soc_version is not None + return _ascend_soc_version + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def get_graph_params(): + return _graph_params diff --git a/vllm_ascend/worker/draft_model_runner.py b/vllm_ascend/worker/draft_model_runner.py index 1306b1e160..bfd513d5fe 100644 --- a/vllm_ascend/worker/draft_model_runner.py +++ b/vllm_ascend/worker/draft_model_runner.py @@ -18,7 +18,6 @@ from typing import List, Optional import torch -from vllm.forward_context import set_forward_context from vllm.logger import logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs @@ -27,6 +26,7 @@ ModelRunnerInputBase, ModelRunnerWrapperBase) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention import AscendMetadata # A flag to enable debug prints for the updated input tensors @@ -51,12 +51,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): """ def __init__(self, model_runner: ModelRunnerBase): - if hasattr( - model_runner, - "return_hidden_states") and model_runner.return_hidden_states: - raise ValueError( - "return_hidden_states is not supported for TP1DraftModelRunner." - ) super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None @@ -211,6 +205,9 @@ def execute_model( if self.prompt_adapter_config is not None: raise ValueError("TP1DraftModelRunner has no support for " "prompt_adapter_config") + if model_input.inputs_embeds is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "inputs_embeds") if model_input.multi_modal_kwargs: raise ValueError( "TP1DraftModelRunner has no support for multi_modal_kwargs" @@ -264,14 +261,15 @@ def execute_model( spec_step_idx = kwargs.get("spec_step_idx", step) model_execute_kwargs["spec_step_idx"] = spec_step_idx compute_logits_kwargs["spec_step_idx"] = spec_step_idx - with set_forward_context(model_input.attn_metadata, - self.vllm_config): + with set_ascend_forward_context(model_input.attn_metadata, + self.vllm_config): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions hidden_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -293,6 +291,9 @@ def execute_model( ) outputs.append(output) + if self.return_hidden_states and is_fallback: + output.hidden_states = hidden_states + if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 48c5d4b68f..7846f655d1 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -35,7 +35,6 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group -from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import logger from vllm.lora.layers import LoRAMapping @@ -66,6 +65,7 @@ _init_sampling_metadata_from_tensor_dict) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1431,8 +1431,12 @@ def execute_model( model_forward_start.record() if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): + with set_ascend_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + with_prefill=prefill_meta is not None, + in_profile_run=self.in_profile_run): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions if self.torchair_graph_enabled: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 89f30bc43c..68293e0b58 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -17,7 +17,9 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # +import copy import gc +import math import os import time import types @@ -32,13 +34,18 @@ import torch._dynamo.cache_size import torch.distributed as dist import torch.nn as nn +import torchair from torch.distributed import ReduceOp +from torchair import patch_for_hcom from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import get_dp_group, get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE @@ -69,14 +76,21 @@ scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler -from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is +from vllm_ascend.utils import ProfileExecuteDuration from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer +from vllm_ascend.eplb.eplb_updator import EplbUpdator +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -132,6 +146,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config ascend_config = get_ascend_config() @@ -149,12 +164,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs - self.graph_block_tables = np.zeros( - (self.vllm_config.scheduler_config.max_num_seqs, - (self.model_config.max_model_len + self.block_size - 1) // - self.block_size), - dtype=np.int32) - # Model-related. self.num_attn_layers = self.model_config.get_num_layers_by_block_type( vllm_config.parallel_config, LayerBlockType.attention) @@ -208,8 +217,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # Set up speculative decoding. self.use_spec_decode = False self.spec_attn_mask = None + self.actual_seq_q_lens = [] + self.spec_token_num = 0 + self.decode_token_per_req = 1 if self.speculative_config: self.use_spec_decode = True + self.spec_token_num = self.speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 self.spec_attn_mask = torch.triu(torch.ones(2048, 2048, dtype=torch.bool), @@ -222,6 +236,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.device) # type: ignore elif self.speculative_config.method == 'deepseek_mtp': self.drafter = MtpProposer(self.vllm_config, self) + self.decode_token_per_req = 1 + self.spec_token_num + self.actual_seq_q_lens = [ + len for len in + range(self.decode_token_per_req, self.max_num_tokens + + 1, self.decode_token_per_req) + ] else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -243,6 +263,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.query_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -343,15 +369,19 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes + self.use_ring_mla = ascend_config.chunked_prefill_for_mla if ascend_config.torchair_graph_config.graph_batch_sizes_init: self.init_torchair_graph_batch_sizes() - if len(self.torchair_graph_batch_sizes) == 0: - # TODO(zzzzwwjj): check torchair_graph_batch_sizes init code - self.torchair_graph_batch_sizes = [ - self.scheduler_config.max_num_seqs - ] + self.check_torchair_graph_batch_sizes() + + # graph_block_tables shape: [num_request, cell(max_model_len / block_size)] + self.graph_block_tables = np.zeros( + (self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req, + (self.model_config.max_model_len + self.block_size - 1) // + self.block_size), + dtype=np.int32) torch._dynamo.cache_size.config.cache_size_limit += len( self.torchair_graph_batch_sizes) @@ -362,6 +392,21 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank + # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True + self.in_profile_run = False + + # kv role + self.is_kv_producer = False + if vllm_config.kv_transfer_config is not None: + self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer + + #EPLB + self.dynamic_eplb = ascend_config.dynamic_eplb + if self.dynamic_eplb == True: + self.eplb_adaptor = None + self.is_eplb_warmuped = False + self.eplb_updator = EplbUpdator(ascend_config.expert_map_path) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -420,33 +465,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: generator.manual_seed(sampling_params.seed) else: generator = None - if vllm_version_is("0.9.1"): - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) - else: - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - pooling_params=None, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -569,6 +600,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, ()) + if spec_token_ids: + req_index = self.input_batch.num_reqs - 1 + start_index = len(req_state.prompt_token_ids) + len( + req_state.output_token_ids) + end_token_index = start_index + len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + self.input_batch.num_tokens[req_index] = end_token_index # Condense the batched states if there are empty indices. if removed_req_indices: @@ -578,16 +619,45 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: + self, num_tokens: int, with_prefill: bool, enable_dbo: bool + ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + if self.dp_size == 1: + return num_tokens, None, with_prefill, enable_dbo + forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], + [num_tokens, with_prefill, not enable_dbo], device="cpu", dtype=torch.int32) dist.all_reduce(forward_metadata, op=ReduceOp.MAX, group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + num_tokens_across_dp = torch.tensor([forward_metadata[0]] * + self.dp_size, + device="cpu", + dtype=torch.int32) + return forward_metadata[0].item(), num_tokens_across_dp, bool( + forward_metadata[1]), not bool(forward_metadata[2]) + + def _check_dbo_is_valid(self, query_lens: torch.Tensor, + attn_state: AscendAttentionState, + num_tokens: int) -> bool: + # do the checks for dp + dbo + if attn_state in [ + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + ]: + return False + # considering the case that one dp rank may enable dbo while others may not + if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: + return False + # TODO: remove it if token-level microbatch is enabled + [token_index, + seq_index] = compute_split_seq_index(query_lens, attn_state, + num_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + query_lens) or num_tokens < 256: + return False + return True def get_model(self) -> nn.Module: return self.model @@ -776,7 +846,8 @@ def _process_reqs( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor]: + torch.Tensor, int, torch.Tensor, Optional[set[str]], + Optional[set[str]]]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -844,6 +915,7 @@ def _process_reqs( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) + self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) positions = self.positions[:num_input_tokens] @@ -872,6 +944,9 @@ def _process_reqs( # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # SpecDecoding now supports seq_len=1 and seq_len=2 + attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): attn_state = AscendAttentionState.SpecDecoding @@ -881,11 +956,14 @@ def _process_reqs( else: attn_state = AscendAttentionState.PrefillCacheHit - attn_mask = self._make_attention_mask(seq_lens=seq_lens, - query_lens=num_scheduled_tokens, - position=positions, - attn_state=attn_state) - self.attn_mask = attn_mask + # NOTE: when use ring_mla, attn_mask don't need to generate here. + if not self.use_ring_mla or attn_state == AscendAttentionState.PrefillNoCache: + attn_mask = self._make_attention_mask( + seq_lens=seq_lens, + query_lens=num_scheduled_tokens, + position=positions, + attn_state=attn_state) + self.attn_mask = attn_mask self.attn_state = attn_state # type: ignore extra_builder_kwargs = {} @@ -896,36 +974,52 @@ def _process_reqs( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + self.slot_mapping_cpu[:total_num_scheduled_tokens], + non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache + self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] + # Use host tensor, other wise error: tensor.hostData is null common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) + query_start_loc=query_start_loc, + seq_lens=self.seq_lens_cpu[:num_reqs]) + self.seq_lens_list = self.seq_lens_np.tolist()[:num_input_tokens] with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) + num_tokens_across_dp = None - if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) - extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens - extra_builder_kwargs['with_prefill_across_dp'] = with_prefill - - # Add graph_pad_size here + padded_num_tokens = total_num_scheduled_tokens if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) - graph_pad_size = padded_batch_size - total_num_scheduled_tokens + padded_num_tokens = self.select_torchair_padded_batch_size( + total_num_scheduled_tokens) + (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp( + padded_num_tokens, with_prefill, enable_dbo) + extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo + + # TODO(zzzzwwjj): this code need to refactor afterwards. + self.with_prefill = with_prefill + # Add num_token_pad_size and num_reqs_pad_size here for torchair graph mode + if self.torchair_graph_enabled and not with_prefill: + num_token_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens + num_reqs_pad_size = ( + padded_num_tokens_across_dp // self.decode_token_per_req - + num_reqs) + assert num_token_pad_size >= 0 and num_reqs_pad_size >= 0 - extra_builder_kwargs['graph_pad_size'] = graph_pad_size + extra_builder_kwargs['num_token_pad_size'] = num_token_pad_size + extra_builder_kwargs['num_reqs_pad_size'] = num_reqs_pad_size + self.num_reqs_pad_size = num_reqs_pad_size + self.extra_builder_kwargs = extra_builder_kwargs if self.vllm_config.model_config.use_mla: attn_metadata = self.attn_metadata_builder.build( # type: ignore @@ -941,6 +1035,7 @@ def _process_reqs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + common_attn_metadata=common_attn_metadata, common_prefix_len=None, **extra_builder_kwargs, ) @@ -956,10 +1051,7 @@ def _process_reqs( # Copy the tensors to the NPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - input_ids = self.input_ids[:num_input_tokens] - # prepare the MRoPE for mllm if using multimodal - num_input_tokens = total_num_scheduled_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if self.is_multimodal_model: @@ -973,51 +1065,56 @@ def _process_reqs( # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids[:total_num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds) + self.inputs_embeds[:total_num_scheduled_tokens].copy_( + inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. + # then the embedding layer is not included in the ACL Graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] - else: - positions = self.positions[:num_input_tokens] if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_batch_size] - positions = self.positions[:padded_batch_size] + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] # Run forward pass - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + # TODO(zzzzwwjj): check param `num_tokens_across_dp` later. + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=padded_num_tokens_across_dp, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill): with ProfileExecuteDuration().capture_async("forward"): + self.maybe_setup_kv_connector(scheduler_output) model_kwargs = {} if self.torchair_graph_enabled: model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata + if envs_ascend.VLLM_ASCEND_ENABLE_DBO and with_prefill: + model_kwargs["graph_enable"] = False # type: ignore if self.torchair_graph_enabled and not with_prefill: compiled_model = self._get_torchair_lazy_compiled_model( - padded_batch_size) + padded_num_tokens_across_dp) hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs, - ) + **model_kwargs) else: assert self.model is not None hidden_states = self.model( @@ -1025,9 +1122,11 @@ def _process_reqs( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **model_kwargs, - ) + **model_kwargs) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1052,7 +1151,8 @@ def _process_reqs( sample_indices = spec_decode_metadata.logits_indices return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, sample_indices) + total_num_scheduled_tokens, sample_indices, finished_sending, + finished_recving) def _calc_spec_decode_metadata( self, @@ -1219,16 +1319,27 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: + with ProfileExecuteDuration().capture_async( "prepare input and forward"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + logger.debug( + "skip this step for we receive the data from remote disaggregate prefill node" + ) + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + if self.dynamic_eplb: + self.eplb_updator.forward_before() + return self.kv_connector_no_forward(scheduler_output) (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, - sample_indices) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + num_scheduled_tokens, sample_indices, finished_sending, + finished_recving) = (self._process_reqs(scheduler_output, + intermediate_tensors)) + + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() with ProfileExecuteDuration().capture_async("post process"): logits = self.model.compute_logits(hidden_states[sample_indices], @@ -1319,25 +1430,19 @@ def execute_model( hidden_states, attn_metadata, ) - if vllm_version_is("0.9.1"): - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict={}, - ) - else: - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict={}, - pooler_output=[], - ) + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict={}, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) durations = ProfileExecuteDuration().pop_captured_sync() if durations: @@ -1349,8 +1454,55 @@ def execute_model( logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + return model_runner_output + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # TODO(zzzzwwjj): Check whether `set_ascend_forward_context` has influence with kv_connector or not. + with set_ascend_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finsihed_sending, finished_recving = ( + self.get_finished_kv_transfer(scheduler_output)) + # For the case of no forward caused by receiving remote kv, + # one round of dummy inference is necessary + # to prevent hang over the collective calls. + if not finsihed_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finsihed_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfer( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + def _profile_multimodal(self) -> None: # TODO: handle encoder-decoder models once we support them. # NOTE: Currently model is profiled with a single non-text @@ -1438,15 +1590,34 @@ def _profile_multimodal(self) -> None: def _dummy_run( self, num_tokens: int, - is_compile: bool = False, - with_prefill: bool = True, + skip_attn: bool = True, + with_prefill: bool = False, + is_torchair_compile: bool = False, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + is_profile_run: bool = False, ) -> torch.Tensor: + if self.torchair_graph_enabled and not with_prefill: + num_tokens = self.select_torchair_padded_batch_size(num_tokens) + + # For kv producer, with prefill always true + if self.is_kv_producer: + with_prefill = True + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp( + num_tokens, with_prefill, False) + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens + num_reqs = math.ceil(num_tokens / self.decode_token_per_req) + if with_prefill: + num_reqs = min(num_tokens, max_num_reqs) + else: + num_reqs = (num_tokens + self.decode_token_per_req - + 1) // self.decode_token_per_req min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs @@ -1454,6 +1625,26 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + # NOTE: If torchair graph mode and not with_prefill, + # we can't skip_attn, it will cause graph recompile. + if self.torchair_graph_enabled and not with_prefill: + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_reqs, num_actual_tokens=1) + elif skip_attn: + attn_metadata = None + else: + attn_metadata = self.attn_metadata_builder.build_dummy_metadata( + num_actual_tokens=num_tokens, + num_reqs=num_reqs, + num_scheduled_tokens=num_scheduled_tokens, + attn_state=attn_state, + ) + + + if not is_torchair_compile and not is_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1483,14 +1674,17 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run): + model_kwargs = {} if self.torchair_graph_enabled and not with_prefill: - attn_metadata = self.attn_metadata_builder.build_dummy( - num_reqs=num_tokens, num_actual_tokens=1) # Only mark static while compiling - if is_compile: + if is_torchair_compile: torch._dynamo.mark_static(input_ids) torch._dynamo.mark_static(positions) torch._dynamo.mark_static( @@ -1505,21 +1699,43 @@ def _dummy_run( torch._dynamo.mark_static(kv[1]) compiled_model = self._get_torchair_lazy_compiled_model( num_tokens) + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: + model_kwargs["graph_enable"] = True # type: ignore hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=None, - kv_caches=self.kv_caches, - attn_metadata=attn_metadata, + **model_kwargs, ) else: + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: + model_kwargs["graph_enable"] = False # type: ignore hidden_states = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - return hidden_states + inputs_embeds=inputs_embeds, + **model_kwargs) + if self.speculative_config and self.speculative_config.method == "deepseek_mtp": + assert isinstance(self.drafter, MtpProposer) + self.drafter.dummy_run(num_reqs, with_prefill=with_prefill) + if is_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not is_torchair_compile and not is_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + return hidden_states + + @contextmanager + def set_in_profile_run(self): + self.in_profile_run = True + try: + yield + finally: + self.in_profile_run = False def profile_run(self) -> None: # FIXME Profile with multimodal encoder & encoder cache. @@ -1547,7 +1763,10 @@ def profile_run(self) -> None: # TODO: call maybe_profile_with_lora() # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens) + with self.set_in_profile_run(): + hidden_states = self._dummy_run(self.max_num_tokens, + with_prefill=True, + is_profile_run=True) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] @@ -1560,6 +1779,20 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + def do_get_expert_load(self) -> tuple: + return self.eplb_updator.get_expert_load() + + def do_update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int): + return self.eplb_updator.update_expert_load_statistical_period(num_expert_load_gather, num_iterations) + + def eplb_warmup(self): + #EPLB + if self.dynamic_eplb and not self.is_eplb_warmuped: + self.is_eplb_warmuped = True + self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + self.eplb_updator.set_adaptor(self.eplb_adaptor) + self.eplb_updator.warm_up_eplb() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) @@ -1578,9 +1811,9 @@ def load_model(self) -> None: m.consumed_memory / float(2**30)) def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.max_num_reqs: + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: raise ValueError( - f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}" + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" ) compiled_model = self.torchair_compiled_models.get( @@ -1590,9 +1823,6 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): if compiled_model: return compiled_model - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - patch_for_hcom() config = torchair.CompilerConfig() config.experimental_config.frozen_parameter = True @@ -1640,9 +1870,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - import torch_npu kv_caches: Dict[str, torch.Tensor] = {} + def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):] + self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, @@ -1653,17 +1888,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[self.cache_config.block_size], ) - kv_cache_sizes = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "NPU.") - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + if not vllm_version_is("0.9.0"): + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[ + kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: - tensor_size = kv_cache_sizes[layer_name] + if vllm_version_is("0.9.0"): + tensor_size = kv_cache_config.tensors[layer_name].size + else: + tensor_size = kv_cache_sizes[layer_name] assert tensor_size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_size // kv_cache_spec.page_size_bytes @@ -1675,6 +1915,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks + alignment = 2 * 1024 * 1024 # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): @@ -1682,29 +1923,51 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if self.torchair_graph_enabled: - layer_kv_cache_nope = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - layer_kv_cache_pe = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, - ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - kv_caches[layer_name] = (layer_kv_cache_nope, - layer_kv_cache_pe) - torch_npu.npu_format_cast(kv_caches[layer_name][0], 2) - torch_npu.npu_format_cast(kv_caches[layer_name][1], 2) + if self.model_config.is_deepseek_mla: + # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory + # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but + # we found there are also some exceptions during test, so we manual align those memory here, this part + # of code may consume 2M * 2 * elem_size memory every layer. + num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape + rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + nope_dim = head_size - rope_dim + nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim + nope_allocate_shape_alignment = nope_allocate_shape + alignment + rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim + rope_allocate_shape_alignment = rope_allocate_shape + alignment + nope_cache_shape = (num_blocks, block_size, + num_kv_heads, nope_dim) + rope_cache_shape = (num_blocks, block_size, + num_kv_heads, rope_dim) + nope_cache = torch.zeros(nope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + rope_cache = torch.zeros(rope_allocate_shape_alignment, + dtype=dtype, + device=self.device) + nope_cache = align_memory( + nope_cache, alignment)[:nope_allocate_shape].view( + nope_cache_shape) + rope_cache = align_memory( + rope_cache, alignment)[:rope_allocate_shape].view( + rope_cache_shape) + kv_caches[layer_name] = (nope_cache, rope_cache) else: - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - torch_npu.npu_format_cast(kv_caches[layer_name], 2) + num_caches = kv_cache_shape[0] + kv_cache_list = [] + for i in range(num_caches): + cache_shape = kv_cache_shape[1:] + cache_size = math.prod(cache_shape) + cache_size_aligned = cache_size + alignment + kv_cache = torch.zeros(cache_size_aligned, + dtype=dtype, + device=self.device) + kv_cache = align_memory( + kv_cache, + alignment)[:cache_size].view(cache_shape) + kv_cache_list.append(kv_cache) + kv_caches[layer_name] = kv_cache_list + # torch_npu.npu_format_cast(kv_caches[layer_name], 2) else: # TODO: add new branches when introducing more types of # KV cache specs. @@ -1715,6 +1978,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.vllm_config.compilation_config.static_forward_context, self.kv_caches) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -1773,24 +2039,25 @@ def capture_model(self) -> None: reversed(torchair_graph_batch_sizes)): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - is_compile=True, - with_prefill=False) - self._dummy_run(num_tokens, - is_compile=True, - with_prefill=False) + # NOTE: when in torchair graph and not with_prefill, + # we don't need to set `skip_attn=False` + self._dummy_run(num_tokens, is_torchair_compile=True) + self._dummy_run(num_tokens, is_torchair_compile=True) logger.info("Batchsize %d is compiled successfully: %d/%d.", num_tokens, idx + 1, graph_num) elif self.use_aclgraph: # Trigger ACL graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. + # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode with graph_capture(device=self.device): + skip_attn = not self.vllm_config.compilation_config.full_cuda_graph + # TODO: Make sure passing attn_state to _dummy_run in the future for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn) else: logger.info("Skipping NPU graph capture for eager mode.") return @@ -1885,6 +2152,7 @@ def _generate_mtp_token_ids( cu_num_tokens, token_indices = self.drafter.prepare_inputs( attn_metadata.query_start_loc, num_rejected_tokens, + force_one_token=True, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] @@ -1916,8 +2184,43 @@ def init_torchair_graph_batch_sizes(self): start_graph_batch_size *= 2 def select_torchair_padded_batch_size(self, batch_size: int): - selected_batch_size = self.max_num_reqs for padded_batch_size in self.torchair_graph_batch_sizes: - if batch_size <= padded_batch_size < selected_batch_size: - selected_batch_size = padded_batch_size - return selected_batch_size + if batch_size <= padded_batch_size: + # we treat batch_size as num of requests + return padded_batch_size + raise ValueError( + f"cur batch_size is invalid, torchair_graph_batch_sizes is " + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." + ) + + def check_torchair_graph_batch_sizes(self): + # return graph_batch_sizes according to the number of tokens + # first pad according to the number of requests + if len(self.torchair_graph_batch_sizes) == 0: + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + else: + self.torchair_graph_batch_sizes = sorted( + self.torchair_graph_batch_sizes) + while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: + self.torchair_graph_batch_sizes.pop() + if len(self.torchair_graph_batch_sizes) == 0: + logger.warning( + "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" + ) + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: + self.torchair_graph_batch_sizes.append(self.max_num_reqs) + + # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` + tp_size = self.parallel_config.tensor_parallel_size + if self.parallel_config.enable_expert_parallel: + new_graph_batch_sizes = [] + for graph_batch_size in self.torchair_graph_batch_sizes: + cur_graph_batch_size = (graph_batch_size + tp_size - + 1) // tp_size * tp_size + # `graph_batch_size` need to be divisible by `self.decode_token_per_req` + cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req + if cur_graph_batch_size not in new_graph_batch_sizes and \ + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: + new_graph_batch_sizes.append(cur_graph_batch_size) + self.torchair_graph_batch_sizes = new_graph_batch_sizes diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index ba8406fa0a..04a7d617b5 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -1,15 +1,19 @@ import torch +import vllm.envs as envs_vllm from vllm.attention.layer import Attention from vllm.config import (VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) from vllm.v1.sample.metadata import SamplingMetadata -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP +from vllm_ascend.utils import ProfileExecuteDuration # FIXME(woosuk): The logic here is duplicated with the main sampling code. @@ -61,13 +65,26 @@ def __init__( vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size self.runner = runner + # persistent buffers for graph + self.input_ids = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int32, + device=self.runner.device) + self.positions = torch.zeros(self.runner.max_num_tokens, + dtype=torch.int64, + device=self.runner.device) + self.hidden_states = torch.zeros( + (self.runner.max_num_tokens, self.runner.hidden_size), + dtype=self.runner.dtype, + device=self.runner.device) + self.is_mtp_torchair_ready = False @staticmethod def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + force_one_token: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: # cu_target_query_lens: [0, a, a + b, a + b + c] # num_rejected_tokens: [n1, n2, n3] @@ -76,32 +93,39 @@ def prepare_inputs( # token_indices: [0, 1, ..., a - n1 - 1, # a, a + 1, ..., a + b - n2 - 1, # a + b, a + b + 1, ..., a + b + c - n3 - 1] - # [0, a, a + b, a + b + c] -> [a, b, c] query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens + if force_one_token: + # enable force_one_token means we only focus on the last token position of each request + # token_indices: [batch_size] + cu_num_tokens = torch.arange(cu_target_query_lens.size(0), + device=cu_target_query_lens.device, + dtype=torch.int32) + relative_index = query_len_per_req - num_rejected_tokens - 1 + token_indices = cu_target_query_lens[:-1] + relative_index + else: + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) - cu_num_tokens = torch.empty_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_num_tokens.device, - ) - - BLOCK_SIZE = 1024 - prepare_input_kernel( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) + BLOCK_SIZE = 1024 + prepare_input_kernel( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) return cu_num_tokens, token_indices def propose( @@ -126,13 +150,12 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 - input_ids = torch.empty_like(target_token_ids) # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - input_ids[:-1] = target_token_ids[1:] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - input_ids[last_token_indices] = next_token_ids + self.input_ids[last_token_indices] = next_token_ids query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() @@ -152,6 +175,23 @@ def propose( # input_batch=self.runner.input_batch, # scheduler_output=self.runner.scheduler_output, # ) + extra_builder_kwargs = self.runner.extra_builder_kwargs + + is_running_torchair = self.runner.torchair_graph_enabled and \ + not self.runner.with_prefill and self.is_mtp_torchair_ready + + if is_running_torchair: + if num_tokens == 1: + self.runner.attn_state = AscendAttentionState.DecodeOnly + num_reqs_pad_size = self.runner.num_reqs_pad_size + extra_builder_kwargs['num_reqs_pad_size'] = num_reqs_pad_size + # Assume num token per request is one + extra_builder_kwargs['num_token_pad_size'] = num_reqs_pad_size + num_input_tokens = self.runner.num_reqs_pad_size + else: + extra_builder_kwargs['num_token_pad_size'] = -1 + extra_builder_kwargs['num_reqs_pad_size'] = 0 + num_input_tokens = num_tokens attn_metadata = self.runner.attn_metadata_builder.build( num_reqs=batch_size, @@ -159,14 +199,52 @@ def propose( max_query_len=max_query_len, common_prefix_len=0, common_attn_metadata=common_attn_metadata, - ) - - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( - input_ids=input_ids, - positions=target_positions, - previous_hidden_states=target_hidden_states, - ) + is_mtp_model=True, + **extra_builder_kwargs) + + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + + # Assuming force_one_token is on, so each perfill request query_lens is 1 + if attn_metadata.prefill is not None: + attn_metadata.prefill.query_lens[:] = 1 + + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + with ProfileExecuteDuration().capture_async('mtp_forward'): + model_kwargs = {} + model_kwargs["attn_metadata"] = attn_metadata + if self.runner.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] + if is_running_torchair: + torch._dynamo.mark_static(self.input_ids) + torch._dynamo.mark_static(self.positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + for kv in self.runner.kv_caches: + assert isinstance(kv, + tuple), "kv_cache must be a tuple" + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + hidden_states = self.torchair_compiled_model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + inputs_embeds=None, + **model_kwargs) + else: + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self. + hidden_states[:num_input_tokens], + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:]) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -202,6 +280,49 @@ def load_model(self) -> None: self.model)) process_weights_after_loading(self.model, draft_model_config, target_device) + if self.runner.torchair_graph_enabled and self.is_mtp_torchair_ready: + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + torch.npu.set_compile_mode(jit_compile=False) + if not self.runner.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + else: + self.torchair_compiled_model = torchair.inference.cache_compile( + self.model.forward, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + ) -> None: + if self.runner.torchair_graph_enabled and not with_prefill: + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) + else: + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) + with set_ascend_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + self.model(input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + previous_hidden_states=self.hidden_states[:num_tokens], + attn_metadata=attn_metadata) # TODO Using torch instead of triton may result in poor performance diff --git a/vllm_ascend/worker/pooling_model_runner.py b/vllm_ascend/worker/pooling_model_runner.py index e1262fb0a2..5047a0f106 100644 --- a/vllm_ascend/worker/pooling_model_runner.py +++ b/vllm_ascend/worker/pooling_model_runner.py @@ -21,13 +21,13 @@ import torch from vllm.distributed import get_pp_group -from vllm.forward_context import set_forward_context from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.worker.model_runner import (ModelInputForNPU, ModelInputForNPUBuilder, NPUModelRunnerBase) @@ -142,8 +142,8 @@ def execute_model( if model_input.token_types is not None: cross_enc_kwargs["token_type_ids"] = model_input.token_types - with set_forward_context(model_input.attn_metadata, self.vllm_config, - virtual_engine): + with set_ascend_forward_context(model_input.attn_metadata, + self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index e78cc3f1cf..80f7c4a78d 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -49,9 +49,8 @@ from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator -from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import init_ascend_soc_version, try_register_lib from vllm_ascend.worker.model_runner import NPUModelRunner from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner @@ -218,6 +217,7 @@ def init_device(self) -> None: else: raise RuntimeError( f"Not support device type: {self.device_config.device}") + init_ascend_soc_version() # Initialize the distributed environment. self._init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, @@ -545,11 +545,6 @@ def _init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - init_ascend_model_parallel( - parallel_config.expert_parallel_size, - parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, - ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6fe84a4580..b062e5cf9e 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -40,9 +40,8 @@ from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator -from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import init_ascend_soc_version, try_register_lib from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -75,6 +74,9 @@ def __init__( is_driver_worker=is_driver_worker) # Try to import mindie_turbo to accelerate vLLM inference. + local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + world_size = self.vllm_config.parallel_config.world_size + self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank try_register_lib( "mindie_turbo", "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." @@ -125,6 +127,7 @@ def init_device(self): info = f"Not support device type: {self.device_config.device}" logger.error(info) raise RuntimeError(info) + init_ascend_soc_version() # Initialize the distributed environment. self._init_worker_distributed_environment() # Set random seed. @@ -192,6 +195,7 @@ def load_model(self) -> None: self.model_runner.load_model() def compile_or_warm_up_model(self) -> None: + self.model_runner.eplb_warmup() warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ @@ -201,12 +205,18 @@ def compile_or_warm_up_model(self) -> None: for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) + if not self.model_config.enforce_eager: self.model_runner.capture_model() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) + def get_expert_load(self) -> tuple: + return self.model_runner.do_get_expert_load() + def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int): + self.model_runner.do_update_expert_load_statistical_period(num_expert_load_gather, num_iterations) + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -245,22 +255,10 @@ def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) def execute_dummy_batch(self) -> None: - runner = self.model_runner - max_num_tokens = 1 - with_prefill = False - if runner.dp_size > 1: - max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) - if runner.torchair_graph_enabled and not with_prefill: - max_num_tokens = runner.select_torchair_padded_batch_size( - max_num_tokens) - runner._dummy_run(max_num_tokens, - is_compile=False, - with_prefill=with_prefill) + self.model_runner._dummy_run(1) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" - parallel_config = self.vllm_config.parallel_config set_custom_all_reduce( not self.parallel_config.disable_custom_all_reduce) init_distributed_environment(self.parallel_config.world_size, @@ -269,11 +267,6 @@ def _init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - init_ascend_model_parallel( - parallel_config.expert_parallel_size, - parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, - ) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):