Skip to content

Commit a3319f4

Browse files
authored
[Bugfix] Enforce contiguous input for dynamic_per_token FP8/INT8 quant (#19452)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 9d880f5 commit a3319f4

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

vllm/_custom_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def scaled_fp8_quant(
12701270
device=input.device,
12711271
dtype=torch.float32)
12721272
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
1273-
output, input, scale, scale_ub)
1273+
output, input.contiguous(), scale, scale_ub)
12741274
else:
12751275
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
12761276
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
@@ -1379,8 +1379,8 @@ def scaled_int8_quant(
13791379
dtype=torch.float32)
13801380
input_azp = None if symmetric else torch.empty_like(input_scales,
13811381
dtype=torch.int32)
1382-
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
1383-
input_azp)
1382+
torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(),
1383+
input_scales, input_azp)
13841384
return output, input_scales, input_azp
13851385

13861386

0 commit comments

Comments
 (0)