diff --git a/auto_round/data_type/mxfp.py b/auto_round/data_type/mxfp.py index f5ad2aa08..6c6290445 100644 --- a/auto_round/data_type/mxfp.py +++ b/auto_round/data_type/mxfp.py @@ -15,7 +15,13 @@ import torch from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE, register_dtype -from auto_round.data_type.utils import floor_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste +from auto_round.data_type.utils import ( + ceil_ste, + floor_ste, + reshape_pad_tensor_by_group_size, + revert_tensor_by_pad, + round_ste, +) MXFP_FORMAT_CACHE = { # data type: ebits, mbits, emax, max_norm, min_norm @@ -95,17 +101,18 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`. """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + data_type = data_type if data_type in MXFP_FORMAT_CACHE else "mx_fp" + str(bits) ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type] orig_dtype = tensor.dtype tensor = tensor.to(torch.float32) - shared_exp, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) + max_val, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) if isinstance(max_scale, torch.Tensor): - shared_exp *= (max_scale.unsqueeze(dim=-1)).to(tensor.device) + max_val *= (max_scale.unsqueeze(dim=-1)).to(tensor.device) else: - shared_exp *= max_scale + max_val *= max_scale # shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) - shared_exp = torch.where(shared_exp == 0, torch.ones_like(shared_exp), torch.log2(shared_exp)) + shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val)) shared_exp = floor_ste(shared_exp) scale_emax = 2 ** (8 - 1) - 1 shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax) @@ -120,8 +127,61 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None +@torch.compile() +def quant_mx_rceil(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, + mantissa_rounding="even", data_type="mx_fp", **kwargs): + """Quantize the given tensor using the specified parameters. + + This function performs quantization on the `tensor` tensor according to the + given bit width (`bits`), data type (`data_type`), and additional parameters. + The quantization process involves scaling the tensor values and adjusting + the exponent and mantissa to fit within the specified format. + + Args: + tensor (torch.Tensor): The tensor containing the tensors to be quantized. + bits (int): The bit width to be used for quantization. + group_size (int): The group size of sharing scale and exponent. + data_type (str): The data type for quantization (e.g., 'mx_fp4'). + v (float): A value used for adjusting the tensors. + max_scale (float or torch.Tensor): The maximum scale to be applied to the tensors. + mantissa_rounding (str): rounding method for mantissa,currently support even,nearest,floor + + Returns: + tuple: A tuple containing the quantized tensors, shared exponent, and None (reserved for future use). + + Raises: + KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`. + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + data_type = data_type if data_type in MXFP_FORMAT_CACHE else "mx_fp" + str(bits) + ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type] + orig_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + max_val, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) + if isinstance(max_scale, torch.Tensor): + max_val *= (max_scale.unsqueeze(dim=-1)).to(tensor.device) + else: + max_val *= max_scale + + # shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) + shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), ceil_ste(torch.log2(max_val / max_norm))) + scale_emax = 2 ** (8 - 1) - 1 + shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax) + + scale = torch.pow(2, shared_exp) + tensor = tensor / scale + v + tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) + tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) + + tensor = tensor * scale + tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) + return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None + + for key in MXFP_FORMAT_CACHE.keys(): QUANT_FUNC_WITH_DTYPE[key] = quant_mx + QUANT_FUNC_WITH_DTYPE[key + "_rceil"] = quant_mx_rceil +QUANT_FUNC_WITH_DTYPE["mx_fp_rceil"] = quant_mx_rceil if __name__ == "__main__": data = torch.tensor([0.0, 0.25, 0.4, 0.75, 1.25, 1.4, 1.75, 2.5, 2.9, 3.5, 5.0, 5.1]) @@ -131,4 +191,4 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, data_neg = data * -1 data2 = quant_element(data_neg, 2, 3, 6.0) - assert (torch.sum(torch.abs(data2 - gt * -1)) < 1e-6) \ No newline at end of file + assert (torch.sum(torch.abs(data2 - gt * -1)) < 1e-6) diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index 748417038..ea8c56ffa 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -163,6 +163,17 @@ def floor_ste(x: torch.Tensor): return (x.floor() - x).detach() + x +def ceil_ste(x: torch.Tensor): + """Straight-Through Estimator for ceil. + + Args: + x: torch.Tensor + + Returns: + torch.Tensor + """ + return (x.ceil() - x).detach() + x + def float8_e4m3fn_ste(x: torch.Tensor): """Straight-Through Estimator (STE) for float8.