From fa7a1f020c9271e4d6d4ca91710112121fd1f0e1 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Tue, 24 Jun 2025 01:24:35 +0000 Subject: [PATCH] [Dist][EP] Remove ETP/EP maintained in vllm-ascend Signed-off-by: MengqingCao --- docs/source/tutorials/multi_npu_moge.md | 2 + .../configuration/additional_config.md | 2 - examples/run_dp_attention_etp16.sh | 22 -- examples/run_dp_attention_etp16_benmark.sh | 57 ----- .../long_term/accuracy/accuracy_multicard.py | 61 +---- tests/e2e/multicard/test_expert_parallel.py | 30 +++ .../multicard/test_fused_moe_allgather_ep.py | 2 - .../e2e/multicard/test_torchair_graph_mode.py | 1 + tests/ut/distributed/test_parallel_state.py | 208 ------------------ tests/ut/test_ascend_config.py | 3 - tests/ut/test_platform.py | 25 --- vllm_ascend/ascend_config.py | 2 - vllm_ascend/distributed/parallel_state.py | 77 ------- vllm_ascend/models/deepseek_v2.py | 3 +- vllm_ascend/models/pangu_moe.py | 16 +- vllm_ascend/ops/common_fused_moe.py | 1 + vllm_ascend/ops/fused_moe.py | 37 ++-- vllm_ascend/patch/__init__.py | 12 +- .../patch_common/patch_distributed.py | 22 -- vllm_ascend/platform.py | 18 -- vllm_ascend/quantization/w8a8.py | 2 +- vllm_ascend/quantization/w8a8_dynamic.py | 2 +- vllm_ascend/utils.py | 2 - vllm_ascend/worker/worker_v1.py | 7 - 24 files changed, 66 insertions(+), 548 deletions(-) delete mode 100644 examples/run_dp_attention_etp16.sh delete mode 100644 examples/run_dp_attention_etp16_benmark.sh create mode 100644 tests/e2e/multicard/test_expert_parallel.py delete mode 100644 tests/ut/distributed/test_parallel_state.py delete mode 100644 vllm_ascend/distributed/parallel_state.py diff --git a/docs/source/tutorials/multi_npu_moge.md b/docs/source/tutorials/multi_npu_moge.md index e152197ddb..f2839de028 100644 --- a/docs/source/tutorials/multi_npu_moge.md +++ b/docs/source/tutorials/multi_npu_moge.md @@ -48,6 +48,7 @@ Run the following script to start the vLLM server on Multi-NPU: ```bash vllm serve /path/to/pangu-pro-moe-model \ --tensor-parallel-size 4 \ +--enable-expert-parallel \ --trust-remote-code \ --enforce-eager ``` @@ -145,6 +146,7 @@ if __name__ == "__main__": llm = LLM(model="/path/to/pangu-pro-moe-model", tensor_parallel_size=4, + enable_expert_parallel=True, distributed_executor_backend="mp", max_model_len=1024, trust_remote_code=True, diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index d58ac5ac85..df01430df1 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -28,7 +28,6 @@ The following table lists the additional configuration options available in vLLM |-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------| | `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | | `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | -| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | @@ -75,7 +74,6 @@ An example of additional configuration is as follows: "enabled": True, "enable_chunked_prefill": True, }, - "expert_tensor_parallel_size": 1, "refresh": False, } ``` diff --git a/examples/run_dp_attention_etp16.sh b/examples/run_dp_attention_etp16.sh deleted file mode 100644 index 5d87879a11..0000000000 --- a/examples/run_dp_attention_etp16.sh +++ /dev/null @@ -1,22 +0,0 @@ -export TASK_QUEUE_ENABLE=1 -source /usr/local/Ascend/ascend-toolkit/set_env.sh -source /usr/local/Ascend/nnal/atb/set_env.sh -export ASCEND_LAUNCH_BLOCKING=0 -export VLLM_VERSION=0.9.1 - -nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \ - --served-model-name auto \ - --quantization ascend \ - --trust-remote-code \ - --distributed-executor-backend=mp \ - --port 8006 \ - -tp=8 \ - -dp=2 \ - --max-num-seqs 24 \ - --max-model-len 32768 \ - --max-num-batched-tokens 32768 \ - --block-size 128 \ - --no-enable-prefix-caching \ - --additional-config '{"torchair_graph_config":{"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24]},"ascend_scheduler_config":{"enabled":true},"expert_tensor_parallel_size":16}' \ - --gpu-memory-utilization 0.96 &> run.log & -disown \ No newline at end of file diff --git a/examples/run_dp_attention_etp16_benmark.sh b/examples/run_dp_attention_etp16_benmark.sh deleted file mode 100644 index bdd1fb858c..0000000000 --- a/examples/run_dp_attention_etp16_benmark.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/bin/bash -# Concurrency array -concurrency_array=(48) -#best rate -rate_array=(0.7) - -# Result file -result_file="benchmark_results.txt" -echo "Benchmark Results" > $result_file -echo "===================" >> $result_file - -# Loop through all combinations -for concurrency in "${concurrency_array[@]}"; do - for rate in "${rate_array[@]}"; do - echo "Testing with concurrency=$concurrency, rate=$rate" - echo "" >> $result_file - echo "Concurrency: $concurrency, Request Rate: $rate" >> $result_file - echo "-------------------" >> $result_file - - # Run benchmark test - python /mnt/deepseek/vllm/benchmarks/benchmark_serving.py \ - --backend vllm \ - --trust-remote-code \ - --model auto \ - --tokenizer /mnt/deepseek/DeepSeek-R1-W8A8-VLLM \ - --dataset-name random \ - --random-input-len 4096 \ - --random-output-len 1536 \ - --ignore-eos \ - --num-prompts 400 \ - --max-concurrency $concurrency \ - --request-rate $rate \ - --metric-percentiles 90 \ - --base-url http://localhost:8006 2>&1 | tee -a $result_file - - # Wait for system cool down - sleep 30 - done -done - -# Analyze results -echo "Analysis Results" > analysis_results.txt -echo "=================" >> analysis_results.txt - -# Extract and analyze TPOT data -echo "TPOT Analysis:" >> analysis_results.txt -grep "Mean TPOT" $result_file | awk -F':' '{ - printf "Concurrency %s, Rate %s: %s ms\n", $1, $2, $3 -}' >> analysis_results.txt - -# Extract and analyze throughput data -echo -e "\nThroughput Analysis:" >> analysis_results.txt -grep "Output token throughput" $result_file | awk -F':' '{ - printf "Concurrency %s, Rate %s: %s tokens/s\n", $1, $2, $3 -}' >> analysis_results.txt - -echo "Testing completed. Results saved in $result_file and analysis in analysis_results.txt" diff --git a/tests/e2e/long_term/accuracy/accuracy_multicard.py b/tests/e2e/long_term/accuracy/accuracy_multicard.py index 94e3724258..9dd77a9bbc 100644 --- a/tests/e2e/long_term/accuracy/accuracy_multicard.py +++ b/tests/e2e/long_term/accuracy/accuracy_multicard.py @@ -36,7 +36,7 @@ # pre-trained model path on Hugging Face. # Qwen/Qwen2.5-0.5B-Instruct: accuracy test for DP. -# Qwen/Qwen3-30B-A3B: accuracy test for EP and ETP. +# Qwen/Qwen3-30B-A3B: accuracy test for EP. # deepseek-ai/DeepSeek-V2-Lite: accuracy test for TP. MODEL_NAME = ["Qwen/Qwen3-30B-A3B", "deepseek-ai/DeepSeek-V2-Lite"] @@ -200,62 +200,3 @@ def test_lm_eval_accuracy_dp(model, max_tokens): except subprocess.TimeoutExpired: server_proc.kill() server_proc.wait() - - -@pytest.mark.parametrize("max_tokens", [10]) -@pytest.mark.parametrize("model", ["Qwen/Qwen3-30B-A3B"]) -def test_lm_eval_accuracy_etp(model, max_tokens): - log_file = open("accuracy_etp.log", "a+") - cmd = [ - "vllm", "serve", model, "--max_model_len", "4096", - "--tensor_parallel_size", "4", "--enforce_eager", - "--enable_expert_parallel", "--additional_config", - '{"expert_tensor_parallel_size": "4"}' - ] - server_proc = subprocess.Popen(cmd, - stdout=log_file, - stderr=subprocess.DEVNULL) - - try: - for _ in range(300): - try: - r = requests.get(HEALTH_URL, timeout=1) - if r.status_code == 200: - break - except requests.exceptions.RequestException: - pass - time.sleep(1) - else: - log_file.flush() - log_file.seek(0) - log_content = log_file.read() - pytest.fail( - f"vLLM serve did not become healthy after 300s: {HEALTH_URL}\n" - f"==== vLLM Serve Log Start ===\n{log_content}\n==== vLLM Serve Log End ===" - ) - - prompt = "bejing is a" - payload = { - "prompt": prompt, - "max_tokens": max_tokens, - "sampling_params": { - "temperature": 0.0, - "top_p": 1.0, - "seed": 123 - } - } - resp = requests.post(COMPLETIONS_URL, json=payload, timeout=30) - resp.raise_for_status() - data = resp.json() - - generated = data["choices"][0]["text"].strip() - expected = "city in china. it is the capital city of" - assert generated == expected, f"Expected `{expected}`, got `{generated}`" - - finally: - server_proc.send_signal(signal.SIGINT) - try: - server_proc.wait(timeout=10) - except subprocess.TimeoutExpired: - server_proc.kill() - server_proc.wait() diff --git a/tests/e2e/multicard/test_expert_parallel.py b/tests/e2e/multicard/test_expert_parallel.py new file mode 100644 index 0000000000..87bcbaf4be --- /dev/null +++ b/tests/e2e/multicard/test_expert_parallel.py @@ -0,0 +1,30 @@ +import pytest + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + + +@pytest.mark.parametrize("model_name", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_e2e_ep_correctness(model_name): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + max_tokens = 5 + + with VllmRunner(model_name, tensor_parallel_size=2) as vllm_model: + tp_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + with VllmRunner(model_name, + tensor_parallel_size=2, + enable_expert_parallel=True) as vllm_model: + ep_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=ep_output, + outputs_1_lst=tp_output, + name_0="ep_output", + name_1="tp_output", + ) diff --git a/tests/e2e/multicard/test_fused_moe_allgather_ep.py b/tests/e2e/multicard/test_fused_moe_allgather_ep.py index 221d33f0d1..273008f006 100644 --- a/tests/e2e/multicard/test_fused_moe_allgather_ep.py +++ b/tests/e2e/multicard/test_fused_moe_allgather_ep.py @@ -50,7 +50,6 @@ def test_generate_with_allgather(): "enabled": True, "chunked_prefill_enabled": False, }, - "expert_tensor_parallel_size": 1 }) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -74,6 +73,5 @@ def test_generate_with_alltoall(): "enabled": True, "chunked_prefill_enabled": False, }, - "expert_tensor_parallel_size": 1 }) as vllm_model: vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index d363560dd0..9d83d98f32 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -123,6 +123,7 @@ def _pangu_torchair_test_fixture( distributed_executor_backend="mp", enforce_eager=False, additional_config=additional_config, + enable_expert_parallel=True, ) as vllm_model: # use greedy sampler to make sure the generated results are fix vllm_output = vllm_model.generate_greedy(example_prompts, 5) diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py deleted file mode 100644 index b00eeb90a0..0000000000 --- a/tests/ut/distributed/test_parallel_state.py +++ /dev/null @@ -1,208 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -from unittest.mock import MagicMock, patch - -import pytest -from vllm.distributed.parallel_state import GroupCoordinator - -import vllm_ascend -from tests.ut.base import TestBase -from vllm_ascend.distributed.parallel_state import ( - destory_ascend_model_parallel, get_ep_group, get_etp_group, - init_ascend_model_parallel, model_parallel_initialized) - - -class TestParallelState(TestBase): - - @patch('vllm_ascend.distributed.parallel_state._EP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_get_ep_group_when_initialized(self, mock_ep): - # Act - result = get_ep_group() - - # Assert - assert isinstance(result, GroupCoordinator) - - @patch('vllm_ascend.distributed.parallel_state._EP', None) - def test_get_ep_group_when_not_initialized(self): - # Act & Assert - with pytest.raises(AssertionError) as excinfo: - get_ep_group() - assert "expert model parallel group is not initialized" in str( - excinfo.value) - - @patch('vllm_ascend.distributed.parallel_state._ETP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_get_etp_group_when_initialized(self, mock_etp): - # Act - result = get_etp_group() - - # Assert - assert isinstance(result, GroupCoordinator) - - @patch('vllm_ascend.distributed.parallel_state._ETP', None) - def test_get_etp_group_when_not_initialized(self): - # Act & Assert - with pytest.raises(AssertionError) as excinfo: - get_etp_group() - assert "expert tensor parallel group is not initialized" in str( - excinfo.value) - - @patch('vllm_ascend.distributed.parallel_state._ETP', None) - @patch('vllm_ascend.distributed.parallel_state._EP', None) - def test_model_parallel_initialized_when_both_none(self): - # Act & Assert - assert not model_parallel_initialized() - - @patch('vllm_ascend.distributed.parallel_state._ETP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm_ascend.distributed.parallel_state._EP', None) - def test_model_parallel_initialized_when_ep_none(self, mock_etp): - # Act & Assert - assert not model_parallel_initialized() - - @patch('vllm_ascend.distributed.parallel_state._ETP', None) - @patch('vllm_ascend.distributed.parallel_state._EP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_model_parallel_initialized_when_etp_none(self, mock_ep): - # Act & Assert - assert not model_parallel_initialized() - - @patch('vllm_ascend.distributed.parallel_state._ETP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm_ascend.distributed.parallel_state._EP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_model_parallel_initialized_when_etp_initialized( - self, mock_ep, mock_etp): - # Act & Assert - assert model_parallel_initialized() - - @patch('vllm_ascend.distributed.parallel_state._ETP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm_ascend.distributed.parallel_state._EP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_destroy_when_both_exist(self, mock_ep, mock_etp): - # Act - destory_ascend_model_parallel() - # Assert - mock_ep.destroy.assert_called_once() - mock_etp.destroy.assert_called_once() - assert vllm_ascend.distributed.parallel_state._ETP is None - assert vllm_ascend.distributed.parallel_state._EP is None - - @patch('vllm_ascend.distributed.parallel_state._ETP', None) - @patch('vllm_ascend.distributed.parallel_state._EP', - new_callable=lambda: MagicMock()) - def test_destory_ascend_model_parallel_when_etp_none(self, mock_ep): - # Act - destory_ascend_model_parallel() - # Assert - mock_ep.destroy.assert_called_once() - assert vllm_ascend.distributed.parallel_state._EP is None - assert vllm_ascend.distributed.parallel_state._ETP is None - - @patch('vllm_ascend.distributed.parallel_state._ETP', - new_callable=lambda: MagicMock()) - @patch('vllm_ascend.distributed.parallel_state._EP', None) - def test_destory_ascend_model_parallel_when_ep_none(self, mock_etp): - # Act - destory_ascend_model_parallel() - # Assert - mock_etp.destroy.assert_called_once() - assert vllm_ascend.distributed.parallel_state._ETP is None - assert vllm_ascend.distributed.parallel_state._EP is None - - @patch('vllm_ascend.distributed.parallel_state._ETP', None) - @patch('vllm_ascend.distributed.parallel_state._EP', None) - def test_destory_ascend_model_parallel_when_both_none(self): - # Act - destory_ascend_model_parallel() - # Assert - assert vllm_ascend.distributed.parallel_state._ETP is None - assert vllm_ascend.distributed.parallel_state._EP is None - - @patch('torch.distributed.is_initialized', return_value=True) - @patch('torch.distributed.get_world_size', return_value=8) - @patch('vllm_ascend.distributed.parallel_state.get_world_group', - return_value=MagicMock(device_group='npu:0', local_rank=0)) - @patch('torch.distributed.get_backend', return_value='hccl') - @patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group') - @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', - return_value=False) - def test_init_ascend_model_parallel_normal_case( - self, mock_mp_init, mock_init_group, mock_get_backend, - mock_world_group, mock_get_world_size, mock_is_init): - """Test normal initialization with default parameters""" - # Act - init_ascend_model_parallel() - # Assert - mock_init_group.assert_any_call([[0, 1, 2, 3, 4, 5, 6, 7]], - 0, - 'hccl', - group_name="ep") - mock_init_group.assert_any_call([[0]], 0, 'hccl', group_name="etp") - self.assertIsNotNone(vllm_ascend.distributed.parallel_state._EP) - self.assertIsNotNone(vllm_ascend.distributed.parallel_state._ETP) - - @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', - return_value=True) - def test_init_ascend_model_parallel_skip_if_initialized( - self, mock_mp_init): - """Test skipping when model parallel already initialized""" - with patch.object(vllm_ascend.distributed.parallel_state, - '_EP') as mock_ep, patch.object( - vllm_ascend.distributed.parallel_state, - '_ETP') as mock_etp: - # Act - init_ascend_model_parallel() - # Assert - mock_ep.assert_not_called() - mock_etp.assert_not_called() - - @patch('torch.distributed.is_initialized', return_value=False) - def test_init_ascend_model_parallel_assert_dist_not_init( - self, mock_is_init): - """Test assertion when distributed not initialized""" - # Act & Assert - with self.assertRaises(AssertionError): - init_ascend_model_parallel() - - @patch('torch.distributed.is_initialized', return_value=True) - @patch('torch.distributed.get_world_size', return_value=8) - @patch('vllm_ascend.distributed.parallel_state.get_world_group', - return_value=MagicMock(device_group='npu:0', local_rank=1)) - @patch('torch.distributed.get_backend', return_value='hccl') - @patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group') - @patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', - return_value=False) - def test_init_ascend_model_parallel_custom_params( - self, mock_mp_init, mock_init_group, mock_get_backend, - mock_world_group, mock_get_world_size, mock_is_init): - """Test initialization with custom parallel sizes""" - # Act - init_ascend_model_parallel(expert_parallel_size=2, - expert_tensor_parallel_size=4, - world_size=8, - backend='hccl') - #Assert - mock_init_group.assert_any_call([[0, 4], [1, 5], [2, 6], [3, 7]], - 1, - 'hccl', - group_name="ep") - mock_init_group.assert_any_call([[0, 1, 2, 3], [4, 5, 6, 7]], - 1, - 'hccl', - group_name="etp") diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index f5a28b4fd4..34a5cca3f8 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -42,7 +42,6 @@ def test_init_ascend_config_without_additional_config(self): test_vllm_config = VllmConfig() # No additional config given, check the default value here. ascend_config = init_ascend_config(test_vllm_config) - self.assertEqual(ascend_config.expert_tensor_parallel_size, 0) self.assertIsNone(ascend_config.expert_map_path) torchair_graph_config = ascend_config.torchair_graph_config @@ -75,12 +74,10 @@ def test_init_ascend_config_with_additional_config(self): "ascend_scheduler_config": { "enabled": True }, - "expert_tensor_parallel_size": 1, "expert_map_path": "test_expert_map_path", "refresh": True } ascend_config = init_ascend_config(test_vllm_config) - self.assertEqual(ascend_config.expert_tensor_parallel_size, 1) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") torchair_graph_config = ascend_config.torchair_graph_config diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index f7dc40e635..5d732b665a 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -28,7 +28,6 @@ def setUp(self): self.mock_vllm_config.speculative_config = None self.mock_ascend_config = MagicMock() - self.mock_ascend_config.expert_tensor_parallel_size = 0 self.mock_ascend_config.torchair_graph_config.enabled = False self.mock_ascend_config.ascend_scheduler_config.enabled = False @@ -253,30 +252,6 @@ def test_check_and_update_config_basic_config_update( mock_init_ascend.assert_called_once_with(self.mock_vllm_config) mock_check_ascend.assert_called_once() - @patch("vllm_ascend.utils.is_310p", return_value=False) - @patch("vllm_ascend.ascend_config.check_ascend_config") - @patch("vllm_ascend.ascend_config.init_ascend_config") - def test_check_and_update_config_expert_parallel_enabled( - self, mock_init_ascend, mock_check_ascend, mock_is_310p): - mock_init_ascend.return_value = self.mock_ascend_config - self.mock_vllm_config.parallel_config.enable_expert_parallel = True - self.mock_vllm_config.parallel_config.tensor_parallel_size = 2 - self.mock_vllm_config.parallel_config.world_size_across_dp = 4 - - from vllm_ascend import platform - - importlib.reload(platform) - - self.platform.check_and_update_config(self.mock_vllm_config) - - self.assertEqual( - self.mock_vllm_config.parallel_config.expert_tensor_parallel_size, - 1) - self.assertEqual( - self.mock_vllm_config.parallel_config.expert_parallel_size, - self.mock_vllm_config.parallel_config.world_size_across_dp, - ) - @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 50b0e83618..4bc6e88839 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -44,8 +44,6 @@ def __init__(self, vllm_config): self.ascend_scheduler_config = AscendSchedulerConfig( ascend_scheduler_config) - self.expert_tensor_parallel_size = int( - additional_config.get("expert_tensor_parallel_size", 0)) self.expert_map_path = additional_config.get("expert_map_path", None) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py deleted file mode 100644 index 2778a6ef27..0000000000 --- a/vllm_ascend/distributed/parallel_state.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional - -import torch -from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, - init_model_parallel_group) - -# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for -# customize parallel solution -_EP: Optional[GroupCoordinator] = None -_ETP: Optional[GroupCoordinator] = None - - -def get_ep_group() -> GroupCoordinator: - assert _EP is not None, ("expert model parallel group is not initialized") - return _EP - - -def get_etp_group() -> GroupCoordinator: - assert _ETP is not None, ( - "expert tensor parallel group is not initialized") - return _ETP - - -def model_parallel_initialized(): - return (_ETP is not None and _EP is not None) - - -def init_ascend_model_parallel( - expert_parallel_size: int = 1, - expert_tensor_parallel_size: int = 1, - world_size: Optional[int] = None, - backend: Optional[str] = None, -): - if model_parallel_initialized(): - return - assert torch.distributed.is_initialized() - world_size = world_size or torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - num_expert_parallel_groups = expert_tensor_parallel_size - num_expert_tensor_parallel_groups = expert_parallel_size - - global _EP - group_ranks = [] - for i in range(num_expert_parallel_groups): - ranks = list(range(i, world_size, num_expert_parallel_groups)) - group_ranks.append(ranks) - - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") - - group_ranks = [] - global _ETP - for i in range(num_expert_tensor_parallel_groups): - ranks = list( - range(i * expert_tensor_parallel_size, - (i + 1) * expert_tensor_parallel_size)) - group_ranks.append(ranks) - - _ETP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="etp") - - -def destory_ascend_model_parallel(): - global _EP - if _EP: - _EP.destroy() - _EP = None - - global _ETP - if _ETP: - _ETP.destroy() - _ETP = None diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index bfa86f0ee2..e26859b14c 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -39,7 +39,7 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -69,7 +69,6 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index 609c86f361..0d2d9a653e 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -30,8 +30,8 @@ from vllm.distributed import (divide, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_dp_group, get_tp_group, - get_world_group) +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group, get_world_group) from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul @@ -58,7 +58,6 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p logger = init_logger(__name__) @@ -93,7 +92,7 @@ def __init__( # Divide the weight matrix along the last dimension. output_size = sum(output_sizes) self.output_sizes = output_sizes - self.tp_size = get_world_group().world_size + self.tp_size = get_tp_group().world_size self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -144,8 +143,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_world_group().rank_in_group - tp_size = get_world_group().world_size + tp_rank = get_tp_group().rank_in_group + tp_size = get_tp_group().world_size if output_dim is not None: shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size @@ -204,7 +203,7 @@ def __init__( group=None, ): # Divide the weight matrix along the first dimension. - self.group = group if group is not None else get_world_group() + self.group = group if group is not None else get_tp_group() self.tp_rank = self.group.rank_in_group self.tp_size = self.group.world_size self.input_size_per_partition = divide(input_size, self.tp_size) @@ -357,7 +356,7 @@ def pangu_group8_topk( num_tokens = scores.shape[0] router_scale = _ROUTER_SCALE.squeeze( # type: ignore ) - + # TODO: support disable expert parallel ep_size = get_ep_group().world_size local_num_experts = global_num_experts // ep_size local_num_group = topk // ep_size @@ -464,6 +463,7 @@ def __init__( custom_routing_function=topk_wrapper(num_voted_experts), prefix=f"{prefix}.experts", ) + self.use_ep = self.experts.use_ep self.gate = ReplicatedLinear( config.hidden_size, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index cdd18bd6ae..c857e8b420 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -85,6 +85,7 @@ def forward_oot( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, + moe_parallel_config=self.moe.moe_parallel_config, topk_weights=topk_weights, topk_ids=topk_ids, top_k=top_k, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 1221d8984d..1111f22f2a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -26,7 +26,8 @@ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group, get_tp_group +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import \ FusedMoEConfig # isort: skip @@ -41,7 +42,6 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.communication_op import \ data_parallel_reduce_scatter -from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.utils import (FusedMoEState, dispose_tensor, get_all_reduce_merge_state, get_fused_moe_state, @@ -124,6 +124,7 @@ def fused_experts_with_mc2( topk_weights: torch.Tensor, topk_ids: torch.Tensor, top_k: int, + moe_parallel_config: FusedMoEParallelConfig, expert_map: torch.Tensor = None, moe_all_to_all_group_name: Optional[str] = None, shared_experts: Optional[Any] = None @@ -142,22 +143,20 @@ def fused_experts_with_mc2( rank = torch.distributed.get_rank() quant_mode = 0 - ep_group = get_ep_group().device_group - local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) + ep_rank_id = moe_parallel_config.ep_rank + ep_world_size = moe_parallel_config.ep_size - tp_size = get_etp_group().world_size - tp_rank = rank % tp_size + tp_world_size = moe_parallel_config.tp_size + tp_rank = rank % tp_world_size stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - # "group_tp": self.moe_rs_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, + "tp_world_size": tp_world_size, "tp_rank_id": tp_rank, } kwargs_mc2.update(stage1_kwargs) @@ -217,12 +216,12 @@ def fused_experts_with_mc2( stage3_kwargs = { "ep_send_counts": ep_recv_counts, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, "tp_send_counts": tp_recv_counts, # "group_tp": self.moe_rs_group_name, "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, + "tp_world_size": tp_world_size, "tp_rank_id": tp_rank, } kwargs_mc2.update(stage3_kwargs) @@ -560,6 +559,7 @@ def fused_experts_moge( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + moe_parallel_config: FusedMoEParallelConfig, topk_weights: torch.Tensor, topk_ids: torch.Tensor, top_k: int, @@ -581,7 +581,7 @@ def fused_experts_moge( Returns: hidden_states: Hidden states after routing. """ - ep_size = get_ep_group().world_size + ep_size = moe_parallel_config.ep_size local_num_experts = global_num_experts // ep_size local_num_group = top_k // ep_size @@ -982,7 +982,7 @@ def __init__(self, moe: FusedMoEConfig = None): vllm_config = get_current_vllm_config() self.ep_group = get_ep_group() - self.ep_size = self.ep_group.world_size + self.ep_size = self.moe.moe_parallel_config.ep_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size self.max_model_len = vllm_config.model_config.max_model_len @@ -1074,13 +1074,14 @@ def apply( if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill, is_deepseek_v3_r1) + fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill, + is_deepseek_v3_r1) if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, + moe_parallel_config=self.moe.moe_parallel_config, topk_weights=topk_weights, topk_ids=topk_ids, top_k=top_k, diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index e60344836d..a64eb0f0e9 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -37,17 +37,7 @@ # ================= # ** File: platform/patch_common/patch_distributed.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.distributed.parallel_state.destroy_model_parallel()` -# Why: -# vllm dose not support outside platform maintain its own `CoordinatorGroup`, vllm-ascend maintain EP and ETP -# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy -# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed -# How: -# Call `vllm_ascend.distributed.parallel_state method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup` -# Related PR (if no, explain why): -# Future Plan: -# Remove those patch when vllm merged them -# 2. `vllm.config.ParallelConfig.get_next_dp_init_port` +# 1. `vllm.config.ParallelConfig.get_next_dp_init_port` # Why: # vllm doesn't support get port from environment. # How: diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index d244016076..a1e5f00f9b 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -18,33 +18,12 @@ # This file is a part of the vllm-ascend project. import torch -import vllm -import vllm.distributed import vllm.envs as envs from vllm.config import ParallelConfig from vllm_ascend.utils import is_310p -def ascend_destroy_model_parallel(): - """Set the groups to none and destroy them.""" - from vllm.distributed.parallel_state import _DP, _PP, _TP - if _TP: - _TP.destroy() - _TP = None - - if _PP: - _PP.destroy() - _PP = None - - if _DP: - _DP.destroy() - _DP = None - from vllm_ascend.distributed.parallel_state import \ - destory_ascend_model_parallel - destory_ascend_model_parallel() - - def parallel_config_get_dp_port(self) -> int: """ We might need to initialize process groups in multiple @@ -62,7 +41,6 @@ def parallel_config_get_dp_port(self) -> int: return port -vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f13ed4994f..303dee7028 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -131,24 +131,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if kv_cache_dtype is not None: vllm_config.cache_config.cache_dtype = kv_cache_dtype - if parallel_config: - # Default value for expert tensor parallel size - parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size - - # NOTE: When enable_expert_parallel is True, we follow vLLM convention: - # ep_size = world_size, which means expert_tensor_parallel_size must be 1 - if parallel_config.enable_expert_parallel: - parallel_config.expert_tensor_parallel_size = 1 - # NOTE: When enable_expert_parallel is False and param `asceend_config.expert_tensor_parallel_size` - # is configured, use ascend_config - elif ascend_config.expert_tensor_parallel_size > 0: - parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size - - # Calculate expert parallel size based on world size - parallel_config.expert_parallel_size = ( - parallel_config.world_size_across_dp // - parallel_config.expert_tensor_parallel_size) - if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index edd42e53bf..8fec79fc64 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -20,9 +20,9 @@ import torch import torch_npu from vllm.attention.backends.abstract import AttentionType +from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a0c90ab399..0093578053 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -21,10 +21,10 @@ import torch.distributed as dist import torch_npu from vllm.distributed import GroupCoordinator +from vllm.distributed.parallel_state import get_ep_group import vllm_ascend.envs as envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState, dispose_tensor, get_fused_moe_state, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 634e13cb9e..6642280f61 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -311,8 +311,6 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: parallel_factor = 1 + sum(size > 1 for size in [ parallel_config.data_parallel_size_local, parallel_config.tensor_parallel_size, - parallel_config.expert_parallel_size, - parallel_config.expert_tensor_parallel_size, ]) # Calculate maximum supported batch sizes considering model architecture diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index df03d508e4..2ad04abfbb 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -41,7 +41,6 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator -from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist, check_torchair_cache_exist, @@ -308,18 +307,12 @@ def execute_dummy_batch(self) -> None: def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" - parallel_config = self.vllm_config.parallel_config init_distributed_environment(self.parallel_config.world_size, self.rank, self.distributed_init_method, self.local_rank, "hccl") ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - init_ascend_model_parallel( - parallel_config.expert_parallel_size, - parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, - ) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):