Skip to content

[NVFP4] Compressed Tensors Hacks #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: modelopt_act_emulation
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
Expand Down Expand Up @@ -183,8 +183,8 @@
"weights"].type == QuantizationType.FLOAT
else:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.model_validate( # noqa: E501
quant_config.get("input_activations"))
"input_activations"] = quant_config.get(
"input_activations")
return target_scheme_map

@classmethod
Expand Down Expand Up @@ -232,6 +232,27 @@
return (is_weight_only and is_group_quant and is_float_type
and is_4_bits and is_group_size_16 and is_symmetric)

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):

print(input_quant)
is_weight_act_quant = weight_quant is not None and input_quant is not None

Check failure on line 238 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py:238:81: E501 Line too long (82 > 80)
is_group_quant = (input_quant.get("strategy") == "tensor_group"
and weight_quant.strategy
== QuantizationStrategy.GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.get("symmetric")

is_group_size_16 = weight_quant.group_size == 16 and input_quant.get(
"group_size") == 16
is_float_type = weight_quant.type == QuantizationType.FLOAT and input_quant.get(

Check failure on line 246 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py:246:81: E501 Line too long (88 > 80)
"type") == QuantizationType.FLOAT.value
is_4_bits = weight_quant.num_bits == 4 and input_quant.get(
"num_bits") == 4

print(is_weight_act_quant, is_group_quant, is_symmetric,
is_group_size_16, is_float_type, is_4_bits)
return (is_weight_act_quant and is_group_quant and is_float_type
and is_4_bits and is_group_size_16 and is_symmetric)

def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
Expand Down Expand Up @@ -334,6 +355,10 @@
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()

if self._is_fp4a4_nvfp4(weight_quant, input_quant):
print("DONE")
return CompressedTensorsW4A4Fp4()

if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
Expand All @@ -17,5 +18,6 @@
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4"
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional

import torch
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype, ref_nvfp4_quant)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)

__all__ = ["CompressedTensorsW4A4Fp4"]


class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):

def __init__(self):
self.group_size = 16

@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

# Weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)

# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_global_scale", weight_global_scale)

# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)

input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)

def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer) -> None:

global_input_scale = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(global_input_scale,
requires_grad=False)

layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)

swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

x_m, x_k = x.shape
output_dtype = x.dtype

# quantize input to (FP4 and interleaved block scale)
x_global_scale = layer.input_global_scale
x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale,
self.group_size)

# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale

# dequantize weight
w_fp4 = layer.weight_packed.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = layer.weight_global_scale
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device, self.group_size)

# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
CompressionFormat.float_quantized.value, "nvfp4-pack-quantized"
]
return format in _ACTIVATION_QUANTIZATION_FORMATS

Expand Down
Loading