-
Notifications
You must be signed in to change notification settings - Fork 16
feat: Smooth attention #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,7 +82,13 @@ def __init__( | |
| self.m2_bounded = m2_bounded | ||
| self.qm1_mode = qm1_mode | ||
| self.qm2_mode = qm2_mode | ||
|
|
||
| self.smooth_attn= qcfg.get("smooth_attn", False) | ||
| self.smooth_attn_alpha = qcfg.get("smooth_attn_alpha", 0.5) | ||
| if self.smooth_attn_alpha < 0 or self.smooth_attn_alpha > 1: | ||
| raise ValueError( | ||
| "smooth_attn_alpha must be in range [0,1] " | ||
| f"(given: {self.smooth_attn_alpha})" | ||
| ) | ||
| self.m1_clip_init_val = kwargs.get( | ||
| "m1_clip_init_val", qcfg.get("m1_clip_init_val", 1.0) | ||
| ) | ||
|
|
@@ -191,6 +197,12 @@ def forward(self, m1, m2): | |
| Returns: | ||
| torch.Tensor: Output tensor after quantized bmm. | ||
| """ | ||
| if self.smooth_attn: | ||
| attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5) | ||
| attn_scales = attn_scales.pow(self.smooth_attn_alpha) | ||
| m1 *= attn_scales | ||
| m2 /= attn_scales.reshape(1,1,m2.shape[2], 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
|
|
||
| # pylint: disable = access-member-before-definition | ||
| if self.calib_counter: | ||
| with torch.no_grad(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,12 +123,24 @@ def get_activation_quantizer( | |
| ) | ||
| elif qa_mode == "dorefa": | ||
| act_quantizer = dorefa_quantize_activation | ||
| elif ( | ||
| qa_mode == "max" | ||
| ): # NOTE Need to be careful using this for activation, particular to 1 sided. | ||
| act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) | ||
| elif qa_mode == "minmax": | ||
| act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) | ||
|
|
||
| elif "max" in qa_mode: | ||
| # NOTE Need to be careful using this for activation, particular to 1 sided. | ||
| if "min" in qa_mode: | ||
| act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) | ||
| elif "pertoken" in qa_mode or "perToken" in qa_mode: | ||
| act_quantizer = QMaxDynamic(nbits, dim=-1) | ||
| elif "per_channel" in qa_mode or "perCh" in qa_mode: | ||
| act_quantizer = QMaxDynamic(nbits, dim=-2) | ||
| elif "sym" in qa_mode: | ||
| act_quantizer = Qmax( | ||
| nbits, | ||
| align_zero=True, | ||
| minmax=False, | ||
| extend_act_range=extend_act_range, | ||
| ) | ||
| else: | ||
| act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) | ||
| elif qa_mode == "fix": | ||
| act_quantizer = QFixSymmetric( | ||
| nbits, init_clip_val=clip_val, align_zero=align_zero | ||
|
|
@@ -140,13 +152,7 @@ def get_activation_quantizer( | |
| minmax=False, | ||
| extend_act_range=extend_act_range, | ||
| ) | ||
| elif qa_mode == "pactsym": | ||
| act_quantizer = PACT2Sym( | ||
| nbits, | ||
| init_clip_val=clip_val, | ||
| dequantize=True, | ||
| inplace=False, | ||
| ) | ||
|
|
||
| elif qa_mode == "pactsym+": | ||
| act_quantizer = PACTplusSym( | ||
| nbits, | ||
|
|
@@ -179,8 +185,6 @@ def get_activation_quantizer( | |
| perToken=perToken, | ||
| emulate=True, | ||
| ) | ||
| elif qa_mode == "pertokenmax": | ||
| act_quantizer = PerTokenMax(nbits) | ||
| else: | ||
| raise ValueError(f"unrecognized activation quantization mode {qa_mode}") | ||
| else: # swcap-compatible activation quantizers | ||
|
|
@@ -3488,6 +3492,42 @@ def __repr__(self): | |
| return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" | ||
|
|
||
|
|
||
| class QMaxDynamic(nn.Module): | ||
| def __init__(self, num_bits, dim=-1): | ||
| """ | ||
| For per-token or per-channel quantization using abs().max() as scale, usually for activation | ||
| and could be used for Qbmm M2 as well. | ||
| (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token | ||
| dim = -2 -> per-channel | ||
| Zero is aligned so that the levels are symmetric around zero (lossing one level) | ||
| Since the token length is un-known before running, the quantizater can only calculate the | ||
| scales at the run times dynamically, meaning no trainable quantization scales is allowed. | ||
| (unless input seq length is always the same, not just padded to a fixed length.) | ||
| """ | ||
| super().__init__() | ||
| self.num_bits = num_bits | ||
| self.levels = 2 ** (self.num_bits - 1) - 1 | ||
| if isinstance(dim, str): | ||
| if "perCh" in dim or "per_channel" in dim: | ||
| dim = -2 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to store dim into self.reduce_dim for perCh and perToken |
||
| elif "perToken" in dim or "per_token" in dim or "per_Token" in dim: | ||
| dim = -1 | ||
| elif dim in [-1, -2]: | ||
| self.reduce_dim = dim | ||
| else: | ||
| raise ValueError( | ||
| f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}" | ||
| ) | ||
|
|
||
| def forward(self, input_tensor): | ||
| amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0] | ||
| scales = amax_dim.clamp(min=1e-5).div(self.levels) | ||
| return input_tensor.div(scales).round().mul(scales) | ||
|
|
||
| def __repr__(self): | ||
| return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" | ||
|
|
||
|
|
||
| class Qdynamic(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -149,6 +149,7 @@ def config_defaults() -> dict: | |
| "smoothq": False, | ||
| "smoothq_scale_layers": [], | ||
| "smoothq_act_scale_path": None, | ||
| "smooth_attn": False, | ||
| # Other vars | ||
| "which2patch_contextmanager": None, | ||
| "force_stop_if_qbmm_auto_check_failed": False, | ||
|
|
@@ -940,11 +941,16 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: | |
| "pactsym+", | ||
| "max", | ||
| "minmax", | ||
| "maxbmm", | ||
| "maxsym", | ||
| "pertokenmax", | ||
| "lsq+", | ||
| "fix", | ||
| "brecq", | ||
| ] | ||
| shared_modes = [ | ||
| "max_perToken", | ||
| "max_perCh", | ||
| # fp8_e4m3 | ||
| "fp8_e4m3_sat", | ||
| "fp8_e4m3_scale", | ||
|
|
@@ -981,33 +987,34 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: | |
| "brecq", | ||
| "adaround", | ||
| "pertokenmax", | ||
| # fp8_e4m3 | ||
| "fp8_e4m3_sat", | ||
| "fp8_e4m3_scale", | ||
| "fp8_e4m3_sat_perCh", | ||
| "fp8_e4m3_scale_perCh", | ||
| "fp8_e4m3_sat_perToken", | ||
| "fp8_e4m3_scale_perToken", | ||
| # fp8_e5m2 | ||
| "fp8_e5m2_sat", | ||
| "fp8_e5m2_scale", | ||
| "fp8_e5m2_sat_perCh", | ||
| "fp8_e5m2_scale_perCh", | ||
| "fp8_e5m2_sat_perToken", | ||
| "fp8_e5m2_scale_perToken", | ||
| # # fp8_e4m3 | ||
| # "fp8_e4m3_sat", | ||
| # "fp8_e4m3_scale", | ||
| # "fp8_e4m3_sat_perCh", | ||
| # "fp8_e4m3_scale_perCh", | ||
| # "fp8_e4m3_sat_perToken", | ||
| # "fp8_e4m3_scale_perToken", | ||
| # # fp8_e5m2 | ||
| # "fp8_e5m2_sat", | ||
| # "fp8_e5m2_scale", | ||
| # "fp8_e5m2_sat_perCh", | ||
| # "fp8_e5m2_scale_perCh", | ||
| # "fp8_e5m2_sat_perToken", | ||
| # "fp8_e5m2_scale_perToken", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the reason for commenting all FP8 quantizers in this config check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I did have this concern as well. They are moving them to the "shared" checks for both activations and weights, but I don't necessarily agree. We may want to disable some of these for activations in the future as the don't work. |
||
| ] | ||
| bmm_mode_settings = [ | ||
| "pact", | ||
| "pactsym", | ||
| "pactsym+", | ||
| "maxsym", | ||
| "maxbmm", | ||
| "max", | ||
| "minmax", | ||
| "pertokenmax", | ||
| "fp8_e4m3_sat", | ||
| "fp8_e4m3_scale_perToken", | ||
| "fp8_e5m2_sat", | ||
| "fp8_e5m2_scale_perToken", | ||
| # "fp8_e4m3_sat", | ||
| # "fp8_e4m3_scale_perToken", | ||
| # "fp8_e5m2_sat", | ||
| # "fp8_e5m2_scale_perToken", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question as above: why are fp8 quantization options not checked? |
||
| ] | ||
|
|
||
| # Get strings in config for qa_modes, qw_modes, bmm_modes | ||
|
|
@@ -1043,15 +1050,15 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: | |
| # Check each for correct ranges | ||
| for qa_mode_str in qa_modes_str: | ||
| qa_mode = config.get(qa_mode_str, "pact+") | ||
| if not qa_mode in (qa_mode_settings + mx_spec_config_modes): | ||
| if not qa_mode in (qa_mode_settings + mx_spec_config_modes + shared_modes): | ||
| raise ValueError( | ||
| f"{qa_mode_str} = {qa_mode} is not set to one of the following: " | ||
| f"{qa_mode_settings + mx_spec_config_modes}" | ||
| ) | ||
|
|
||
| for qw_mode_str in qw_modes_str: | ||
| qw_mode = config.get(qw_mode_str, "sawb+") | ||
| if not qw_mode in (qw_mode_settings + mx_spec_config_modes): | ||
| if not qw_mode in (qw_mode_settings + mx_spec_config_modes + shared_modes): | ||
| raise ValueError( | ||
| f"{qw_mode_str} = {qw_mode} is not set to one of the following: " | ||
| f"{qw_mode_settings + mx_spec_config_modes}" | ||
|
|
@@ -1063,7 +1070,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: | |
| bmm_mode_consistency += bmm_mode.startswith("mx_") | ||
| # mx_specs doesn't have 4 individual bmmX_qmY_modes, it re-uses w and a fmt instead. | ||
| # We will keep them in qcfg (with "mx_" prefix NOT removed). | ||
| if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes): | ||
| if not bmm_mode in (bmm_mode_settings + mx_spec_config_modes + shared_modes): | ||
| raise ValueError( | ||
| f"{bmm_mode_str} = {bmm_mode} is not set to one of the following: " | ||
| f"{bmm_mode_settings + mx_spec_config_modes}" | ||
|
|
@@ -1101,6 +1108,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None: | |
| "qskip_large_mag_layers", | ||
| "recompute_narrow_weights", | ||
| "smoothq", | ||
| "smooth_attn", | ||
| ] | ||
| for boolean_var_str in boolean_vars_str: | ||
| boolean_var = config.get( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chichun-charlie-liu I don't know enough about bmm, but is this the correct assumption for all bmm to use these dims for amax()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iqbal-saraf let's not hard-coded dimension like this... and possibly want to keep the flexibility for per-token and per-channel as well.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if BMM are always 4D, then this should be the only way to create the smooth attention scales, as they must correspond to the reduction dimension of the matmul (dim=2 for the second input). We can't compute per-token scales (dim=3 or dim=-1) here.
This implementation mirrors smoothquant where we compute weight scales for dim=1 only (computing max along dim=0). The difference in amax dimension choice is due to weights being 2D, as well as transposed, vs BMM being 4D, but in both cases the choice of dimension is fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if BMM are necessarily 4D. I believe that if batch size is not passed, the tensor is unsqueezed from 3D to 4D, but not sure if it is a guarantee. If not, we should also deal with 3D BMM or, alternatively, raise an error if BMM are not 4D.