Skip to content

Commit bdae3e4

Browse files
committed
fix ut
Signed-off-by: AlvisGong <gwly0401@163.com>
1 parent f839093 commit bdae3e4

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11

2-

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from vllm.distributed.parallel_state import GroupCoordinator
2323

2424
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
25-
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2MergedReplicatedLinear,
26-
TorchairDeepseekV2MLAAttention, TorchairDeepseekV2MLP,
27-
TorchairDeepseekV2MoE, TorchairDeepseekV2RowParallelLinear,
25+
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
26+
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
27+
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
28+
TorchairDeepseekV2RowParallelLinear,
2829
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
2930
TorchairDeepseekV2SiluAndMul)
3031

@@ -309,3 +310,22 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
309310
model_config=vllm_config.model_config,
310311
quant_config=None)
311312
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
313+
314+
315+
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
316+
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
317+
318+
input_ids = torch.randint(0, 10000, (2, 4))
319+
positions = torch.arange(4).repeat(2, 1)
320+
with patch.object(model.model,
321+
"forward",
322+
return_value=torch.randn(2, 4, 128)):
323+
output = model(input_ids, positions)
324+
assert output.shape == (2, 4, 128)
325+
326+
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
327+
with patch(
328+
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
329+
):
330+
loaded = model.load_weights(weights)
331+
assert loaded is not None

0 commit comments

Comments
 (0)