-
Notifications
You must be signed in to change notification settings - Fork 107
TE inference executor for 8 bit #2632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
def te_linear_fp8_impl(x, qweight, bias, absmax, scale): | ||
wq = transformer_engine.pytorch.Float8Quantizer( | ||
scale=scale, | ||
amax=absmax, | ||
fp8_dtype=transformer_engine_torch.DType.kFloat8E4M3, | ||
rowwise=True, | ||
columnwise=False, | ||
) | ||
|
||
w = wq.create_tensor_from_data(qweight, fake_dtype=x.dtype, requires_grad=False) | ||
|
||
minmax = x.aminmax() | ||
xmax = torch.maximum(minmax.min.abs(), minmax.max.abs()).to(torch.float32) | ||
xq = transformer_engine.pytorch.Float8Quantizer( | ||
scale=1.0 / xmax, # this needs to 1 (or even somewhat smaller for accumulation?) | ||
amax=xmax, | ||
fp8_dtype=transformer_engine_torch.DType.kFloat8E4M3, | ||
rowwise=True, | ||
columnwise=False, | ||
) | ||
|
||
out, *_ = transformer_engine.pytorch.ops.BasicLinear._functional_forward( | ||
x, | ||
w, | ||
input_quantizer=xq, | ||
with_quantized_compute=True, | ||
weight_requires_grad=False, | ||
input_requires_grad=False, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the core of how I compute linears. Does this make sense? What is a good scale for the input? I used 1/absmax now such that x[j, k] * w[i, k]
would be < fp8_max
, not sure if that is needed.
|
||
minmax = x.aminmax() | ||
xmax = torch.maximum(minmax.min.abs(), minmax.max.abs()).to(torch.float32) | ||
xq = transformer_engine.pytorch.Float8Quantizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we necessarily want to use Float8Quantizer
which ties up with DelayedScaling recipe (but works for both Hopper and Blackwell archs).
Instead we can also use MXFP8Quantizer
(note: this is only supported with Blackwell arch). For this recipe, the quantization of the input will not depend on scale from previous iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly, I don't currently have easy access to Blackwell, so I'm keen to support hopper, too.
Of course, I'd 100% love to have something flexible enough for fp8 + fp4 on Hopper / Blackwell to the extend that it is supported.
What does this PR do?
This implements inference (only) fp8 linears using transformer engine.
It is modelled after the bitsandbytes transform.
One question I'd have is what a good scaling would be. I'm currently using the max range (
fp8_max / tensor.absmax()
) on the weight and 1.0 on the input, but I have no idea what a good input scale would be (this would depend on the accumulation in fp8 matmuls).PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃