@@ -73,16 +73,20 @@ def test_override_quantization_method(self, mock_is_available):
73
73
self .assertIsNone (result )
74
74
75
75
def test_get_quant_method_for_linear (self ):
76
+ mock_config = MagicMock ()
77
+ mock_config .model_config .hf_config .model_type = None
76
78
linear_layer = MagicMock (spec = LinearBase )
77
79
# Test skipped layer
78
- with patch .object (self .ascend_config ,
80
+ with patch ("vllm_ascend.quantization.quant_config.get_current_vllm_config" , return_value = mock_config ), \
81
+ patch .object (self .ascend_config , \
79
82
'is_layer_skipped_ascend' ,
80
83
return_value = True ):
81
84
method = self .ascend_config .get_quant_method (linear_layer , ".attn" )
82
85
self .assertIsInstance (method , UnquantizedLinearMethod )
83
86
84
87
# Test quantized layer
85
88
with patch .object (self .ascend_config , 'is_layer_skipped_ascend' , return_value = False ), \
89
+ patch ("vllm_ascend.quantization.quant_config.get_current_vllm_config" , return_value = mock_config ), \
86
90
patch ('vllm_ascend.quantization.quant_config.AscendLinearMethod' , return_value = MagicMock ()) as mock_ascend_linear :
87
91
88
92
method = self .ascend_config .get_quant_method (linear_layer , ".attn" )
@@ -93,14 +97,18 @@ def test_get_quant_method_for_linear(self):
93
97
94
98
def test_get_quant_method_for_attention (self ):
95
99
attention_layer = MagicMock (spec = Attention )
96
- with patch ('vllm_ascend.quantization.quant_config.AscendKVCacheMethod' ,
100
+ mock_config = MagicMock ()
101
+ mock_config .model_config .hf_config .model_type = None
102
+ with patch ("vllm_ascend.quantization.quant_config.get_current_vllm_config" , return_value = mock_config ), \
103
+ patch ('vllm_ascend.quantization.quant_config.AscendKVCacheMethod' , \
97
104
return_value = MagicMock ()) as mock_ascend_kvcache :
98
105
# Test with fa_quant_type
99
106
method = self .ascend_config .get_quant_method (
100
107
attention_layer , ".attn" )
101
108
self .assertIs (method , mock_ascend_kvcache .return_value )
102
109
103
- with patch ('vllm_ascend.quantization.quant_config.AscendKVCacheMethod' ,
110
+ with patch ("vllm_ascend.quantization.quant_config.get_current_vllm_config" , return_value = mock_config ), \
111
+ patch ('vllm_ascend.quantization.quant_config.AscendKVCacheMethod' , \
104
112
return_value = MagicMock ()) as mock_ascend_kvcache :
105
113
# Test with kv_quant_type
106
114
modified_config = {"kv_quant_type" : "C8" }
@@ -113,16 +121,20 @@ def test_get_quant_method_for_fused_moe(self):
113
121
fused_moe_layer = MagicMock (spec = FusedMoE )
114
122
fused_moe_layer .moe = MagicMock (spec = FusedMoEConfig )
115
123
fused_moe_layer .moe_config = MagicMock (spec = FusedMoEConfig )
124
+ mock_config = MagicMock ()
125
+ mock_config .model_config .hf_config .model_type = None
116
126
117
127
# Test skipped layer
118
128
with patch .object (self .ascend_config , 'is_layer_skipped_ascend' , return_value = True ), \
129
+ patch ("vllm_ascend.quantization.quant_config.get_current_vllm_config" , return_value = mock_config ), \
119
130
patch ('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod' , return_value = MagicMock ()) as mock_ascend_moe :
120
131
method = self .ascend_config .get_quant_method (
121
132
fused_moe_layer , "moe_layer" )
122
133
self .assertIs (method , mock_ascend_moe .return_value )
123
134
124
135
# Test quantized layer
125
136
with patch .object (self .ascend_config , 'is_layer_skipped_ascend' , return_value = False ), \
137
+ patch ("vllm_ascend.quantization.quant_config.get_current_vllm_config" , return_value = mock_config ), \
126
138
patch ('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod' , return_value = MagicMock ()) as mock_ascend_moe :
127
139
method = self .ascend_config .get_quant_method (
128
140
fused_moe_layer , "moe_layer" )
0 commit comments