From 5cd909b25fe6a2ad8367646bd3c0812a03ff464c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 16 May 2025 00:57:29 +0000 Subject: [PATCH] hacks --- .../compressed_tensors/compressed_tensors.py | 37 ++++- .../compressed_tensors/schemes/__init__.py | 4 +- .../schemes/compressed_tensors_w4a4_nvfp4.py | 137 ++++++++++++++++++ .../quantization/compressed_tensors/utils.py | 2 +- 4 files changed, 172 insertions(+), 8 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 27547f315fe..bfa0d59fd3a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -23,10 +23,10 @@ CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, - CompressedTensorsScheme, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -183,8 +183,8 @@ def _quantization_scheme_map_from_config( "weights"].type == QuantizationType.FLOAT else: target_scheme_map[target][ - "input_activations"] = QuantizationArgs.model_validate( # noqa: E501 - quant_config.get("input_activations")) + "input_activations"] = quant_config.get( + "input_activations") return target_scheme_map @classmethod @@ -232,6 +232,27 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, return (is_weight_only and is_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) + def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + + print(input_quant) + is_weight_act_quant = weight_quant is not None and input_quant is not None + is_group_quant = (input_quant.get("strategy") == "tensor_group" + and weight_quant.strategy + == QuantizationStrategy.GROUP.value) + is_symmetric = weight_quant.symmetric and input_quant.get("symmetric") + + is_group_size_16 = weight_quant.group_size == 16 and input_quant.get( + "group_size") == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT and input_quant.get( + "type") == QuantizationType.FLOAT.value + is_4_bits = weight_quant.num_bits == 4 and input_quant.get( + "num_bits") == 4 + + print(is_weight_act_quant, is_group_quant, is_symmetric, + is_group_size_16, is_float_type, is_4_bits) + return (is_weight_act_quant and is_group_quant and is_float_type + and is_4_bits and is_group_size_16 and is_symmetric) + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 @@ -334,6 +355,10 @@ def _get_scheme_from_parts( if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp4a4_nvfp4(weight_quant, input_quant): + print("DONE") + return CompressedTensorsW4A4Fp4() + if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 79bf5c108ac..f44aada6d20 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 @@ -17,5 +18,6 @@ "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24", "CompressedTensorsW4A16Fp4" + "CompressedTensors24", "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A4Fp4" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 00000000000..92b96796303 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + dequantize_to_dtype, ref_nvfp4_quant) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_global_scale", input_global_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer) -> None: + + global_input_scale = layer.input_global_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(global_input_scale, + requires_grad=False) + + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + x_m, x_k = x.shape + output_dtype = x.dtype + + # quantize input to (FP4 and interleaved block scale) + x_global_scale = layer.input_global_scale + x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, + self.group_size) + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size) + x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + # dequantize weight + w_fp4 = layer.weight_packed.data.view(torch.uint8) + w_blockscale = layer.weight_scale_swizzled.data + w_global_scale = layer.weight_global_scale + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + output_dtype, x.device, self.group_size) + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + del w_dq, x_dq + return out diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index ccd54281ceb..2dc79986072 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -13,7 +13,7 @@ def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ CompressionFormat.naive_quantized.value, CompressionFormat.int_quantized.value, - CompressionFormat.float_quantized.value, + CompressionFormat.float_quantized.value, "nvfp4-pack-quantized" ] return format in _ACTIVATION_QUANTIZATION_FORMATS