|
24 | 24 | from contextlib import contextmanager, nullcontext
|
25 | 25 | from enum import Enum
|
26 | 26 | from threading import Lock
|
27 |
| -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union |
| 27 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
28 | 28 |
|
29 | 29 | import torch
|
30 | 30 | import torch_npu # noqa: F401 # noqa: F401
|
|
60 | 60 | _SLEEP_MODE_ENABLED = None
|
61 | 61 | _CURRENT_STREAM = None
|
62 | 62 | _ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
| 63 | +_MIN_DP_BUFFER_SIZE = 50 |
| 64 | +_MIN_MC2_BUFFER_SIZE = 512 |
63 | 65 |
|
64 | 66 |
|
65 | 67 | def is_310p():
|
@@ -634,3 +636,84 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
|
634 | 636 | return nullcontext()
|
635 | 637 | assert target_stream is not None
|
636 | 638 | return torch.npu.stream(target_stream)
|
| 639 | + |
| 640 | + |
| 641 | +def create_hccl_pg_options(group_name: str): |
| 642 | + options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options() |
| 643 | + hccl_config = get_hccl_config_for_pg_options(group_name) |
| 644 | + if hccl_config is not None: |
| 645 | + options.hccl_config = hccl_config |
| 646 | + return options |
| 647 | + |
| 648 | + |
| 649 | +def get_hccl_config_for_pg_options(group_name: str) -> Optional[Dict]: |
| 650 | + hccl_config_map = { |
| 651 | + "dp": { |
| 652 | + "hccl_buffer_size": calculate_dp_buffer_size() |
| 653 | + }, |
| 654 | + "mc2": { |
| 655 | + "hccl_buffer_size": calculate_mc2_buffer_size() |
| 656 | + }, |
| 657 | + } |
| 658 | + return hccl_config_map.get(group_name) |
| 659 | + |
| 660 | + |
| 661 | +def calculate_dp_buffer_size() -> int: |
| 662 | + """ |
| 663 | + formula of dp buffer size: |
| 664 | + dp_size + 2 (flags: with_prefill and enable_dbo) |
| 665 | + """ |
| 666 | + from vllm.config import get_current_vllm_config |
| 667 | + vllm_config = get_current_vllm_config() |
| 668 | + dp_size = vllm_config.parallel_config.data_parallel_size |
| 669 | + int32_size = torch.iinfo(torch.int32).bits // 8 |
| 670 | + dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024)) |
| 671 | + return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) |
| 672 | + |
| 673 | + |
| 674 | +def calculate_mc2_buffer_size() -> int: |
| 675 | + """ |
| 676 | + formula of mc2 buffer size: |
| 677 | + 2 * (local_routed_expert_num * max_bs_per_rank * ep_world_size * align512(align32(2 * H) + 64) + |
| 678 | + (K + shared_expert_num) * max_bs_per_rank * align512(2 * H)) |
| 679 | + """ |
| 680 | + from vllm.config import get_current_vllm_config |
| 681 | + vllm_config = get_current_vllm_config() |
| 682 | + dp_size = vllm_config.parallel_config.data_parallel_size |
| 683 | + tp_size = vllm_config.parallel_config.tensor_parallel_size |
| 684 | + # since mc2 process group is only used with enable_expert_parallel, |
| 685 | + # it is unnecessary to consider non-ep case. |
| 686 | + ep_world_size = dp_size * tp_size |
| 687 | + # FIXME: try to get local_routed_expert_num and redundant_expert_num (default: 2) elegantly |
| 688 | + global_routed_expert_num = getattr(vllm_config.model_config.hf_config, |
| 689 | + 'n_routed_experts', 1) |
| 690 | + local_routed_expert_num = math.ceil( |
| 691 | + global_routed_expert_num / ep_world_size) + 2 |
| 692 | + # tokens are passed to shread experts without mc2 ops. |
| 693 | + shared_expert_num = 0 |
| 694 | + |
| 695 | + max_bs_per_rank = math.ceil(vllm_config.scheduler_config.max_num_seqs / |
| 696 | + tp_size) |
| 697 | + # take MTP into consideration and it is better to align vllm speculative method later. |
| 698 | + mtp_spec_token_num = ( |
| 699 | + vllm_config.speculative_config.num_speculative_tokens + 1 if |
| 700 | + (vllm_config.speculative_config is not None |
| 701 | + and vllm_config.speculative_config.method == 'deepseek_mtp') else 1) |
| 702 | + max_bs_per_rank *= mtp_spec_token_num |
| 703 | + H = vllm_config.model_config.hf_config.hidden_size |
| 704 | + K = getattr(vllm_config.model_config.hf_config, 'num_experts_per_tok', 1) |
| 705 | + |
| 706 | + aligned_2H_32 = _round_up(2 * H, 32) |
| 707 | + aligned_2H_512 = _round_up(2 * H, 512) |
| 708 | + |
| 709 | + # local_routed_expert_num * max_bs_per_rank * ep_world_size * Align512(Align32(2 * H) + 64) |
| 710 | + part1_base = aligned_2H_32 + 64 |
| 711 | + aligned_part1 = _round_up(part1_base, 512) |
| 712 | + local_expert_component = local_routed_expert_num * max_bs_per_rank * ep_world_size * aligned_part1 |
| 713 | + |
| 714 | + # (K + shared_expert_num) * max_bs_per_rank * Align512(2 * H) |
| 715 | + shared_expert_component = ( |
| 716 | + K + shared_expert_num) * max_bs_per_rank * aligned_2H_512 |
| 717 | + mc2_buffer_size = math.ceil( |
| 718 | + 2 * (local_expert_component + shared_expert_component) / (1024 * 1024)) |
| 719 | + return max(mc2_buffer_size, _MIN_MC2_BUFFER_SIZE) |
0 commit comments