Skip to content

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Oct 12, 2025

What does this PR do?

This implements inference (only) fp8 linears using transformer engine.
It is modelled after the bitsandbytes transform.

One question I'd have is what a good scaling would be. I'm currently using the max range (fp8_max / tensor.absmax()) on the weight and 1.0 on the input, but I have no idea what a good input scale would be (this would depend on the accumulation in fp8 matmuls).

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Comment on lines 35 to 64
def te_linear_fp8_impl(x, qweight, bias, absmax, scale):
wq = transformer_engine.pytorch.Float8Quantizer(
scale=scale,
amax=absmax,
fp8_dtype=transformer_engine_torch.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
)

w = wq.create_tensor_from_data(qweight, fake_dtype=x.dtype, requires_grad=False)

minmax = x.aminmax()
xmax = torch.maximum(minmax.min.abs(), minmax.max.abs()).to(torch.float32)
xq = transformer_engine.pytorch.Float8Quantizer(
scale=1.0 / xmax, # this needs to 1 (or even somewhat smaller for accumulation?)
amax=xmax,
fp8_dtype=transformer_engine_torch.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
)

out, *_ = transformer_engine.pytorch.ops.BasicLinear._functional_forward(
x,
w,
input_quantizer=xq,
with_quantized_compute=True,
weight_requires_grad=False,
input_requires_grad=False,
)

Copy link
Collaborator Author

@t-vi t-vi Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of how I compute linears. Does this make sense? What is a good scale for the input? I used 1/absmax now such that x[j, k] * w[i, k] would be < fp8_max, not sure if that is needed.


minmax = x.aminmax()
xmax = torch.maximum(minmax.min.abs(), minmax.max.abs()).to(torch.float32)
xq = transformer_engine.pytorch.Float8Quantizer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we necessarily want to use Float8Quantizer which ties up with DelayedScaling recipe (but works for both Hopper and Blackwell archs).

Instead we can also use MXFP8Quantizer (note: this is only supported with Blackwell arch). For this recipe, the quantization of the input will not depend on scale from previous iteration.

https://github.com/NVIDIA/TransformerEngine/blob/7ad130efd52c3aa4a386d25f1d42b28d5aa20090/transformer_engine/pytorch/tensor/mxfp8_tensor.py#L29

Copy link
Collaborator Author

@t-vi t-vi Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, I don't currently have easy access to Blackwell, so I'm keen to support hopper, too.
Of course, I'd 100% love to have something flexible enough for fp8 + fp4 on Hopper / Blackwell to the extend that it is supported.

@t-vi t-vi merged commit a955b66 into main Oct 17, 2025
48 of 51 checks passed
@t-vi t-vi deleted the tom/te-inference branch October 17, 2025 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants