Skip to content

support rceil for mxfp #660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 66 additions & 6 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import torch

from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE, register_dtype
from auto_round.data_type.utils import floor_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.data_type.utils import (
ceil_ste,
floor_ste,
reshape_pad_tensor_by_group_size,
revert_tensor_by_pad,
round_ste,
)

MXFP_FORMAT_CACHE = {
# data type: ebits, mbits, emax, max_norm, min_norm
Expand Down Expand Up @@ -95,17 +101,18 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0,
KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`.
"""
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
data_type = data_type if data_type in MXFP_FORMAT_CACHE else "mx_fp" + str(bits)
ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type]
orig_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
shared_exp, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True)
max_val, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True)
if isinstance(max_scale, torch.Tensor):
shared_exp *= (max_scale.unsqueeze(dim=-1)).to(tensor.device)
max_val *= (max_scale.unsqueeze(dim=-1)).to(tensor.device)
else:
shared_exp *= max_scale
max_val *= max_scale

# shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype))
shared_exp = torch.where(shared_exp == 0, torch.ones_like(shared_exp), torch.log2(shared_exp))
shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val))
shared_exp = floor_ste(shared_exp)
scale_emax = 2 ** (8 - 1) - 1
shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax)
Expand All @@ -120,8 +127,61 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0,
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


@torch.compile()
def quant_mx_rceil(tensor, bits=4, group_size=-1, v=0, max_scale=1.0,
mantissa_rounding="even", data_type="mx_fp", **kwargs):
"""Quantize the given tensor using the specified parameters.

This function performs quantization on the `tensor` tensor according to the
given bit width (`bits`), data type (`data_type`), and additional parameters.
The quantization process involves scaling the tensor values and adjusting
the exponent and mantissa to fit within the specified format.

Args:
tensor (torch.Tensor): The tensor containing the tensors to be quantized.
bits (int): The bit width to be used for quantization.
group_size (int): The group size of sharing scale and exponent.
data_type (str): The data type for quantization (e.g., 'mx_fp4').
v (float): A value used for adjusting the tensors.
max_scale (float or torch.Tensor): The maximum scale to be applied to the tensors.
mantissa_rounding (str): rounding method for mantissa,currently support even,nearest,floor

Returns:
tuple: A tuple containing the quantized tensors, shared exponent, and None (reserved for future use).

Raises:
KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`.
"""
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
data_type = data_type if data_type in MXFP_FORMAT_CACHE else "mx_fp" + str(bits)
ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type]
orig_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
max_val, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True)
if isinstance(max_scale, torch.Tensor):
max_val *= (max_scale.unsqueeze(dim=-1)).to(tensor.device)
else:
max_val *= max_scale

# shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype))
shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), ceil_ste(torch.log2(max_val / max_norm)))
scale_emax = 2 ** (8 - 1) - 1
shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax)

scale = torch.pow(2, shared_exp)
tensor = tensor / scale + v
tensor = torch.clamp(tensor, min=-max_norm, max=max_norm)
tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding)

tensor = tensor * scale
tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len)
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


for key in MXFP_FORMAT_CACHE.keys():
QUANT_FUNC_WITH_DTYPE[key] = quant_mx
QUANT_FUNC_WITH_DTYPE[key + "_rceil"] = quant_mx_rceil
QUANT_FUNC_WITH_DTYPE["mx_fp_rceil"] = quant_mx_rceil

if __name__ == "__main__":
data = torch.tensor([0.0, 0.25, 0.4, 0.75, 1.25, 1.4, 1.75, 2.5, 2.9, 3.5, 5.0, 5.1])
Expand All @@ -131,4 +191,4 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0,

data_neg = data * -1
data2 = quant_element(data_neg, 2, 3, 6.0)
assert (torch.sum(torch.abs(data2 - gt * -1)) < 1e-6)
assert (torch.sum(torch.abs(data2 - gt * -1)) < 1e-6)
11 changes: 11 additions & 0 deletions auto_round/data_type/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ def floor_ste(x: torch.Tensor):
return (x.floor() - x).detach() + x


def ceil_ste(x: torch.Tensor):
"""Straight-Through Estimator for ceil.

Args:
x: torch.Tensor

Returns:
torch.Tensor
"""
return (x.ceil() - x).detach() + x


def float8_e4m3fn_ste(x: torch.Tensor):
"""Straight-Through Estimator (STE) for float8.
Expand Down