Skip to content

Commit 18ad23e

Browse files
author
tanqingshan (A)
committed
[EPLB] ut for EPLB
Signed-off-by: tanqingshan (A) <t50050625@china.huawei.com>
1 parent 76844ee commit 18ad23e

File tree

7 files changed

+503
-0
lines changed

7 files changed

+503
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
3+
4+
5+
class DummyAdaptor(EplbAdaptor):
6+
def __init__(self, **kwargs):
7+
super().__init__(**kwargs)
8+
self.args = kwargs
9+
10+
def get_rank_expert_workload(self):
11+
return "workload"
12+
13+
def get_init_expert_map(self, num_moe_layers):
14+
return {"layers": num_moe_layers}
15+
16+
def do_update_expert_map(self, layer_id, updated_expert_map):
17+
return {"layer_id": layer_id, "map": updated_expert_map}
18+
19+
def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id):
20+
return {
21+
"layer_id": layer_id,
22+
"replace": local_expert_to_replace,
23+
"buffer": buffer_tensor_id,
24+
}
25+
26+
27+
def test_base_class_methods_raise():
28+
adaptor = EplbAdaptor()
29+
with pytest.raises(NotImplementedError):
30+
adaptor.get_rank_expert_workload()
31+
with pytest.raises(NotImplementedError):
32+
adaptor.get_init_expert_map(1)
33+
with pytest.raises(NotImplementedError):
34+
adaptor.do_update_expert_map(1, {})
35+
with pytest.raises(NotImplementedError):
36+
adaptor.do_update_expert_weight(1, "x", "y")
37+
38+
39+
def test_dummy_adaptor_init_and_args():
40+
adaptor = DummyAdaptor(test_arg=123)
41+
assert adaptor.args["test_arg"] == 123
42+
43+
44+
def test_get_rank_expert_workload():
45+
adaptor = DummyAdaptor()
46+
result = adaptor.get_rank_expert_workload()
47+
assert result == "workload"
48+
49+
50+
def test_get_init_expert_map():
51+
adaptor = DummyAdaptor()
52+
result = adaptor.get_init_expert_map(5)
53+
assert isinstance(result, dict)
54+
assert result["layers"] == 5
55+
56+
57+
def test_do_update_expert_map():
58+
adaptor = DummyAdaptor()
59+
updated = {"expert": 1}
60+
result = adaptor.do_update_expert_map(2, updated)
61+
assert result["layer_id"] == 2
62+
assert result["map"] == updated
63+
64+
65+
def test_do_update_expert_weight():
66+
adaptor = DummyAdaptor()
67+
result = adaptor.do_update_expert_weight(1, "expertA", "bufferX")
68+
assert result["layer_id"] == 1
69+
assert result["replace"] == "expertA"
70+
assert result["buffer"] == "bufferX"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# test_policy_abstract.py
2+
import pytest
3+
from vllm_ascend.eplb.core.policy.policy_abstract import DynamicConfig, EplbPolicy
4+
5+
class DummyPolicy(EplbPolicy):
6+
def rebalance_experts(self, current_expert_table, expert_workload):
7+
return 1, current_expert_table
8+
9+
def test_dynamic_config_attributes():
10+
config = DynamicConfig()
11+
assert config.placement_policy is None
12+
assert config.max_transferred_expert_per_layer == 100
13+
assert config.ep_worldsize == 64
14+
assert config.num_die_per_host == 8
15+
16+
def test_eplb_policy_init_and_method():
17+
config = DynamicConfig()
18+
policy = DummyPolicy(config)
19+
20+
assert policy.config == config
21+
22+
expert_table = [[0, 1, 2]]
23+
workload = [10]
24+
res, new_table = policy.rebalance_experts(expert_table, workload)
25+
26+
assert res == 1
27+
assert new_table == expert_table
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import patch
4+
from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb
5+
6+
class TestDynamicEplb:
7+
8+
def test_add_redundant_basic(self):
9+
current_expert_table = np.array([[[0, 1], [1, 0]]])
10+
expert_workload = np.array([[[2, 3], [4, 1]]])
11+
num_original_expert = 2
12+
result = DynamicEplb.add_redundant(current_expert_table, expert_workload, num_original_expert)
13+
expected = np.array([[2+1, 3+4]])
14+
assert np.array_equal(result, expected)
15+
16+
def test_get_redundant_num(self):
17+
counts = np.array([2, 1, 3])
18+
assert DynamicEplb.get_redundant_num(3, counts) == 3
19+
20+
def test_calculate_max_heat_per_layer(self):
21+
workload_table = np.array([[[1,2],[3,4]], [[2,2],[1,1]]])
22+
max_heat = DynamicEplb.calculate_max_heat_per_layer(workload_table, 2)
23+
assert max_heat == [7, 4]
24+
25+
def test_constraint_expert_local_exchange(self):
26+
current = [[[0,1],[2,3]]]
27+
global_dep = [[[1,0],[3,2]]]
28+
new_dep = DynamicEplb.constraint_expert_local_exchange(current, global_dep)
29+
assert new_dep == [[[0,1],[2,3]]]
30+
31+
def test_compute_balanced_pack_redundancy_normal(self):
32+
origin_weights = [(0, 10), (1, 20)]
33+
result, boxes = DynamicEplb.compute_balanced_pack_redundancy(origin_weights, 2, 1)
34+
assert isinstance(result, list) and len(result) == 2
35+
36+
def test_compute_balanced_pack_redundancy_card0(self):
37+
origin_weights = [(0, 10)]
38+
with pytest.raises(RuntimeError):
39+
DynamicEplb.compute_balanced_pack_redundancy(origin_weights, 0, 0)
40+
41+
def test_compute_balanced_pack_normal(self):
42+
origin_weights = np.array([(0, 10), (1, 20)], dtype=object)
43+
result, boxes = DynamicEplb.compute_balanced_pack(origin_weights, 2)
44+
assert isinstance(result, list) and len(result) == 2
45+
46+
def test_compute_balanced_pack_card0(self):
47+
origin_weights = np.array([(0, 10)], dtype=object)
48+
with pytest.raises(RuntimeError):
49+
DynamicEplb.compute_balanced_pack(origin_weights, 0)
50+
51+
def test_original_compute_balanced_pack_redundancy(self):
52+
origin_weights = [(0, 5), (1, 10)]
53+
result, boxes = DynamicEplb.original_compute_balanced_pack_redundancy(origin_weights, 2, 1)
54+
assert isinstance(result, list) and len(result) == 2
55+
56+
def test_rebalance_experts_normal(self):
57+
expert_table = np.array([[[0,1],[1,0]]])
58+
workload = np.array([[[2,3],[4,1]]])
59+
policy = DynamicEplb(config=None)
60+
change, priority, new_dep = policy.rebalance_experts(expert_table, workload)
61+
assert change in [0,1]
62+
assert isinstance(priority, np.ndarray)
63+
assert isinstance(new_dep, list)
64+
assert np.array(new_dep).shape == expert_table.shape
65+
66+
def test_rebalance_experts_exceptions(self):
67+
policy = DynamicEplb(config=None)
68+
69+
# case1: num_original_expert != expert_num
70+
expert_table = np.array([[[0,1],[1,0]]])
71+
workload = np.array([[[2,3],[4,1]]])
72+
with patch.object(DynamicEplb, 'add_redundant', return_value=np.array([[1,2,3]])):
73+
with pytest.raises(ValueError):
74+
policy.rebalance_experts(expert_table, workload)
75+
76+
# case2: num_npus <= 0
77+
expert_table_zero = np.array([[]]) # 1 layer, 0 NPU, 0 experts
78+
workload_zero = np.array([[]])
79+
with pytest.raises(ValueError):
80+
policy.rebalance_experts(expert_table_zero, workload_zero)
81+
82+
# case3: num_npus < num_redundancy_expert
83+
expert_table_small = np.array([[[0,0]]]) # 1 layer, 1 NPU, 2 experts
84+
workload_small = np.array([[[1,1]]])
85+
with patch.object(DynamicEplb, 'get_redundant_num', return_value=2):
86+
with pytest.raises(ValueError):
87+
policy.rebalance_experts(expert_table_small, workload_small)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import pytest
2+
import numpy as np
3+
from collections import defaultdict
4+
from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import (
5+
DynamicConfig,
6+
DynamicEplbV2,
7+
DynamicTable
8+
)
9+
10+
11+
@pytest.fixture
12+
def config():
13+
return DynamicConfig()
14+
15+
16+
@pytest.fixture
17+
def policy(config):
18+
return DynamicEplbV2(config)
19+
20+
21+
def test_safe_operations(policy):
22+
# safe_divide
23+
assert policy.safe_divide(10, 2) == 5
24+
with pytest.raises(ZeroDivisionError):
25+
policy.safe_divide(1, 0)
26+
27+
# safe_exact_divide
28+
assert policy.safe_exact_divide(10, 3) == 3
29+
with pytest.raises(ZeroDivisionError):
30+
policy.safe_exact_divide(1, 0)
31+
32+
# safe_mod
33+
assert policy.safe_mod(10, 3) == 1
34+
with pytest.raises(ZeroDivisionError):
35+
policy.safe_mod(1, 0)
36+
37+
38+
def test_add_redundant():
39+
workload = np.array([[[1, 2], [3, 4]]])
40+
placement = np.array([[[0, 1], [0, 1]]])
41+
result = DynamicEplbV2.add_redundant(placement, workload, 2)
42+
assert result.shape == (1, 2)
43+
assert np.all(result[0] == [4, 6]) # 0:1+3, 1:2+4
44+
45+
46+
def test_get_redundant_num():
47+
counts = np.array([1, 2, 1])
48+
assert DynamicEplbV2.get_redundant_num(3, counts) == 1 # sum(counts-1)
49+
50+
51+
def test_calculate_max_heat_per_layer():
52+
workload = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
53+
result = DynamicEplbV2.calculate_max_heat_per_layer(workload, 2)
54+
assert result == [7, 15]
55+
56+
57+
def test_calculate_initial_imbalance(policy):
58+
deployment = np.array([[[0, 1], [0, 1]]])
59+
workloads = np.array([[1, 1]])
60+
result = policy.calculate_initial_imbalance(deployment, workloads)
61+
assert isinstance(result, list)
62+
assert len(result) == 1
63+
64+
65+
def test_compute_redundant_assignments(policy):
66+
base_experts = [(0, 10), (1, 5)]
67+
redundant, sorted_weights = policy.compute_redundant_assignments(
68+
base_experts, num_redundant_experts=2, num_experts=2)
69+
assert len(redundant) == 2
70+
assert len(sorted_weights) == 2
71+
72+
73+
def test_prepare_expert_list():
74+
base_experts = [(0, 10), (1, 5)]
75+
redundant_assignments = [[2], []]
76+
result = DynamicEplbV2.prepare_expert_list(base_experts, redundant_assignments, 1)
77+
assert isinstance(result, list)
78+
assert len(result) == 1
79+
80+
81+
def test_non_redundant_expert_information():
82+
origin_deployment = np.array([[0, 1]])
83+
updated_weights = [(0, 10), (1, 5)]
84+
rendun_pos = [[]]
85+
assignments, weights, loads, counts = DynamicEplbV2.non_redundant_expert_information(
86+
origin_deployment, updated_weights, rendun_pos)
87+
assert assignments[0] == [0, 1]
88+
assert loads[0] == 15
89+
90+
91+
def test_recomputing_initial_weight(policy):
92+
layer_workloads = [10, 5]
93+
device_assignments = [[0, 1]]
94+
cur_layer_workload, num_all_experts = policy.recomputing_initial_weight(
95+
layer_workloads, device_assignments)
96+
assert cur_layer_workload[0] == 10
97+
assert num_all_experts[0] == 1
98+
99+
100+
def test_safe_divide_zero_edge_case(policy):
101+
assert policy.safe_divide(0, 1) == 0
102+
assert policy.safe_divide(0, 5) == 0
103+
104+
105+
106+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory
4+
from vllm_ascend.eplb.core.policy.policy_abstract import DynamicConfig
5+
from vllm_ascend.eplb.core.policy.policy_random import RandomLoadBalance
6+
from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb
7+
from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import DynamicEplbV2
8+
9+
10+
@pytest.fixture
11+
def dummy_config():
12+
return DynamicConfig()
13+
14+
15+
@pytest.mark.parametrize("policy_type, expected_class", [
16+
(0, RandomLoadBalance),
17+
(1, DynamicEplb),
18+
(2, DynamicEplbV2),
19+
(999, RandomLoadBalance),
20+
])
21+
def test_generate_policy(policy_type, expected_class, dummy_config):
22+
policy_instance = PolicyFactory.generate_policy(policy_type, dummy_config)
23+
assert isinstance(policy_instance, expected_class)

0 commit comments

Comments
 (0)