Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/ut/torchair/ops/test_torchair_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def mock_dist_env(mocker: MockerFixture):
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
enable_shared_expert_dp=False,
expert_map_path=None
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
Expand All @@ -79,7 +81,9 @@ def load_model(self, model) -> None:
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
if self.torchair_graph_enabled:
if self.torchair_graph_enabled or (
self.enable_shared_expert_dp
and self.vllm_config.model_config.use_mla):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/torchair/models/torchair_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def forward(

attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
else:
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/torchair/ops/torchair_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def __init__(self, moe: FusedMoEConfig = None):

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

try:
device_group = get_mc2_group().device_group
Expand Down Expand Up @@ -884,6 +885,8 @@ def apply(
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)

fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All

if fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
Expand Down Expand Up @@ -1155,6 +1158,8 @@ def forward(self,
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ def __init__(self):

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp

try:
device_group = get_mc2_group().device_group
Expand Down Expand Up @@ -936,6 +937,8 @@ def apply(
)

fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def _sync_metadata_across_dp(
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.enable_shared_expert_dp:
return super()._sync_metadata_across_dp(num_tokens, with_prefill,
enable_dbo)
# Padding is not required for shared_expert_dp cases in eager mode.
return num_tokens, None, with_prefill, enable_dbo
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
Expand Down
39 changes: 19 additions & 20 deletions vllm_ascend/torchair/torchair_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,25 @@ def determine_available_memory(self) -> int:
ascend_config = get_ascend_config()
if ascend_config.enable_shared_expert_dp:
return available_kv_cache_memory
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist(
):
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory

if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
if check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory

def init_device(self):
Expand Down
Loading