Skip to content

Conversation

iqbal-saraf
Copy link

Description of the change

Adding smooth attention feature

This PR adds the smooth attention feature for smoothing bmm attention inputs.

Related issues or PRs

#139

How to verify the PR

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added (if that coverage is difficult, please briefly explain the reason)
  • I have ensured all unit tests pass

Checklist for passing CI/CD:

  • All commits are signed showing "Signed-off-by: Name <email@domain.com>" with git commit -signoff or equivalent
  • PR title and commit messages adhere to Conventional Commits
  • Contribution is formatted with tox -e fix
  • Contribution passes linting with tox -e lint
  • Contribution passes spellcheck with tox -e spellcheck
  • Contribution passes all unit tests with tox -e unit

Note: CI/CD performs unit tests on multiple versions of Python from a fresh install. There may be differences with your local environment and the test environment.

Signed-off-by: Iqbal Saraf <iqbal.saraf@ibm.com>
Signed-off-by: Iqbal Saraf <iqbal.saraf@ibm.com>
@github-actions github-actions bot added the feat label Jun 16, 2025
torch.Tensor: Output tensor after quantized bmm.
"""
if self.smooth_attn:
attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chichun-charlie-liu I don't know enough about bmm, but is this the correct assumption for all bmm to use these dims for amax()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iqbal-saraf let's not hard-coded dimension like this... and possibly want to keep the flexibility for per-token and per-channel as well.

Copy link
Collaborator

@andrea-fasoli andrea-fasoli Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if BMM are always 4D, then this should be the only way to create the smooth attention scales, as they must correspond to the reduction dimension of the matmul (dim=2 for the second input). We can't compute per-token scales (dim=3 or dim=-1) here.

This implementation mirrors smoothquant where we compute weight scales for dim=1 only (computing max along dim=0). The difference in amax dimension choice is due to weights being 2D, as well as transposed, vs BMM being 4D, but in both cases the choice of dimension is fixed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if BMM are necessarily 4D. I believe that if batch size is not passed, the tensor is unsqueezed from 3D to 4D, but not sure if it is a guarantee. If not, we should also deal with 3D BMM or, alternatively, raise an error if BMM are not 4D.

attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5)
attn_scales = attn_scales.pow(self.smooth_attn_alpha)
m1 *= attn_scales
m2 /= attn_scales.reshape(1,1,m2.shape[2], 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

self.levels = 2 ** (self.num_bits - 1) - 1
if isinstance(dim, str):
if "perCh" in dim or "per_channel" in dim:
dim = -2
Copy link
Collaborator

@BrandonGroth BrandonGroth Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to store dim into self.reduce_dim for perCh and perToken

# "fp8_e5m2_sat_perCh",
# "fp8_e5m2_scale_perCh",
# "fp8_e5m2_sat_perToken",
# "fp8_e5m2_scale_perToken",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason for commenting all FP8 quantizers in this config check?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I did have this concern as well. They are moving them to the "shared" checks for both activations and weights, but I don't necessarily agree. We may want to disable some of these for activations in the future as the don't work.

# "fp8_e4m3_sat",
# "fp8_e4m3_scale_perToken",
# "fp8_e5m2_sat",
# "fp8_e5m2_scale_perToken",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above: why are fp8 quantization options not checked?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants