Skip to content

Commit 07cc624

Browse files
committed
rebase main and fix
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent f02f59b commit 07cc624

File tree

2 files changed

+158
-157
lines changed

2 files changed

+158
-157
lines changed

tests/ut/ops/test_vocab_parallel_embedding.py

Lines changed: 157 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -18,165 +18,166 @@
1818

1919
import torch
2020

21-
from vllm_ascend.ops.vocab_parallel_embedding import (AscendLogitsProcessor,
22-
AscendParallelLMHead)
21+
from vllm_ascend.ops.vocab_parallel_embedding import (
22+
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
2323

2424
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
2525

26-
# class TestCustomVocabParallelEmbedding(unittest.TestCase):
27-
28-
# def setUp(self):
29-
# self.num_embeddings = 50
30-
# self.embedding_dim = 10
31-
# self.org_num_embeddings = 40
32-
# self.padding_size = 8
33-
34-
# def _create_layer(self):
35-
# # Patch methods and dependencies for VocabParallelEmbedding
36-
# mock_group = MagicMock()
37-
# mock_group.world_size = 2
38-
# mock_group.rank_in_group = 0
39-
# with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
40-
# patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
41-
# patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
42-
# patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
43-
# patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
44-
45-
# # Create an instance of VocabParallelEmbedding
46-
# layer = AscendVocabParallelEmbedding(
47-
# num_embeddings=self.num_embeddings,
48-
# embedding_dim=self.embedding_dim,
49-
# org_num_embeddings=self.org_num_embeddings,
50-
# padding_size=self.padding_size,
51-
# quant_config=None, # Mock quantization config
52-
# prefix="")
53-
54-
# layer.shard_indices = MagicMock()
55-
# layer.shard_indices.org_vocab_start_index = 10
56-
# layer.shard_indices.org_vocab_end_index = 20
57-
# layer.shard_indices.num_org_vocab_padding = 5
58-
# layer.shard_indices.added_vocab_start_index = 30
59-
# layer.shard_indices.added_vocab_end_index = 40
60-
61-
# # Mock the quantization method
62-
# layer.quant_method.embedding = MagicMock(
63-
# side_effect=lambda _, x: torch.randn(x.shape[0], self.
64-
# embedding_dim))
65-
# return layer
66-
67-
# def test_get_masked_input_and_mask(self):
68-
# """Test the mask and offset calculation helper function."""
69-
# layer = self._create_layer()
70-
71-
# input_ = torch.tensor([5, 15, 25, 35, 45])
72-
73-
# masked_input, mask = layer._get_masked_input_and_mask(
74-
# input_,
75-
# org_vocab_start_index=10,
76-
# org_vocab_end_index=20,
77-
# num_org_vocab_padding=5,
78-
# added_vocab_start_index=30,
79-
# added_vocab_end_index=40)
80-
81-
# expected_mask = torch.tensor([True, False, True, False, True])
82-
# self.assertTrue(
83-
# torch.equal(mask, expected_mask),
84-
# f"Mask mismatch. Expected {expected_mask}, got {mask}")
85-
86-
# expected_masked = torch.tensor([0, 5, 0, 20, 0])
87-
# self.assertTrue(
88-
# torch.equal(masked_input, expected_masked),
89-
# f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
90-
# )
91-
92-
# def test_forward_with_tp_size_1(self):
93-
# """Test forward pass without tensor parallelism."""
94-
# # Create a fresh mock embedding with tp_size=1
95-
# layer = self._create_layer()
96-
# layer.tp_size = 1
97-
# layer.quant_method.embedding = MagicMock(
98-
# return_value=torch.randn(3, layer.embedding_dim))
99-
100-
# input_ = torch.tensor([1, 2, 3])
101-
102-
# with patch(
103-
# "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
104-
# side_effect=lambda x: x) as mock_reduce_tp1:
105-
# output = layer.forward(input_)
106-
107-
# # Should just pass through without masking
108-
# layer.quant_method.embedding.assert_called_once_with(
109-
# layer, input_.long())
110-
# self.assertEqual(output.shape, (3, layer.embedding_dim))
111-
112-
# # Verify all_reduce was called once
113-
# mock_reduce_tp1.assert_called_once()
114-
115-
# def test_forward_with_tp(self):
116-
# layer = self._create_layer()
117-
# layer.tp_size = 2
118-
119-
# input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
120-
121-
# with patch(
122-
# "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
123-
# side_effect=lambda x: x) as mock_reduce_tp:
124-
# # Call the forward method
125-
# output = layer.forward(input_)
126-
127-
# # Check that masking was applied correctly
128-
# layer.quant_method.embedding.assert_called_once()
129-
# called_input = layer.quant_method.embedding.call_args[0][1]
130-
# expected_input = torch.tensor([5, 20]) # after offset calculation
131-
# self.assertTrue(torch.all(called_input == expected_input))
132-
133-
# # Check that all reduce was called
134-
# mock_reduce_tp.assert_called_once()
135-
# self.assertEqual(output.shape, (2, self.embedding_dim))
136-
137-
# def test_forward_with_invalid_vocab(self):
138-
# """Test that invalid vocab indices are properly masked out."""
139-
# # Create a fresh embedding layer
140-
# layer = self._create_layer()
141-
# input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
142-
# # Create predictable mock output
143-
# mock_output = torch.randn(5, self.embedding_dim)
144-
# layer.quant_method.embedding = MagicMock(
145-
# return_value=mock_output.clone())
146-
147-
# # Patch tensor_model_parallel_all_reduce to mock its behavior
148-
# with patch(
149-
# "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
150-
# side_effect=lambda x: x):
151-
# # Call the forward method
152-
# output = layer.forward(input_)
153-
# # Check that invalid positions (0, 2, 4) were zeroed out
154-
# self.assertTrue(torch.all(output[0] == 0))
155-
# self.assertTrue(torch.all(output[2] == 0))
156-
# self.assertTrue(torch.all(output[4] == 0))
157-
# self.assertTrue(torch.all(output[1] == mock_output[1]))
158-
# self.assertTrue(torch.all(output[3] == mock_output[3]))
159-
# self.assertEqual(output.shape, (5, self.embedding_dim))
160-
161-
# def test_output_shape(self):
162-
# """Test that output shape is correct."""
163-
# # Create a fresh embedding layer
164-
# layer = self._create_layer()
165-
166-
# test_cases = [
167-
# (torch.tensor([15]), (1, self.embedding_dim)),
168-
# (torch.tensor([15, 35]), (2, self.embedding_dim)),
169-
# (torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
170-
# ]
171-
172-
# for input_, expected_shape in test_cases:
173-
# with self.subTest(input=input_):
174-
# with patch(
175-
# "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
176-
# side_effect=lambda x: x):
177-
# # Call the forward method
178-
# output = layer.forward(input_)
179-
# self.assertEqual(output.shape, expected_shape)
26+
27+
class TestCustomVocabParallelEmbedding(unittest.TestCase):
28+
29+
def setUp(self):
30+
self.num_embeddings = 50
31+
self.embedding_dim = 10
32+
self.org_num_embeddings = 40
33+
self.padding_size = 8
34+
35+
def _create_layer(self):
36+
# Patch methods and dependencies for VocabParallelEmbedding
37+
mock_group = MagicMock()
38+
mock_group.world_size = 2
39+
mock_group.rank_in_group = 0
40+
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
41+
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
42+
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
43+
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
44+
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
45+
46+
# Create an instance of VocabParallelEmbedding
47+
layer = AscendVocabParallelEmbedding(
48+
num_embeddings=self.num_embeddings,
49+
embedding_dim=self.embedding_dim,
50+
org_num_embeddings=self.org_num_embeddings,
51+
padding_size=self.padding_size,
52+
quant_config=None, # Mock quantization config
53+
prefix="")
54+
55+
layer.shard_indices = MagicMock()
56+
layer.shard_indices.org_vocab_start_index = 10
57+
layer.shard_indices.org_vocab_end_index = 20
58+
layer.shard_indices.num_org_vocab_padding = 5
59+
layer.shard_indices.added_vocab_start_index = 30
60+
layer.shard_indices.added_vocab_end_index = 40
61+
62+
# Mock the quantization method
63+
layer.quant_method.embedding = MagicMock(
64+
side_effect=lambda _, x: torch.randn(x.shape[0], self.
65+
embedding_dim))
66+
return layer
67+
68+
def test_get_masked_input_and_mask(self):
69+
"""Test the mask and offset calculation helper function."""
70+
layer = self._create_layer()
71+
72+
input_ = torch.tensor([5, 15, 25, 35, 45])
73+
74+
masked_input, mask = layer._get_masked_input_and_mask(
75+
input_,
76+
org_vocab_start_index=10,
77+
org_vocab_end_index=20,
78+
num_org_vocab_padding=5,
79+
added_vocab_start_index=30,
80+
added_vocab_end_index=40)
81+
82+
expected_mask = torch.tensor([True, False, True, False, True])
83+
self.assertTrue(
84+
torch.equal(mask, expected_mask),
85+
f"Mask mismatch. Expected {expected_mask}, got {mask}")
86+
87+
expected_masked = torch.tensor([0, 5, 0, 20, 0])
88+
self.assertTrue(
89+
torch.equal(masked_input, expected_masked),
90+
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
91+
)
92+
93+
def test_forward_with_tp_size_1(self):
94+
"""Test forward pass without tensor parallelism."""
95+
# Create a fresh mock embedding with tp_size=1
96+
layer = self._create_layer()
97+
layer.tp_size = 1
98+
layer.quant_method.embedding = MagicMock(
99+
return_value=torch.randn(3, layer.embedding_dim))
100+
101+
input_ = torch.tensor([1, 2, 3])
102+
103+
with patch(
104+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
105+
side_effect=lambda x: x) as mock_reduce_tp1:
106+
output = layer.forward(input_)
107+
108+
# Should just pass through without masking
109+
layer.quant_method.embedding.assert_called_once_with(
110+
layer, input_.long())
111+
self.assertEqual(output.shape, (3, layer.embedding_dim))
112+
113+
# Verify all_reduce was called once
114+
mock_reduce_tp1.assert_called_once()
115+
116+
def test_forward_with_tp(self):
117+
layer = self._create_layer()
118+
layer.tp_size = 2
119+
120+
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
121+
122+
with patch(
123+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
124+
side_effect=lambda x: x) as mock_reduce_tp:
125+
# Call the forward method
126+
output = layer.forward(input_)
127+
128+
# Check that masking was applied correctly
129+
layer.quant_method.embedding.assert_called_once()
130+
called_input = layer.quant_method.embedding.call_args[0][1]
131+
expected_input = torch.tensor([5, 20]) # after offset calculation
132+
self.assertTrue(torch.all(called_input == expected_input))
133+
134+
# Check that all reduce was called
135+
mock_reduce_tp.assert_called_once()
136+
self.assertEqual(output.shape, (2, self.embedding_dim))
137+
138+
def test_forward_with_invalid_vocab(self):
139+
"""Test that invalid vocab indices are properly masked out."""
140+
# Create a fresh embedding layer
141+
layer = self._create_layer()
142+
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
143+
# Create predictable mock output
144+
mock_output = torch.randn(5, self.embedding_dim)
145+
layer.quant_method.embedding = MagicMock(
146+
return_value=mock_output.clone())
147+
148+
# Patch tensor_model_parallel_all_reduce to mock its behavior
149+
with patch(
150+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
151+
side_effect=lambda x: x):
152+
# Call the forward method
153+
output = layer.forward(input_)
154+
# Check that invalid positions (0, 2, 4) were zeroed out
155+
self.assertTrue(torch.all(output[0] == 0))
156+
self.assertTrue(torch.all(output[2] == 0))
157+
self.assertTrue(torch.all(output[4] == 0))
158+
self.assertTrue(torch.all(output[1] == mock_output[1]))
159+
self.assertTrue(torch.all(output[3] == mock_output[3]))
160+
self.assertEqual(output.shape, (5, self.embedding_dim))
161+
162+
def test_output_shape(self):
163+
"""Test that output shape is correct."""
164+
# Create a fresh embedding layer
165+
layer = self._create_layer()
166+
167+
test_cases = [
168+
(torch.tensor([15]), (1, self.embedding_dim)),
169+
(torch.tensor([15, 35]), (2, self.embedding_dim)),
170+
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
171+
]
172+
173+
for input_, expected_shape in test_cases:
174+
with self.subTest(input=input_):
175+
with patch(
176+
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
177+
side_effect=lambda x: x):
178+
# Call the forward method
179+
output = layer.forward(input_)
180+
self.assertEqual(output.shape, expected_shape)
180181

181182

182183
class TestAscendLogitsProcessor(unittest.TestCase):

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def _prepare_inputs(
12781278
logits_indices = spec_decode_metadata.logits_indices
12791279

12801280
if lmhead_tp_enable():
1281-
max_num_reqs_across_dp = padded_num_tokens_across_dp if not with_prefill else self.max_num_reqs
1281+
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
12821282
logits_indices = nn.functional.pad(
12831283
logits_indices,
12841284
(0, max_num_reqs_across_dp - logits_indices.shape[0]))

0 commit comments

Comments
 (0)