Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion fms_mo/modules/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ def __init__(
self.m2_bounded = m2_bounded
self.qm1_mode = qm1_mode
self.qm2_mode = qm2_mode

self.smooth_attn= qcfg.get("smooth_attn", False)
self.smooth_attn_alpha = qcfg.get("smooth_attn_alpha", 0.5)
if self.smooth_attn_alpha < 0 or self.smooth_attn_alpha > 1:
raise ValueError(
"smooth_attn_alpha must be in range [0,1] "
f"(given: {self.smooth_attn_alpha})"
)
self.m1_clip_init_val = kwargs.get(
"m1_clip_init_val", qcfg.get("m1_clip_init_val", 1.0)
)
Expand Down Expand Up @@ -191,6 +197,12 @@ def forward(self, m1, m2):
Returns:
torch.Tensor: Output tensor after quantized bmm.
"""
if self.smooth_attn:
attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chichun-charlie-liu I don't know enough about bmm, but is this the correct assumption for all bmm to use these dims for amax()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iqbal-saraf let's not hard-coded dimension like this... and possibly want to keep the flexibility for per-token and per-channel as well.

Copy link
Collaborator

@andrea-fasoli andrea-fasoli Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if BMM are always 4D, then this should be the only way to create the smooth attention scales, as they must correspond to the reduction dimension of the matmul (dim=2 for the second input). We can't compute per-token scales (dim=3 or dim=-1) here.

This implementation mirrors smoothquant where we compute weight scales for dim=1 only (computing max along dim=0). The difference in amax dimension choice is due to weights being 2D, as well as transposed, vs BMM being 4D, but in both cases the choice of dimension is fixed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if BMM are necessarily 4D. I believe that if batch size is not passed, the tensor is unsqueezed from 3D to 4D, but not sure if it is a guarantee. If not, we should also deal with 3D BMM or, alternatively, raise an error if BMM are not 4D.

attn_scales = attn_scales.pow(self.smooth_attn_alpha)
m1 *= attn_scales
m2 /= attn_scales.reshape(1,1,m2.shape[2], 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


# pylint: disable = access-member-before-definition
if self.calib_counter:
with torch.no_grad():
Expand Down
70 changes: 55 additions & 15 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,24 @@ def get_activation_quantizer(
)
elif qa_mode == "dorefa":
act_quantizer = dorefa_quantize_activation
elif (
qa_mode == "max"
): # NOTE Need to be careful using this for activation, particular to 1 sided.
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "minmax":
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)

elif "max" in qa_mode:
# NOTE Need to be careful using this for activation, particular to 1 sided.
if "min" in qa_mode:
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
elif "pertoken" in qa_mode or "perToken" in qa_mode:
act_quantizer = QMaxDynamic(nbits, dim=-1)
elif "per_channel" in qa_mode or "perCh" in qa_mode:
act_quantizer = QMaxDynamic(nbits, dim=-2)
elif "sym" in qa_mode:
act_quantizer = Qmax(
nbits,
align_zero=True,
minmax=False,
extend_act_range=extend_act_range,
)
else:
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "fix":
act_quantizer = QFixSymmetric(
nbits, init_clip_val=clip_val, align_zero=align_zero
Expand All @@ -140,13 +152,7 @@ def get_activation_quantizer(
minmax=False,
extend_act_range=extend_act_range,
)
elif qa_mode == "pactsym":
act_quantizer = PACT2Sym(
nbits,
init_clip_val=clip_val,
dequantize=True,
inplace=False,
)

elif qa_mode == "pactsym+":
act_quantizer = PACTplusSym(
nbits,
Expand Down Expand Up @@ -179,8 +185,6 @@ def get_activation_quantizer(
perToken=perToken,
emulate=True,
)
elif qa_mode == "pertokenmax":
act_quantizer = PerTokenMax(nbits)
else:
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
else: # swcap-compatible activation quantizers
Expand Down Expand Up @@ -3488,6 +3492,42 @@ def __repr__(self):
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"


class QMaxDynamic(nn.Module):
def __init__(self, num_bits, dim=-1):
"""
For per-token or per-channel quantization using abs().max() as scale, usually for activation
and could be used for Qbmm M2 as well.
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
dim = -2 -> per-channel
Zero is aligned so that the levels are symmetric around zero (lossing one level)
Since the token length is un-known before running, the quantizater can only calculate the
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
(unless input seq length is always the same, not just padded to a fixed length.)
"""
super().__init__()
self.num_bits = num_bits
self.levels = 2 ** (self.num_bits - 1) - 1
if isinstance(dim, str):
if "perCh" in dim or "per_channel" in dim:
dim = -2
Copy link
Collaborator

@BrandonGroth BrandonGroth Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to store dim into self.reduce_dim for perCh and perToken

elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
dim = -1
elif dim in [-1, -2]:
self.reduce_dim = dim
else:
raise ValueError(
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
)

def forward(self, input_tensor):
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
scales = amax_dim.clamp(min=1e-5).div(self.levels)
return input_tensor.div(scales).round().mul(scales)

def __repr__(self):
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"


class Qdynamic(nn.Module):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions fms_mo/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class FMSMOArguments(TypeChecker):
bmm2_qm1_mode: str = field(default="pact", metadata={"help": ("bmm2.m1 quanitzer")})
bmm2_qm2_mode: str = field(default="pact", metadata={"help": ("bmm2.m1 quanitzer")})
smoothq_alpha: float = field(default=0.65, metadata={"help": "smooth quant alpha"})
smooth_attn_alpha: float = field(default=0.5, metadata={"help": "smooth attention alpha"})
smooth_attn: bool = field(default=False, metadata={"help": "enable smooth attention"})
qmodel_calibration: int = field(
default=0,
metadata={"help": "Num of batches for Qmodel calibration, using model copy."},
Expand Down
50 changes: 29 additions & 21 deletions fms_mo/utils/qconfig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def config_defaults() -> dict:
"smoothq": False,
"smoothq_scale_layers": [],
"smoothq_act_scale_path": None,
"smooth_attn": False,
# Other vars
"which2patch_contextmanager": None,
"force_stop_if_qbmm_auto_check_failed": False,
Expand Down Expand Up @@ -940,11 +941,16 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
"pactsym+",
"max",
"minmax",
"maxbmm",
"maxsym",
"pertokenmax",
"lsq+",
"fix",
"brecq",
]
shared_modes = [
"max_perToken",
"max_perCh",
# fp8_e4m3
"fp8_e4m3_sat",
"fp8_e4m3_scale",
Expand Down Expand Up @@ -981,33 +987,34 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
"brecq",
"adaround",
"pertokenmax",
# fp8_e4m3
"fp8_e4m3_sat",
"fp8_e4m3_scale",
"fp8_e4m3_sat_perCh",
"fp8_e4m3_scale_perCh",
"fp8_e4m3_sat_perToken",
"fp8_e4m3_scale_perToken",
# fp8_e5m2
"fp8_e5m2_sat",
"fp8_e5m2_scale",
"fp8_e5m2_sat_perCh",
"fp8_e5m2_scale_perCh",
"fp8_e5m2_sat_perToken",
"fp8_e5m2_scale_perToken",
# # fp8_e4m3
# "fp8_e4m3_sat",
# "fp8_e4m3_scale",
# "fp8_e4m3_sat_perCh",
# "fp8_e4m3_scale_perCh",
# "fp8_e4m3_sat_perToken",
# "fp8_e4m3_scale_perToken",
# # fp8_e5m2
# "fp8_e5m2_sat",
# "fp8_e5m2_scale",
# "fp8_e5m2_sat_perCh",
# "fp8_e5m2_scale_perCh",
# "fp8_e5m2_sat_perToken",
# "fp8_e5m2_scale_perToken",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason for commenting all FP8 quantizers in this config check?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I did have this concern as well. They are moving them to the "shared" checks for both activations and weights, but I don't necessarily agree. We may want to disable some of these for activations in the future as the don't work.

]
bmm_mode_settings = [
"pact",
"pactsym",
"pactsym+",
"maxsym",
"maxbmm",
"max",
"minmax",
"pertokenmax",
"fp8_e4m3_sat",
"fp8_e4m3_scale_perToken",
"fp8_e5m2_sat",
"fp8_e5m2_scale_perToken",
# "fp8_e4m3_sat",
# "fp8_e4m3_scale_perToken",
# "fp8_e5m2_sat",
# "fp8_e5m2_scale_perToken",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above: why are fp8 quantization options not checked?

]

# Get strings in config for qa_modes, qw_modes, bmm_modes
Expand Down Expand Up @@ -1043,15 +1050,15 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
# Check each for correct ranges
for qa_mode_str in qa_modes_str:
qa_mode = config.get(qa_mode_str, "pact+")
if not qa_mode in (qa_mode_settings + mx_spec_config_modes):
if not qa_mode in (qa_mode_settings + mx_spec_config_modes + shared_modes):
raise ValueError(
f"{qa_mode_str} = {qa_mode} is not set to one of the following: "
f"{qa_mode_settings + mx_spec_config_modes}"
)

for qw_mode_str in qw_modes_str:
qw_mode = config.get(qw_mode_str, "sawb+")
if not qw_mode in (qw_mode_settings + mx_spec_config_modes):
if not qw_mode in (qw_mode_settings + mx_spec_config_modes + shared_modes):
raise ValueError(
f"{qw_mode_str} = {qw_mode} is not set to one of the following: "
f"{qw_mode_settings + mx_spec_config_modes}"
Expand All @@ -1063,7 +1070,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
bmm_mode_consistency += bmm_mode.startswith("mx_")
# mx_specs doesn't have 4 individual bmmX_qmY_modes, it re-uses w and a fmt instead.
# We will keep them in qcfg (with "mx_" prefix NOT removed).
if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes):
if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes + shared_modes):
raise ValueError(
f"{bmm_mode_str} = {bmm_mode} is not set to one of the following: "
f"{bmm_mode_settings + mx_spec_config_modes}"
Expand Down Expand Up @@ -1101,6 +1108,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
"qskip_large_mag_layers",
"recompute_narrow_weights",
"smoothq",
"smooth_attn",
]
for boolean_var_str in boolean_vars_str:
boolean_var = config.get(
Expand Down
Loading