3
3
4
4
try :
5
5
from petit_kernel import mul_nvfp4_a16 , process_nvfp4_scales , repack_nvfp4
6
+ _PETIT_AVAILABLE = True
6
7
except ImportError :
8
+ _PETIT_AVAILABLE = False
7
9
8
- def _check_petit_nvfp4_supported (
9
- quant_method : str , group_size : Optional [int ]
10
- ) -> tuple [bool , Optional [str ]]:
11
- return (
12
- False ,
13
- "Petit is not installed. Please install it with `pip install petit-kernel`." ,
14
- )
15
-
16
- def prepare_nvfp4_layer_for_petit (layer : torch .nn .Module ) -> None :
17
- raise ValueError (
18
- "Petit is not installed. Please install it with `pip install petit-kernel`."
19
- )
20
-
21
- def apply_petit_nvfp4_linear (
22
- input : torch .Tensor ,
23
- weight : torch .Tensor ,
24
- weight_scale : torch .Tensor ,
25
- weight_scale_2 : torch .Tensor ,
26
- size_n : int ,
27
- size_k : int ,
28
- bias : Optional [torch .Tensor ] = None ,
29
- ) -> torch .Tensor :
30
- raise ValueError (
31
- "Petit is not installed. Please install it with `pip install petit-kernel`."
32
- )
10
+ _PETIT_INSTALL_MSG = (
11
+ "Petit is not installed. Please install it with "
12
+ "`pip install petit-kernel`."
13
+ )
33
14
15
+ def _require_petit () -> None :
16
+ if not _PETIT_AVAILABLE :
17
+ # 统一的报错出口,避免重复代码与行过长
18
+ raise ImportError (_PETIT_INSTALL_MSG )
34
19
35
20
def _check_petit_nvfp4_supported (
36
21
quant_method : str , group_size : Optional [int ]
37
22
) -> tuple [bool , Optional [str ]]:
38
23
if quant_method != "NVFP4" :
39
24
return (
40
25
False ,
41
- "Petit currently only supports: NVFP4"
42
- " quantizations in sglang. Please check the "
43
- "`hf_quant_config.json` file for your model's "
44
- "quant configuration." ,
26
+ (
27
+ "Petit currently only supports: NVFP4 quantizations in sglang. "
28
+ "Please check the `hf_quant_config.json` file for your model's "
29
+ "quant configuration."
30
+ ),
45
31
)
46
32
if group_size is not None and group_size != 16 :
47
33
return (
48
34
False ,
49
- "Petit currently only supports: group_size=16" " quantizations." ,
35
+ "Petit currently only supports: group_size=16 quantizations." ,
50
36
)
51
37
return (True , None )
52
38
53
-
54
- def verify_petit_nvfp4_supported (quant_method : str , group_size : Optional [int ]) -> None :
39
+ def verify_petit_nvfp4_supported (
40
+ quant_method : str , group_size : Optional [int ]
41
+ ) -> None :
55
42
supported , error_msg = _check_petit_nvfp4_supported (quant_method , group_size )
56
43
if not supported :
44
+ # 避免 mypy 对 Optional[str] 报错
45
+ assert error_msg is not None
57
46
raise ValueError (error_msg )
58
47
59
-
60
48
def prepare_nvfp4_layer_for_petit (layer : torch .nn .Module ) -> None :
49
+ _require_petit () # 没装 petit 时在这里统一报错
50
+
61
51
# Repack weights to petit format
62
52
part_size_n = layer .output_size_per_partition
63
53
part_size_k = layer .input_size_per_partition
@@ -71,9 +61,6 @@ def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
71
61
)
72
62
layer .weight_scale = torch .nn .Parameter (weight_scale , requires_grad = False )
73
63
74
- return
75
-
76
-
77
64
def apply_petit_nvfp4_linear (
78
65
input : torch .Tensor ,
79
66
weight : torch .Tensor ,
@@ -83,6 +70,8 @@ def apply_petit_nvfp4_linear(
83
70
size_k : int ,
84
71
bias : Optional [torch .Tensor ] = None ,
85
72
) -> torch .Tensor :
73
+ _require_petit () # 没装 petit 时在这里统一报错
74
+
86
75
reshaped_x = input .reshape (- 1 , input .shape [- 1 ])
87
76
out_shape = input .shape [:- 1 ] + (size_n ,)
88
77
@@ -100,4 +89,4 @@ def apply_petit_nvfp4_linear(
100
89
if bias is not None :
101
90
output .add_ (bias ) # In-place add
102
91
103
- return output .reshape (out_shape )
92
+ return output .reshape (out_shape )
0 commit comments