-
Notifications
You must be signed in to change notification settings - Fork 16
feat: Smooth attention #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Iqbal Saraf <iqbal.saraf@ibm.com>
Signed-off-by: Iqbal Saraf <iqbal.saraf@ibm.com>
torch.Tensor: Output tensor after quantized bmm. | ||
""" | ||
if self.smooth_attn: | ||
attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chichun-charlie-liu I don't know enough about bmm, but is this the correct assumption for all bmm to use these dims for amax()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iqbal-saraf let's not hard-coded dimension like this... and possibly want to keep the flexibility for per-token and per-channel as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if BMM are always 4D, then this should be the only way to create the smooth attention scales, as they must correspond to the reduction dimension of the matmul (dim=2 for the second input). We can't compute per-token scales (dim=3 or dim=-1) here.
This implementation mirrors smoothquant where we compute weight scales for dim=1 only (computing max along dim=0). The difference in amax dimension choice is due to weights being 2D, as well as transposed, vs BMM being 4D, but in both cases the choice of dimension is fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if BMM are necessarily 4D. I believe that if batch size is not passed, the tensor is unsqueezed from 3D to 4D, but not sure if it is a guarantee. If not, we should also deal with 3D BMM or, alternatively, raise an error if BMM are not 4D.
attn_scales= m2.abs().amax(dim=(0,1,3)).clamp(min=1e-5) | ||
attn_scales = attn_scales.pow(self.smooth_attn_alpha) | ||
m1 *= attn_scales | ||
m2 /= attn_scales.reshape(1,1,m2.shape[2], 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
self.levels = 2 ** (self.num_bits - 1) - 1 | ||
if isinstance(dim, str): | ||
if "perCh" in dim or "per_channel" in dim: | ||
dim = -2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to store dim into self.reduce_dim for perCh and perToken
# "fp8_e5m2_sat_perCh", | ||
# "fp8_e5m2_scale_perCh", | ||
# "fp8_e5m2_sat_perToken", | ||
# "fp8_e5m2_scale_perToken", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reason for commenting all FP8 quantizers in this config check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I did have this concern as well. They are moving them to the "shared" checks for both activations and weights, but I don't necessarily agree. We may want to disable some of these for activations in the future as the don't work.
# "fp8_e4m3_sat", | ||
# "fp8_e4m3_scale_perToken", | ||
# "fp8_e5m2_sat", | ||
# "fp8_e5m2_scale_perToken", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question as above: why are fp8 quantization options not checked?
Description of the change
Adding smooth attention feature
This PR adds the smooth attention feature for smoothing bmm attention inputs.
Related issues or PRs
#139
How to verify the PR
Was the PR tested
Checklist for passing CI/CD:
git commit -signoff
or equivalenttox -e fix
tox -e lint
tox -e spellcheck
tox -e unit
Note: CI/CD performs unit tests on multiple versions of Python from a fresh install. There may be differences with your local environment and the test environment.