|
22 | 22 | from vllm.distributed.parallel_state import GroupCoordinator
|
23 | 23 |
|
24 | 24 | from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
|
25 |
| - TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2MergedReplicatedLinear, |
26 |
| - TorchairDeepseekV2MLAAttention, TorchairDeepseekV2MLP, |
27 |
| - TorchairDeepseekV2MoE, TorchairDeepseekV2RowParallelLinear, |
| 25 | + TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM, |
| 26 | + TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention, |
| 27 | + TorchairDeepseekV2MLP, TorchairDeepseekV2MoE, |
| 28 | + TorchairDeepseekV2RowParallelLinear, |
28 | 29 | TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
29 | 30 | TorchairDeepseekV2SiluAndMul)
|
30 | 31 |
|
@@ -309,3 +310,22 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
309 | 310 | model_config=vllm_config.model_config,
|
310 | 311 | quant_config=None)
|
311 | 312 | assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
| 313 | + |
| 314 | + |
| 315 | +def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config): |
| 316 | + model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config) |
| 317 | + |
| 318 | + input_ids = torch.randint(0, 10000, (2, 4)) |
| 319 | + positions = torch.arange(4).repeat(2, 1) |
| 320 | + with patch.object(model.model, |
| 321 | + "forward", |
| 322 | + return_value=torch.randn(2, 4, 128)): |
| 323 | + output = model(input_ids, positions) |
| 324 | + assert output.shape == (2, 4, 128) |
| 325 | + |
| 326 | + weights = [("model.embed_tokens.weight", torch.randn(10000, 128))] |
| 327 | + with patch( |
| 328 | + "vllm.model_executor.model_loader.weight_utils.default_weight_loader" |
| 329 | + ): |
| 330 | + loaded = model.load_weights(weights) |
| 331 | + assert loaded is not None |
0 commit comments