Skip to content

Commit 02b5d4b

Browse files
committed
[CI] fix
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent 5bfd2a2 commit 02b5d4b

File tree

4 files changed

+63
-15
lines changed

4 files changed

+63
-15
lines changed

tests/ut/models/test_deepseek_v2.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from vllm.distributed.parallel_state import GroupCoordinator
2323

2424
from vllm_ascend.models.deepseek_v2 import (
25-
CustomDeepseekV2ForCausalLM, CustomDeepseekV2MergedReplicatedLinear,
26-
CustomDeepseekV2MLAAttention, CustomDeepseekV2MLP, CustomDeepseekV2MoE,
25+
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
26+
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
2727
CustomDeepseekV2RowParallelLinear,
2828
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
2929
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@@ -267,33 +267,29 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
267267
prefix="layers.1.self_attn")
268268
assert hasattr(attn, "q_proj")
269269

270+
270271
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
271272
# 创建一个简单的配置对象
272273
class SimpleConfig:
274+
273275
def __init__(self):
274276
self.vocab_size = 10000
275277
self.hidden_size = 128
276278

277279
config = SimpleConfig()
278-
280+
279281
# 直接创建lmhead和logits_processor
280282
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
281283
logits_processor = LogitsProcessor(config.vocab_size)
282284

283-
# 创建测试输入
284-
input_ids = torch.randint(0, config.vocab_size, (2, 4))
285-
positions = torch.arange(4).repeat(2, 1)
286-
287285
# 创建模拟输出
288286
mock_output = torch.randn(2, 4, config.hidden_size)
289287
mock_logits = torch.randn(2, 4, config.vocab_size)
290288

291289
# 直接测试logits_processor
292-
with patch.object(lmhead.quant_method,
293-
"apply",
294-
return_value=mock_logits):
290+
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
295291
with patch.object(logits_processor,
296292
"_gather_logits",
297293
return_value=mock_logits):
298294
logits = logits_processor(lmhead, mock_output)
299-
assert logits.shape == (2, 4, config.vocab_size)
295+
assert logits.shape == (2, 4, config.vocab_size)

tests/ut/ops/test_vocab_parallel_embedding.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import torch
2020

21-
from vllm_ascend.ops.vocab_parallel_embedding import \
22-
AscendVocabParallelEmbedding
21+
from vllm_ascend.ops.vocab_parallel_embedding import (
22+
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
2323

2424
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
2525

@@ -178,3 +178,55 @@ def test_output_shape(self):
178178
# Call the forward method
179179
output = layer.forward(input_)
180180
self.assertEqual(output.shape, expected_shape)
181+
182+
183+
class TestAscendLogitsProcessor(unittest.TestCase):
184+
185+
def setUp(self):
186+
self.vocab_size = 50
187+
self.num_embeddings = 50
188+
self.embedding_dim = 10
189+
self.org_num_embeddings = 40
190+
self.padding_size = 8
191+
192+
self.mock_group = MagicMock()
193+
self.mock_group.world_size = 2
194+
self.mock_group.rank_in_group = 0
195+
self.mock_ascend_config = MagicMock()
196+
self.mock_quant_method = MagicMock()
197+
self.mock_quant_method.apply = MagicMock(
198+
return_value=torch.randn(1, self.vocab_size))
199+
self.patches = [
200+
patch("vllm_ascend.ascend_config.get_ascend_config",
201+
return_value=self.mock_ascend_config),
202+
patch(
203+
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
204+
return_value=self.mock_group),
205+
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
206+
return_value=True),
207+
patch(
208+
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
209+
return_value=torch.randn(1, self.vocab_size))
210+
]
211+
212+
for p in self.patches:
213+
p.start()
214+
215+
def tearDown(self):
216+
for p in self.patches:
217+
p.stop()
218+
219+
def test_create_processor(self):
220+
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
221+
self.assertEqual(processor.vocab_size, self.vocab_size)
222+
223+
def test_get_logits(self):
224+
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
225+
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
226+
embedding_dim=self.embedding_dim,
227+
prefix="lm_head")
228+
lmhead.quant_method = self.mock_quant_method
229+
lmhead.quant_method.apply = self.mock_quant_method.apply
230+
hidden_state = torch.randn(1, self.org_num_embeddings)
231+
processor._get_logits(hidden_state, lmhead)
232+
self.mock_quant_method.apply.assert_called_once()

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _get_logits(
207207
self,
208208
hidden_states: torch.Tensor,
209209
lm_head: AscendParallelLMHead,
210-
embedding_bias: Optional[torch.Tensor],
210+
embedding_bias: Optional[torch.Tensor] = None,
211211
) -> Optional[torch.Tensor]:
212212
if lmhead_tp_enable():
213213
return self._get_logits_lmheadtp(hidden_states, lm_head,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def _prepare_inputs(
12781278
logits_indices = spec_decode_metadata.logits_indices
12791279

12801280
if lmhead_tp_enable():
1281-
max_num_reqs_across_dp = padded_num_tokens_across_dp if not with_prefill else self.max_num_reqs
1281+
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
12821282
logits_indices = nn.functional.pad(
12831283
logits_indices,
12841284
(0, max_num_reqs_across_dp - logits_indices.shape[0]))

0 commit comments

Comments
 (0)