From 646803258c7e33d2e8c3545a0a33dc206dda245d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 1 Oct 2025 06:07:01 +0000 Subject: [PATCH 01/27] Prefer loading model from pretrained instead of config --- unsloth/models/llama.py | 1 + unsloth/models/vision.py | 1 + 2 files changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1b2254251..8ff74872a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1930,6 +1930,7 @@ def from_pretrained( token = token, attn_implementation = "sdpa", ) + model_config.model_name = model_name model_max_seq_length = model_config.max_position_embeddings # Check if RoPE Scaling is even allowed diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ae76c573d..c994d0a17 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -584,6 +584,7 @@ def from_pretrained( token = token, attn_implementation = "sdpa" if supports_sdpa else "eager", ) + model_config.model_name = model_name if fast_inference: fast_inference, model_name = fast_inference_setup(model_name, model_config) From e3184a3f99538febf128514824924132db06933d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 1 Oct 2025 12:41:14 +0000 Subject: [PATCH 02/27] Fixup FP8 forward pass and inference --- unsloth/kernels/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index cb8982df0..72fbe2cde 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -691,10 +691,16 @@ def fast_gemv(X, W, quant_state, out = None): def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) + base_layer = getattr(proj, "base_layer", proj) + W_scale_inv = getattr(base_layer, "weight_scale_inv", None) bsz, q_len, in_dim = X.shape - if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) - if W_quant is None: + if W_scale_inv is not None: + # This is fp8. we'll use the same function as hf does. Always take this path for FP8 + out = base_layer(X) + elif q_len != 1: + return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) + elif W_quant is None: out = torch_matmul(X, W.t(), out = out) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) From 6ef688431f37ef94561122b10536f356f508ecde Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 1 Oct 2025 15:38:24 +0000 Subject: [PATCH 03/27] [WIP] Fix lora forwards --- unsloth/kernels/utils.py | 48 ++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 72fbe2cde..6193ad740 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -201,6 +201,10 @@ def get_lora_parameters(proj): if weight_fake_quantizer is not None: W = weight_fake_quantizer(W) + if W.dtype == torch.float8_e4m3fn: + # we need to somehow store and pass this information :) + W.block_size = getattr(base_layer, 'block_size', [128,128]) + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: return W, getattr(W, "quant_state", None), None, None, None @@ -224,7 +228,7 @@ def get_lora_parameters(proj): return ( W, - getattr(W, "quant_state", None), + getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None), A, B, proj.scaling[adapter], @@ -237,6 +241,10 @@ def get_lora_parameters_bias(proj): base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight + if W.dtype == torch.float8_e4m3fn: + # we need to somehow store and pass this information :) + W.block_size = getattr(base_layer, 'block_size', [128,128]) + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias @@ -248,7 +256,7 @@ def get_lora_parameters_bias(proj): return ( W, - getattr(W, "quant_state", None), + getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None), proj.lora_A [adapter].weight, proj.lora_B [adapter].weight, proj.scaling[adapter], @@ -687,21 +695,33 @@ def fast_gemv(X, W, quant_state, out = None): pass pass +def fp8_forward(X, weight, weight_scale): + block_size = getattr(weight, 'block_size', [128,128]) + # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 + from transformers.integrations.finegrained_fp8 import act_quant, w8a8_block_fp8_matmul_triton + qinput, scale = act_quant(X, block_size[1]) + output = w8a8_block_fp8_matmul_triton( + qinput, + weight, + scale, + weight_scale, + block_size, + output_dtype=X.dtype, + ) + return output.to(X.dtype) + def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) - base_layer = getattr(proj, "base_layer", proj) - W_scale_inv = getattr(base_layer, "weight_scale_inv", None) bsz, q_len, in_dim = X.shape + if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) - if W_scale_inv is not None: - # This is fp8. we'll use the same function as hf does. Always take this path for FP8 - out = base_layer(X) - elif q_len != 1: - return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) - elif W_quant is None: + if W_quant is None: out = torch_matmul(X, W.t(), out = out) + elif W.dtype == torch.float8_e4m3fn: + # In case of fp8, we'll let the base layer forward pass handle this. LoRA is anyway 16bit + out = base_layer(X) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -739,7 +759,6 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) if X.dim() == 3: batch, seq_len, d = X.shape @@ -748,7 +767,12 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - out = torch_matmul(X, W, out = out) + + if W.dtype==torch.float8_e4m3fn: + out = fp8_forward(X, W, W_quant,) + else: + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W if A is not None: From 51d6626fcc91fa05db2e922e96141522f0162d14 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 2 Oct 2025 04:33:13 +0000 Subject: [PATCH 04/27] Infer block size from weight shapes --- unsloth/kernels/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 6193ad740..1c06b8d3f 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -201,10 +201,6 @@ def get_lora_parameters(proj): if weight_fake_quantizer is not None: W = weight_fake_quantizer(W) - if W.dtype == torch.float8_e4m3fn: - # we need to somehow store and pass this information :) - W.block_size = getattr(base_layer, 'block_size', [128,128]) - # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: return W, getattr(W, "quant_state", None), None, None, None @@ -241,10 +237,6 @@ def get_lora_parameters_bias(proj): base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - if W.dtype == torch.float8_e4m3fn: - # we need to somehow store and pass this information :) - W.block_size = getattr(base_layer, 'block_size', [128,128]) - # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias @@ -696,7 +688,11 @@ def fast_gemv(X, W, quant_state, out = None): pass def fp8_forward(X, weight, weight_scale): - block_size = getattr(weight, 'block_size', [128,128]) + # block_size = getattr(weight, 'block_size', [128,128]) + m,n = weight.shape + p,q = weight_scale.shape + assert m % p == 0 and n % q == 0, "FP8 Forward: weight and weight_scale shapes are not compatible" + block_size = [m//p,n//q] # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 from transformers.integrations.finegrained_fp8 import act_quant, w8a8_block_fp8_matmul_triton qinput, scale = act_quant(X, block_size[1]) From 9888e87ab6bd7e071d346505c70a22954b43d06e Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 2 Oct 2025 06:57:23 +0000 Subject: [PATCH 05/27] reconstruct weights from fp8 quants for lora matmul --- unsloth/kernels/utils.py | 62 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 1c06b8d3f..6d2ed48fc 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -201,6 +201,10 @@ def get_lora_parameters(proj): if weight_fake_quantizer is not None: W = weight_fake_quantizer(W) + if W.dtype == torch.float8_e4m3fn: + # we need to somehow store and pass this information :) + W.block_size = getattr(base_layer, 'block_size', [128,128]) + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: return W, getattr(W, "quant_state", None), None, None, None @@ -242,6 +246,10 @@ def get_lora_parameters_bias(proj): return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias pass + if W.dtype == torch.float8_e4m3fn: + # we need to somehow store and pass this information :) + W.block_size = getattr(base_layer, 'block_size', [128,128]) + adapter = getattr(proj, "active_adapters", None) if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) adapter = adapter[0] @@ -707,6 +715,55 @@ def fp8_forward(X, weight, weight_scale): return output.to(X.dtype) +def reconstruct_weight_fp8( + W_fp8: torch.Tensor, + W_scale: torch.Tensor, + group_k: int, + group_n: int, + *, + out_dtype=torch.float16, +): + K, N = W_fp8.shape + num_k_groups = math.ceil(K / group_k) + num_n_groups = math.ceil(N / group_n) + + # normalize scale to (num_k_groups, num_n_groups) + if W_scale.numel() == 1: + W_scale = W_scale.reshape(1, 1).expand(num_k_groups, num_n_groups) + elif W_scale.dim() == 1 and W_scale.numel() == num_k_groups * num_n_groups: + W_scale = W_scale.reshape(num_k_groups, num_n_groups) + elif W_scale.dim() == 2 and W_scale.shape == (num_k_groups, num_n_groups): + pass + else: + raise ValueError("Unsupported W_scale shape") + + W = W_fp8.to(dtype=W_scale.dtype).contiguous() + W_scale = W_scale + + # If K or N not divisible by groups, handle last partial groups by padding + Kpad = num_k_groups * group_k + Npad = num_n_groups * group_n + if Kpad != K or Npad != N: + W_pad = W.new_zeros((Kpad, Npad)) + W_pad[:K, :N] = W + W = W_pad + + Wg = W.view(num_k_groups, group_k, num_n_groups, group_n) + Wg = Wg.permute(0, 2, 1, 3).contiguous() + W_flat = Wg.view(num_k_groups * num_n_groups, group_k * group_n) + + ws_flat = W_scale.reshape(-1, 1) + W_flat = W_flat * ws_flat + + # reshape back + Wg = W_flat.view(num_k_groups, num_n_groups, group_k, group_n) + Wg = Wg.permute(0, 2, 1, 3).to(out_dtype).contiguous() + W_out = Wg.view(Kpad, Npad)[:K, :N] + return W_out + +# This cuts down the time taken from ~100us to ~30us for (4096,4096) weight and (32,32) scale :) +reconstruct_weight_fp8 = torch.compile(reconstruct_weight_fp8) + def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) @@ -765,10 +822,11 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - out = fp8_forward(X, W, W_quant,) + k,n = W.block_size + W = reconstruct_weight_fp8(W, W_quant, k, n) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) - out = torch_matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W if A is not None: From 91db140709096bdebf8c525cb5937cfdef637b02 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 2 Oct 2025 07:33:11 +0000 Subject: [PATCH 06/27] Return weight transpose and fix dtype --- unsloth/kernels/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 6d2ed48fc..c742bdadc 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -721,7 +721,7 @@ def reconstruct_weight_fp8( group_k: int, group_n: int, *, - out_dtype=torch.float16, + out_dtype=torch.bfloat16, ): K, N = W_fp8.shape num_k_groups = math.ceil(K / group_k) @@ -759,7 +759,7 @@ def reconstruct_weight_fp8( Wg = W_flat.view(num_k_groups, num_n_groups, group_k, group_n) Wg = Wg.permute(0, 2, 1, 3).to(out_dtype).contiguous() W_out = Wg.view(Kpad, Npad)[:K, :N] - return W_out + return W_out.T # This cuts down the time taken from ~100us to ~30us for (4096,4096) weight and (32,32) scale :) reconstruct_weight_fp8 = torch.compile(reconstruct_weight_fp8) @@ -823,7 +823,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): if W.dtype==torch.float8_e4m3fn: k,n = W.block_size - W = reconstruct_weight_fp8(W, W_quant, k, n) + W = reconstruct_weight_fp8(W, W_quant, k, n, out_dtype=X.dtype) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) From bff4612cee24ff798b114e7377863ea88d9a6335 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 2 Oct 2025 08:24:13 +0000 Subject: [PATCH 07/27] Refactor FP8 operations --- unsloth/kernels/fp8.py | 108 +++++++++++++++++++++++++++++++++++++++ unsloth/kernels/utils.py | 75 ++------------------------- 2 files changed, 112 insertions(+), 71 deletions(-) create mode 100644 unsloth/kernels/fp8.py diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py new file mode 100644 index 000000000..0bc635480 --- /dev/null +++ b/unsloth/kernels/fp8.py @@ -0,0 +1,108 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +torch_matmul = torch.matmul + +@torch.no_grad +def reconstruct_weight_fp8( + W_fp8: torch.Tensor, + W_scale: torch.Tensor, + group_k: int, + group_n: int, + *, + out_dtype=torch.bfloat16, +): + K, N = W_fp8.shape + num_k_groups = math.ceil(K / group_k) + num_n_groups = math.ceil(N / group_n) + + # normalize scale to (num_k_groups, num_n_groups) + if W_scale.numel() == 1: + W_scale = W_scale.reshape(1, 1).expand(num_k_groups, num_n_groups) + elif W_scale.dim() == 1 and W_scale.numel() == num_k_groups * num_n_groups: + W_scale = W_scale.reshape(num_k_groups, num_n_groups) + elif W_scale.dim() == 2 and W_scale.shape == (num_k_groups, num_n_groups): + pass + else: + raise ValueError("Unsupported W_scale shape") + + W = W_fp8.to(dtype=W_scale.dtype).contiguous() + W_scale = W_scale + + # If K or N not divisible by groups, handle last partial groups by padding + Kpad = num_k_groups * group_k + Npad = num_n_groups * group_n + if Kpad != K or Npad != N: + W_pad = W.new_zeros((Kpad, Npad)) + W_pad[:K, :N] = W + W = W_pad + + Wg = W.view(num_k_groups, group_k, num_n_groups, group_n) + Wg = Wg.permute(0, 2, 1, 3).contiguous() + W_flat = Wg.view(num_k_groups * num_n_groups, group_k * group_n) + + ws_flat = W_scale.reshape(-1, 1) + W_flat = W_flat * ws_flat + + # reshape back + Wg = W_flat.view(num_k_groups, num_n_groups, group_k, group_n) + Wg = Wg.permute(0, 2, 1, 3).to(out_dtype).contiguous() + W_out = Wg.view(Kpad, Npad)[:K, :N] + return W_out.T + + +class FP8_E5M2Linear(torch.autograd.Function): + + @torch.compile + def forward_compiled(ctx, X, weight, weight_scale): + # block_size = getattr(weight, 'block_size', [128,128]) + m,n = weight.shape + p,q = weight_scale.shape + assert m % p == 0 and n % q == 0, "FP8 Forward: weight and weight_scale shapes are not compatible" + block_size = getattr(weight, 'block_size', [m//p,n//q]) + # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 + from transformers.integrations.finegrained_fp8 import act_quant, w8a8_block_fp8_matmul_triton + qinput, scale = act_quant(X, block_size[1]) + output = w8a8_block_fp8_matmul_triton( + qinput, + weight, + scale, + weight_scale, + block_size, + output_dtype=X.dtype, + ) + + ctx.weight = weight + ctx.weight_scale = weight_scale + ctx.block_size = block_size + + return output.to(X.dtype) + + @staticmethod + def forward(ctx, X, weight, weight_scale): + return FP8_E5M2Linear.forward_compiled(ctx, X, weight, weight_scale) + + @torch.compile + def backward_compiled(ctx, grad_output): + W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) + grad_X = torch_matmul(grad_output, W_deq.t()) + return grad_X, None, None + + @staticmethod + def backward(ctx, grad_output): + return FP8_E5M2Linear.backward_compiled(ctx, grad_output) + +def fp8_e5m2_forward(X, weight, weight_scale): + return FP8_E5M2Linear.apply(X, weight, weight_scale) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index c742bdadc..0b882119b 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -19,6 +19,7 @@ import functools from typing import Optional from unsloth import DEVICE_TYPE, DEVICE_COUNT +from .fp8 import fp8_e5m2_forward # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -695,74 +696,6 @@ def fast_gemv(X, W, quant_state, out = None): pass pass -def fp8_forward(X, weight, weight_scale): - # block_size = getattr(weight, 'block_size', [128,128]) - m,n = weight.shape - p,q = weight_scale.shape - assert m % p == 0 and n % q == 0, "FP8 Forward: weight and weight_scale shapes are not compatible" - block_size = [m//p,n//q] - # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 - from transformers.integrations.finegrained_fp8 import act_quant, w8a8_block_fp8_matmul_triton - qinput, scale = act_quant(X, block_size[1]) - output = w8a8_block_fp8_matmul_triton( - qinput, - weight, - scale, - weight_scale, - block_size, - output_dtype=X.dtype, - ) - return output.to(X.dtype) - - -def reconstruct_weight_fp8( - W_fp8: torch.Tensor, - W_scale: torch.Tensor, - group_k: int, - group_n: int, - *, - out_dtype=torch.bfloat16, -): - K, N = W_fp8.shape - num_k_groups = math.ceil(K / group_k) - num_n_groups = math.ceil(N / group_n) - - # normalize scale to (num_k_groups, num_n_groups) - if W_scale.numel() == 1: - W_scale = W_scale.reshape(1, 1).expand(num_k_groups, num_n_groups) - elif W_scale.dim() == 1 and W_scale.numel() == num_k_groups * num_n_groups: - W_scale = W_scale.reshape(num_k_groups, num_n_groups) - elif W_scale.dim() == 2 and W_scale.shape == (num_k_groups, num_n_groups): - pass - else: - raise ValueError("Unsupported W_scale shape") - - W = W_fp8.to(dtype=W_scale.dtype).contiguous() - W_scale = W_scale - - # If K or N not divisible by groups, handle last partial groups by padding - Kpad = num_k_groups * group_k - Npad = num_n_groups * group_n - if Kpad != K or Npad != N: - W_pad = W.new_zeros((Kpad, Npad)) - W_pad[:K, :N] = W - W = W_pad - - Wg = W.view(num_k_groups, group_k, num_n_groups, group_n) - Wg = Wg.permute(0, 2, 1, 3).contiguous() - W_flat = Wg.view(num_k_groups * num_n_groups, group_k * group_n) - - ws_flat = W_scale.reshape(-1, 1) - W_flat = W_flat * ws_flat - - # reshape back - Wg = W_flat.view(num_k_groups, num_n_groups, group_k, group_n) - Wg = Wg.permute(0, 2, 1, 3).to(out_dtype).contiguous() - W_out = Wg.view(Kpad, Npad)[:K, :N] - return W_out.T - -# This cuts down the time taken from ~100us to ~30us for (4096,4096) weight and (32,32) scale :) -reconstruct_weight_fp8 = torch.compile(reconstruct_weight_fp8) def fast_linear_forward(proj, X, temp_lora = None, out = None): @@ -774,6 +707,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: # In case of fp8, we'll let the base layer forward pass handle this. LoRA is anyway 16bit + base_layer = getattr(proj, 'base_layer', proj) out = base_layer(X) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) @@ -822,11 +756,10 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - k,n = W.block_size - W = reconstruct_weight_fp8(W, W_quant, k, n, out_dtype=X.dtype) + out = fp8_e5m2_forward(X, W, W_quant) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) - out = torch_matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W if A is not None: From fb1849ce6c6c6e48847ac5635d49e3102240ab75 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 2 Oct 2025 08:39:51 +0000 Subject: [PATCH 08/27] Fix naming :) --- unsloth/kernels/fp8.py | 10 +++++----- unsloth/kernels/utils.py | 8 +++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 0bc635480..415d6de1b 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -63,7 +63,7 @@ def reconstruct_weight_fp8( return W_out.T -class FP8_E5M2Linear(torch.autograd.Function): +class FP8_E4M3Linear(torch.autograd.Function): @torch.compile def forward_compiled(ctx, X, weight, weight_scale): @@ -92,7 +92,7 @@ def forward_compiled(ctx, X, weight, weight_scale): @staticmethod def forward(ctx, X, weight, weight_scale): - return FP8_E5M2Linear.forward_compiled(ctx, X, weight, weight_scale) + return FP8_E4M3Linear.forward_compiled(ctx, X, weight, weight_scale) @torch.compile def backward_compiled(ctx, grad_output): @@ -102,7 +102,7 @@ def backward_compiled(ctx, grad_output): @staticmethod def backward(ctx, grad_output): - return FP8_E5M2Linear.backward_compiled(ctx, grad_output) + return FP8_E4M3Linear.backward_compiled(ctx, grad_output) -def fp8_e5m2_forward(X, weight, weight_scale): - return FP8_E5M2Linear.apply(X, weight, weight_scale) +def fp8_e4m3_forward(X, weight, weight_scale): + return FP8_E4M3Linear.apply(X, weight, weight_scale) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 0b882119b..2fed13cca 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -19,7 +19,7 @@ import functools from typing import Optional from unsloth import DEVICE_TYPE, DEVICE_COUNT -from .fp8 import fp8_e5m2_forward +from .fp8 import fp8_e4m3_forward # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -706,9 +706,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if W_quant is None: out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: - # In case of fp8, we'll let the base layer forward pass handle this. LoRA is anyway 16bit - base_layer = getattr(proj, 'base_layer', proj) - out = base_layer(X) + out = fp8_e4m3_forward(X, W, W_quant) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -756,7 +754,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - out = fp8_e5m2_forward(X, W, W_quant) + out = fp8_e4m3_forward(X, W, W_quant) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) From 85791f37b179262319963761ac896fd0f72b0447 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 3 Oct 2025 03:58:03 +0000 Subject: [PATCH 09/27] Saner compile --- unsloth/kernels/fp8.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 415d6de1b..25fe7ade0 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -65,8 +65,8 @@ def reconstruct_weight_fp8( class FP8_E4M3Linear(torch.autograd.Function): - @torch.compile - def forward_compiled(ctx, X, weight, weight_scale): + @staticmethod + def forward(ctx, X, weight, weight_scale): # block_size = getattr(weight, 'block_size', [128,128]) m,n = weight.shape p,q = weight_scale.shape @@ -91,18 +91,11 @@ def forward_compiled(ctx, X, weight, weight_scale): return output.to(X.dtype) @staticmethod - def forward(ctx, X, weight, weight_scale): - return FP8_E4M3Linear.forward_compiled(ctx, X, weight, weight_scale) - - @torch.compile - def backward_compiled(ctx, grad_output): + def backward(ctx, grad_output): W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) grad_X = torch_matmul(grad_output, W_deq.t()) return grad_X, None, None - @staticmethod - def backward(ctx, grad_output): - return FP8_E4M3Linear.backward_compiled(ctx, grad_output) - +@torch.compile def fp8_e4m3_forward(X, weight, weight_scale): return FP8_E4M3Linear.apply(X, weight, weight_scale) From 4a4f7e2073b388a29907cde0596673c4f2476806 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 3 Oct 2025 04:49:20 +0000 Subject: [PATCH 10/27] do not depend on transformers --- unsloth/kernels/fp8.py | 195 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 193 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 25fe7ade0..2f2212c43 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import torch.nn as nn +import triton +import triton.language as tl +from torch.nn import functional as F torch_matmul = torch.matmul @@ -62,6 +66,195 @@ def reconstruct_weight_fp8( W_out = Wg.view(Kpad, Npad)[:K, :N] return W_out.T +# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py +@triton.jit +def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + +def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous() + assert x.shape[-1] % block_size == 0 + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + + def grid(meta): + return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) + + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def w8a8_block_fp8_matmul_triton( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + BLOCK_SIZE_M = 128 + if M < BLOCK_SIZE_M: + BLOCK_SIZE_M = triton.next_power_of_2(M) + BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16) + BLOCK_SIZE_K = block_k + assert block_k % BLOCK_SIZE_K == 0 + BLOCK_SIZE_N = block_n + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) + + return C class FP8_E4M3Linear(torch.autograd.Function): @@ -70,10 +263,8 @@ def forward(ctx, X, weight, weight_scale): # block_size = getattr(weight, 'block_size', [128,128]) m,n = weight.shape p,q = weight_scale.shape - assert m % p == 0 and n % q == 0, "FP8 Forward: weight and weight_scale shapes are not compatible" block_size = getattr(weight, 'block_size', [m//p,n//q]) # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 - from transformers.integrations.finegrained_fp8 import act_quant, w8a8_block_fp8_matmul_triton qinput, scale = act_quant(X, block_size[1]) output = w8a8_block_fp8_matmul_triton( qinput, From 0b93d94cfa87e9b79a3a3b8697ca525e99f27986 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 3 Oct 2025 08:42:04 +0000 Subject: [PATCH 11/27] [WIP] fix training --- unsloth/kernels/fp8.py | 81 ++++++++++++++++++++++++++++++---------- unsloth/kernels/utils.py | 19 +++++++--- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 2f2212c43..068af45c5 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -16,23 +16,44 @@ import triton import triton.language as tl from torch.nn import functional as F +import math torch_matmul = torch.matmul -@torch.no_grad def reconstruct_weight_fp8( W_fp8: torch.Tensor, W_scale: torch.Tensor, - group_k: int, - group_n: int, + group_k=None, + group_n=None, *, out_dtype=torch.bfloat16, ): + """ + Dequantize an FP8-block-quantized weight matrix. + + W_fp8: (K, N) tensor in fp8 dtype + W_scale: scalar, 1D (num_k_groups * num_n_groups) or 2D (num_k_groups, num_n_groups) + group_k, group_n: block sizes (defaults read from attributes) + returns: dequantized weight (N, K).contiguous() with block_size metadata set to [group_n, group_k] + """ + + # infer block sizes if not provided + block_from_fp8 = getattr(W_fp8, 'block_size', None) + block_from_scale = getattr(W_scale, 'block_size', None) + if group_k is None or group_n is None: + if block_from_fp8 is not None: + group_k, group_n = block_from_fp8 + elif block_from_scale is not None: + group_k, group_n = block_from_scale + else: + # fallback to default used elsewhere in your codebase + group_k, group_n = (128, 128) + K, N = W_fp8.shape num_k_groups = math.ceil(K / group_k) num_n_groups = math.ceil(N / group_n) - # normalize scale to (num_k_groups, num_n_groups) + # Normalize scale to (num_k_groups, num_n_groups) if W_scale.numel() == 1: W_scale = W_scale.reshape(1, 1).expand(num_k_groups, num_n_groups) elif W_scale.dim() == 1 and W_scale.numel() == num_k_groups * num_n_groups: @@ -40,31 +61,43 @@ def reconstruct_weight_fp8( elif W_scale.dim() == 2 and W_scale.shape == (num_k_groups, num_n_groups): pass else: - raise ValueError("Unsupported W_scale shape") + raise ValueError( + f"Unsupported W_scale shape {tuple(W_scale.shape)} " + f"for W_fp8.shape={tuple(W_fp8.shape)}, group_k={group_k}, group_n={group_n}" + ) - W = W_fp8.to(dtype=W_scale.dtype).contiguous() - W_scale = W_scale + # Move W to same dtype as scale for safe multiplication (and same device) + device = W_fp8.device + scale_dtype = W_scale.dtype + W = W_fp8.to(dtype=scale_dtype, device=device).contiguous() - # If K or N not divisible by groups, handle last partial groups by padding + # pad to full groups if needed (this keeps block grouping simple) Kpad = num_k_groups * group_k Npad = num_n_groups * group_n if Kpad != K or Npad != N: W_pad = W.new_zeros((Kpad, Npad)) W_pad[:K, :N] = W W = W_pad + # now W is (Kpad, Npad) and contiguous + + # View into blocks: shape -> (num_k_groups, group_k, num_n_groups, group_n) + W_blocks = W.view(num_k_groups, group_k, num_n_groups, group_n) + + # Permute to (num_k_groups, num_n_groups, group_k, group_n) so scales broadcast naturally + W_blocks = W_blocks.permute(0, 2, 1, 3) # may be non-contiguous, but multiplication will allocate a new tensor + + # Broadcast multiply by per-block scales -> result is a new contiguous tensor + scales = W_scale.reshape(num_k_groups, num_n_groups, 1, 1).to(device=device, dtype=scale_dtype) + W_scaled = W_blocks * scales # (num_k_groups, num_n_groups, group_k, group_n) - contiguous - Wg = W.view(num_k_groups, group_k, num_n_groups, group_n) - Wg = Wg.permute(0, 2, 1, 3).contiguous() - W_flat = Wg.view(num_k_groups * num_n_groups, group_k * group_n) + # permute back to (num_k_groups, group_k, num_n_groups, group_n), make contiguous, then reshape to (Kpad, Npad) + W_reordered = W_scaled.permute(0, 2, 1, 3).contiguous().reshape(Kpad, Npad).to(dtype=out_dtype) - ws_flat = W_scale.reshape(-1, 1) - W_flat = W_flat * ws_flat + # slice off padding, transpose and return contiguous result + W_out = W_reordered[:K, :N].T.contiguous() + W_out.block_size = [group_n, group_k] # returning transpose => block_size swapped + return W_out - # reshape back - Wg = W_flat.view(num_k_groups, num_n_groups, group_k, group_n) - Wg = Wg.permute(0, 2, 1, 3).to(out_dtype).contiguous() - W_out = Wg.view(Kpad, Npad)[:K, :N] - return W_out.T # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py @triton.jit @@ -263,7 +296,17 @@ def forward(ctx, X, weight, weight_scale): # block_size = getattr(weight, 'block_size', [128,128]) m,n = weight.shape p,q = weight_scale.shape - block_size = getattr(weight, 'block_size', [m//p,n//q]) + block_size = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', None) + assert block_size is not None, "block_size is not set" + if triton.cdiv(m,block_size[0])!=p or triton.cdiv(n,block_size[1])!=q: + if triton.cdiv(m,block_size[0])==q and triton.cdiv(n,block_size[1])==p: + # for some reaosn sometimes the weights seem to be transposed for training. + weight_scale = weight_scale.T + else: + raise ValueError(f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}") + + if not weight.is_contiguous(): + weight = weight.contiguous() # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 qinput, scale = act_quant(X, block_size[1]) output = w8a8_block_fp8_matmul_triton( diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 2fed13cca..f90f0a528 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -19,7 +19,7 @@ import functools from typing import Optional from unsloth import DEVICE_TYPE, DEVICE_COUNT -from .fp8 import fp8_e4m3_forward +from .fp8 import fp8_e4m3_forward, reconstruct_weight_fp8 # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -202,13 +202,16 @@ def get_lora_parameters(proj): if weight_fake_quantizer is not None: W = weight_fake_quantizer(W) + W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) + if W.dtype == torch.float8_e4m3fn: # we need to somehow store and pass this information :) W.block_size = getattr(base_layer, 'block_size', [128,128]) + W_quant.block_size = W.block_size # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: - return W, getattr(W, "quant_state", None), None, None, None + return W, W_quant, None, None, None pass adapter = getattr(proj, "active_adapters", None) @@ -229,7 +232,7 @@ def get_lora_parameters(proj): return ( W, - getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None), + W_quant, A, B, proj.scaling[adapter], @@ -242,14 +245,17 @@ def get_lora_parameters_bias(proj): base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight + W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: - return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias + return W, W_quant, None, None, None, base_layer.bias pass if W.dtype == torch.float8_e4m3fn: # we need to somehow store and pass this information :) W.block_size = getattr(base_layer, 'block_size', [128,128]) + W_quant.block_size = W.block_size adapter = getattr(proj, "active_adapters", None) if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) @@ -257,7 +263,7 @@ def get_lora_parameters_bias(proj): return ( W, - getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None), + W_quant, proj.lora_A [adapter].weight, proj.lora_B [adapter].weight, proj.scaling[adapter], @@ -286,6 +292,7 @@ def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): # TODO: After adding XPU BNB support, check this function if quant_state is None: return W + if W.dtype == torch.float8_e4m3fn: return reconstruct_weight_fp8(W.t().contiguous(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -361,6 +368,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W + if W.dtype == torch.float8_e4m3fn: return reconstruct_weight_fp8(W.t().contiguous(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -436,6 +444,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W + if W.dtype == torch.float8_e4m3fn: return reconstruct_weight_fp8(W.t().contiguous(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files From fb61bf6142b0121fa44447f7563aa878c62d892b Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 3 Oct 2025 08:49:45 +0000 Subject: [PATCH 12/27] Update comment --- unsloth/kernels/fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 068af45c5..a65b08584 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -300,7 +300,8 @@ def forward(ctx, X, weight, weight_scale): assert block_size is not None, "block_size is not set" if triton.cdiv(m,block_size[0])!=p or triton.cdiv(n,block_size[1])!=q: if triton.cdiv(m,block_size[0])==q and triton.cdiv(n,block_size[1])==p: - # for some reaosn sometimes the weights seem to be transposed for training. + # weights are tranposed during backward pass for training :) + # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X weight_scale = weight_scale.T else: raise ValueError(f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}") From 039fa9d7f3b4425de94a1c70714334fa58f23e8b Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 3 Oct 2025 11:41:57 +0000 Subject: [PATCH 13/27] fixup training --- unsloth/kernels/fp8.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index a65b08584..b278603de 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -36,6 +36,10 @@ def reconstruct_weight_fp8( group_k, group_n: block sizes (defaults read from attributes) returns: dequantized weight (N, K).contiguous() with block_size metadata set to [group_n, group_k] """ + if W_fp8.dtype != torch.float8_e4m3fn: + import traceback + traceback.print_stack() + raise ValueError(f'Reconstruct weight from fp8 function called on non fp8 input {W_fp8.dtype}. Returning original input but transposed') # infer block sizes if not provided block_from_fp8 = getattr(W_fp8, 'block_size', None) @@ -106,6 +110,11 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) s = tl.max(tl.abs(x)) / 448.0 + if s==0: + # For a row of all zeros, lets return zeros as is + # for LoRA, there are cases where dY has 0 in it and we should not let it be NaN + # this is a deviation from the original implementation. + s = 1.0 y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) From 182e3ce97ecba596bc8c52f228a2c9a310d18278 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 9 Oct 2025 08:57:09 +0000 Subject: [PATCH 14/27] use dequant kernel from deepseek --- unsloth/kernels/fp8.py | 29 ++++++++++++++++++++++++++++- unsloth/kernels/utils.py | 2 +- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index b278603de..92f70ee76 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -20,6 +20,7 @@ torch_matmul = torch.matmul +@torch.compile def reconstruct_weight_fp8( W_fp8: torch.Tensor, W_scale: torch.Tensor, @@ -102,6 +103,30 @@ def reconstruct_weight_fp8( W_out.block_size = [group_n, group_k] # returning transpose => block_size swapped return W_out +@triton.jit +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py @triton.jit @@ -336,8 +361,10 @@ def forward(ctx, X, weight, weight_scale): @staticmethod def backward(ctx, grad_output): - W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) + # W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) + W_deq = weight_dequant(ctx.weight, ctx.weight_scale, ctx.block_size[1]) grad_X = torch_matmul(grad_output, W_deq.t()) + del W_deq return grad_X, None, None @torch.compile diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f90f0a528..d470aa138 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -18,7 +18,7 @@ next_power_of_2 = triton.next_power_of_2 import functools from typing import Optional -from unsloth import DEVICE_TYPE, DEVICE_COUNT +from .. import DEVICE_TYPE, DEVICE_COUNT from .fp8 import fp8_e4m3_forward, reconstruct_weight_fp8 # torch.cuda.amp.custom_fwd is deprecated >= 2.4 From c8e7261951696190df397ca7644fcb3a2d324fc1 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 9 Oct 2025 13:24:28 +0000 Subject: [PATCH 15/27] Differentiate between fp8 and fbgemmfp8 --- unsloth/kernels/fp8.py | 41 ++++++++++++++++++++++++++++++++++++++++ unsloth/kernels/utils.py | 21 +++++++++++++------- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 92f70ee76..3cd192d04 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -370,3 +370,44 @@ def backward(ctx, grad_output): @torch.compile def fp8_e4m3_forward(X, weight, weight_scale): return FP8_E4M3Linear.apply(X, weight, weight_scale) + + +class FbgemmFp8Linear(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, weight_scale, bias=None): + # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here + output_shape = (*x.shape[:-1], -1) + # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. + # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 + x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x.view(-1, x.shape[-1]).contiguous(), scale_ub=getattr(weight, 'input_scale_ub') + ) + # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works + # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) + + # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight + weight_scale_float32 = weight_scale.to(torch.float32) + output = torch.ops.fbgemm.f8f8bf16_rowwise( + x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True + ) + output = output + bias if bias is not None else output + # Hacky for now, we have the output to the device of x + output = output.to(x.device, x.dtype) + output = output.reshape(output_shape) + del x_quantized, x_scale + + ctx.weight = weight + ctx.weight_scale = weight_scale + + return output + + @staticmethod + def backward(ctx, grad_output): + W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) + grad_X = torch_matmul(grad_output, W_deq.t()) + return grad_X, None, None, None, None + +@torch.compile +def fbgemm_fp8_linear(X, weight, weight_scale, bias=None, ): + return FbgemmFp8Linear.apply(X, weight, weight_scale, bias) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index d470aa138..9e3e274b3 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -18,8 +18,9 @@ next_power_of_2 = triton.next_power_of_2 import functools from typing import Optional + from .. import DEVICE_TYPE, DEVICE_COUNT -from .fp8 import fp8_e4m3_forward, reconstruct_weight_fp8 +from .fp8 import fp8_e4m3_forward, reconstruct_weight_fp8, fbgemm_fp8_linear # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -202,9 +203,9 @@ def get_lora_parameters(proj): if weight_fake_quantizer is not None: W = weight_fake_quantizer(W) - W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) + W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) or getattr(base_layer, 'weight_scale', None) - if W.dtype == torch.float8_e4m3fn: + if getattr(base_layer, 'quant_method', None)=='fp8': # we need to somehow store and pass this information :) W.block_size = getattr(base_layer, 'block_size', [128,128]) W_quant.block_size = W.block_size @@ -245,14 +246,14 @@ def get_lora_parameters_bias(proj): base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) + W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) or getattr(base_layer, 'weight_scale', None) # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: return W, W_quant, None, None, None, base_layer.bias pass - if W.dtype == torch.float8_e4m3fn: + if getattr(base_layer, 'quant_method', None)=='fp8': # we need to somehow store and pass this information :) W.block_size = getattr(base_layer, 'block_size', [128,128]) W_quant.block_size = W.block_size @@ -715,7 +716,10 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if W_quant is None: out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: - out = fp8_e4m3_forward(X, W, W_quant) + if getattr(W, 'quant_method','fp8')=='fp8': + out = fp8_e4m3_forward(X, W, W_quant) + else: + out = fbgemm_fp8_linear(X, W, W_quant, ) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -763,7 +767,10 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - out = fp8_e4m3_forward(X, W, W_quant) + if getattr(W, 'quant_method','fp8')=='fp8': + out = fp8_e4m3_forward(X, W, W_quant) + else: + out = fbgemm_fp8_linear(X, W, W_quant, ) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) From a3a0a3dea3eea8ff511c12c3546ba6df6b86e224 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 9 Oct 2025 15:18:40 +0000 Subject: [PATCH 16/27] fixup differentiation b/w fp8 and fbgemm_fp8 --- unsloth/kernels/fp8.py | 94 +++------------------------------------- unsloth/kernels/utils.py | 26 ++++++----- 2 files changed, 21 insertions(+), 99 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 3cd192d04..33094a968 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -20,89 +20,6 @@ torch_matmul = torch.matmul -@torch.compile -def reconstruct_weight_fp8( - W_fp8: torch.Tensor, - W_scale: torch.Tensor, - group_k=None, - group_n=None, - *, - out_dtype=torch.bfloat16, -): - """ - Dequantize an FP8-block-quantized weight matrix. - - W_fp8: (K, N) tensor in fp8 dtype - W_scale: scalar, 1D (num_k_groups * num_n_groups) or 2D (num_k_groups, num_n_groups) - group_k, group_n: block sizes (defaults read from attributes) - returns: dequantized weight (N, K).contiguous() with block_size metadata set to [group_n, group_k] - """ - if W_fp8.dtype != torch.float8_e4m3fn: - import traceback - traceback.print_stack() - raise ValueError(f'Reconstruct weight from fp8 function called on non fp8 input {W_fp8.dtype}. Returning original input but transposed') - - # infer block sizes if not provided - block_from_fp8 = getattr(W_fp8, 'block_size', None) - block_from_scale = getattr(W_scale, 'block_size', None) - if group_k is None or group_n is None: - if block_from_fp8 is not None: - group_k, group_n = block_from_fp8 - elif block_from_scale is not None: - group_k, group_n = block_from_scale - else: - # fallback to default used elsewhere in your codebase - group_k, group_n = (128, 128) - - K, N = W_fp8.shape - num_k_groups = math.ceil(K / group_k) - num_n_groups = math.ceil(N / group_n) - - # Normalize scale to (num_k_groups, num_n_groups) - if W_scale.numel() == 1: - W_scale = W_scale.reshape(1, 1).expand(num_k_groups, num_n_groups) - elif W_scale.dim() == 1 and W_scale.numel() == num_k_groups * num_n_groups: - W_scale = W_scale.reshape(num_k_groups, num_n_groups) - elif W_scale.dim() == 2 and W_scale.shape == (num_k_groups, num_n_groups): - pass - else: - raise ValueError( - f"Unsupported W_scale shape {tuple(W_scale.shape)} " - f"for W_fp8.shape={tuple(W_fp8.shape)}, group_k={group_k}, group_n={group_n}" - ) - - # Move W to same dtype as scale for safe multiplication (and same device) - device = W_fp8.device - scale_dtype = W_scale.dtype - W = W_fp8.to(dtype=scale_dtype, device=device).contiguous() - - # pad to full groups if needed (this keeps block grouping simple) - Kpad = num_k_groups * group_k - Npad = num_n_groups * group_n - if Kpad != K or Npad != N: - W_pad = W.new_zeros((Kpad, Npad)) - W_pad[:K, :N] = W - W = W_pad - # now W is (Kpad, Npad) and contiguous - - # View into blocks: shape -> (num_k_groups, group_k, num_n_groups, group_n) - W_blocks = W.view(num_k_groups, group_k, num_n_groups, group_n) - - # Permute to (num_k_groups, num_n_groups, group_k, group_n) so scales broadcast naturally - W_blocks = W_blocks.permute(0, 2, 1, 3) # may be non-contiguous, but multiplication will allocate a new tensor - - # Broadcast multiply by per-block scales -> result is a new contiguous tensor - scales = W_scale.reshape(num_k_groups, num_n_groups, 1, 1).to(device=device, dtype=scale_dtype) - W_scaled = W_blocks * scales # (num_k_groups, num_n_groups, group_k, group_n) - contiguous - - # permute back to (num_k_groups, group_k, num_n_groups, group_n), make contiguous, then reshape to (Kpad, Npad) - W_reordered = W_scaled.permute(0, 2, 1, 3).contiguous().reshape(Kpad, Npad).to(dtype=out_dtype) - - # slice off padding, transpose and return contiguous result - W_out = W_reordered[:K, :N].T.contiguous() - W_out.block_size = [group_n, group_k] # returning transpose => block_size swapped - return W_out - @triton.jit def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): pid_m = tl.program_id(axis=0) @@ -118,15 +35,14 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16) -> torch.Tensor: assert x.is_contiguous() and s.is_contiguous() assert x.dim() == 2 and s.dim() == 2 M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) + y = torch.empty_like(x, dtype=dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) - return y - + return y.T # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py @triton.jit @@ -381,7 +297,7 @@ def forward(ctx, x, weight, weight_scale, bias=None): # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - x.view(-1, x.shape[-1]).contiguous(), scale_ub=getattr(weight, 'input_scale_ub') + x.view(-1, x.shape[-1]).contiguous(), scale_ub=getattr(weight, 'input_scale_ub', None) ) # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) @@ -404,7 +320,7 @@ def forward(ctx, x, weight, weight_scale, bias=None): @staticmethod def backward(ctx, grad_output): - W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) + W_deq = weight_dequant(ctx.weight, ctx.weight_scale, ) grad_X = torch_matmul(grad_output, W_deq.t()) return grad_X, None, None, None, None diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 9e3e274b3..9b6112a41 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -20,7 +20,7 @@ from typing import Optional from .. import DEVICE_TYPE, DEVICE_COUNT -from .fp8 import fp8_e4m3_forward, reconstruct_weight_fp8, fbgemm_fp8_linear +from .fp8 import fp8_e4m3_forward, fbgemm_fp8_linear, weight_dequant # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -203,7 +203,7 @@ def get_lora_parameters(proj): if weight_fake_quantizer is not None: W = weight_fake_quantizer(W) - W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) or getattr(base_layer, 'weight_scale', None) + W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None) if getattr(base_layer, 'quant_method', None)=='fp8': # we need to somehow store and pass this information :) @@ -246,7 +246,7 @@ def get_lora_parameters_bias(proj): base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - W_quant = getattr(W, "quant_state", None) or getattr(base_layer,'weight_scale_inv', None) or getattr(base_layer, 'weight_scale', None) + W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None) # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: @@ -293,7 +293,7 @@ def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): # TODO: After adding XPU BNB support, check this function if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return reconstruct_weight_fp8(W.t().contiguous(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t().contiguous(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -369,7 +369,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return reconstruct_weight_fp8(W.t().contiguous(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t().contiguous(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -445,7 +445,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return reconstruct_weight_fp8(W.t().contiguous(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t().contiguous(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -716,10 +716,13 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if W_quant is None: out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: - if getattr(W, 'quant_method','fp8')=='fp8': + quant_method = getattr(W, 'quant_method', None) or getattr(W_quant, 'quant_method', None) + if quant_method=='fp8': out = fp8_e4m3_forward(X, W, W_quant) - else: + elif quant_method=='fbgemm_fp8': out = fbgemm_fp8_linear(X, W, W_quant, ) + else: + raise ValueError(f'no quant method found. {W.shape=} {W_quant.shape}') elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -767,10 +770,13 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - if getattr(W, 'quant_method','fp8')=='fp8': + quant_method = getattr(W, 'quant_method', None) or getattr(W_quant, 'quant_method', None) + if quant_method=='fp8': out = fp8_e4m3_forward(X, W, W_quant) - else: + elif quant_method=='fbgemm_fp8': out = fbgemm_fp8_linear(X, W, W_quant, ) + else: + raise ValueError(f'no quant method found. {W.shape=} {W_quant.shape}') else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) From 5603730128b4ca989df988445c5a86c7a61e5c23 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 9 Oct 2025 17:10:20 +0000 Subject: [PATCH 17/27] make inputs contiguous if required --- unsloth/kernels/fp8.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 33094a968..d92f12e7a 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -36,7 +36,10 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16) -> torch.Tensor: - assert x.is_contiguous() and s.is_contiguous() + if not x.is_contiguous(): + x = x.contiguous() + if not s.is_contiguous(): + s = s.contiguous() assert x.dim() == 2 and s.dim() == 2 M, N = x.size() y = torch.empty_like(x, dtype=dtype) @@ -62,7 +65,8 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): tl.store(s_ptr + pid, s) def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.is_contiguous() + if not x.is_contiguous(): + x = x.contiguous() assert x.shape[-1] % block_size == 0 y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) From bfb45b12bdf86a922aede41d48e24c761b585a8c Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 9 Oct 2025 18:58:44 +0000 Subject: [PATCH 18/27] Improve dequant --- unsloth/kernels/fp8.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index d92f12e7a..d0db94264 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -35,7 +35,7 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16) -> torch.Tensor: +def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16) -> torch.Tensor: if not x.is_contiguous(): x = x.contiguous() if not s.is_contiguous(): @@ -45,7 +45,17 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtyp y = torch.empty_like(x, dtype=dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) - return y.T + return y + +def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16): + if s.shape[1] == 1: + # this is row quantized weight, just simple multiplication suffices + y = x.to(torch.float32) * s.to(torch.float32) + return y.to(dtype).T + else: + # this is block quantized weight + return weight_dequant_block(x, s, dtype=dtype) + # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py @triton.jit @@ -282,7 +292,7 @@ def forward(ctx, X, weight, weight_scale): @staticmethod def backward(ctx, grad_output): # W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) - W_deq = weight_dequant(ctx.weight, ctx.weight_scale, ctx.block_size[1]) + W_deq = ctx.weight.to(torch.bfloat16) * ctx.weight_scale.to(torch.bfloat16) grad_X = torch_matmul(grad_output, W_deq.t()) del W_deq return grad_X, None, None @@ -308,6 +318,12 @@ def forward(ctx, x, weight, weight_scale, bias=None): # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight weight_scale_float32 = weight_scale.to(torch.float32) + + if not weight.is_contiguous(): + weight = weight.contiguous() + if not weight_scale.is_contiguous(): + weight_scale = weight_scale.contiguous() + output = torch.ops.fbgemm.f8f8bf16_rowwise( x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True ) @@ -324,8 +340,9 @@ def forward(ctx, x, weight, weight_scale, bias=None): @staticmethod def backward(ctx, grad_output): - W_deq = weight_dequant(ctx.weight, ctx.weight_scale, ) + W_deq = weight_dequant(ctx.weight, ctx.weight_scale) grad_X = torch_matmul(grad_output, W_deq.t()) + del W_deq return grad_X, None, None, None, None @torch.compile From 3f277faa2a81678ebeaf580650fb04b846a1829e Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 10 Oct 2025 03:27:42 +0000 Subject: [PATCH 19/27] More robust handling --- unsloth/kernels/fp8.py | 4 ++-- unsloth/kernels/utils.py | 24 +++++++++++------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index d0db94264..4611e4a05 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -292,7 +292,7 @@ def forward(ctx, X, weight, weight_scale): @staticmethod def backward(ctx, grad_output): # W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) - W_deq = ctx.weight.to(torch.bfloat16) * ctx.weight_scale.to(torch.bfloat16) + W_deq = weight_dequant(ctx.weight, ctx.weight_scale) grad_X = torch_matmul(grad_output, W_deq.t()) del W_deq return grad_X, None, None @@ -341,7 +341,7 @@ def forward(ctx, x, weight, weight_scale, bias=None): @staticmethod def backward(ctx, grad_output): W_deq = weight_dequant(ctx.weight, ctx.weight_scale) - grad_X = torch_matmul(grad_output, W_deq.t()) + grad_X = torch_matmul(grad_output, W_deq) del W_deq return grad_X, None, None, None, None diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 9b6112a41..9fb15cfbe 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -293,7 +293,7 @@ def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): # TODO: After adding XPU BNB support, check this function if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t().contiguous(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -369,7 +369,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t().contiguous(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -445,7 +445,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t().contiguous(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t(), quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -716,13 +716,12 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if W_quant is None: out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: - quant_method = getattr(W, 'quant_method', None) or getattr(W_quant, 'quant_method', None) - if quant_method=='fp8': + if W_quant.ndim==2 and W_quant.shape[1]>1: + # This is block quantized FP8 matmul out = fp8_e4m3_forward(X, W, W_quant) - elif quant_method=='fbgemm_fp8': - out = fbgemm_fp8_linear(X, W, W_quant, ) else: - raise ValueError(f'no quant method found. {W.shape=} {W_quant.shape}') + # Row quantized FP8 + out = fbgemm_fp8_linear(X, W, W_quant, ) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -770,13 +769,12 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - quant_method = getattr(W, 'quant_method', None) or getattr(W_quant, 'quant_method', None) - if quant_method=='fp8': + if W_quant.ndim==2 and W_quant.shape[1]>1: + # This is block quantized FP8 matmul out = fp8_e4m3_forward(X, W, W_quant) - elif quant_method=='fbgemm_fp8': - out = fbgemm_fp8_linear(X, W, W_quant, ) else: - raise ValueError(f'no quant method found. {W.shape=} {W_quant.shape}') + # Row quantized FP8 + out = fbgemm_fp8_linear(X, W, W_quant, ) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) From dc4c855f8d809ba21aa0dd61588909b218bd397f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 10 Oct 2025 07:46:56 +0000 Subject: [PATCH 20/27] Fixup backward pass for fbgemm_fp8 --- unsloth/kernels/fp8.py | 74 +++++++++++++++++++++++++--------------- unsloth/kernels/utils.py | 6 ++-- 2 files changed, 49 insertions(+), 31 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 4611e4a05..e33959958 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -50,8 +50,15 @@ def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128 def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16): if s.shape[1] == 1: # this is row quantized weight, just simple multiplication suffices - y = x.to(torch.float32) * s.to(torch.float32) - return y.to(dtype).T + if x.shape[0]==s.shape[0]: + y = x.to(torch.float32) * s.to(torch.float32) + elif x.shape[1]==s.shape[0]: + # sometimes, this is called with the transpose of the weight. Adjust for that. + y = x.t().to(torch.float32) * s.to(torch.float32) + y = y.t() + else: + raise ValueError(f'Incompatible shapes {x.shape=}, {s.shape=}') + return y.to(dtype) else: # this is block quantized weight return weight_dequant_block(x, s, dtype=dtype) @@ -306,32 +313,43 @@ class FbgemmFp8Linear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, weight_scale, bias=None): - # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here - output_shape = (*x.shape[:-1], -1) - # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. - # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 - x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - x.view(-1, x.shape[-1]).contiguous(), scale_ub=getattr(weight, 'input_scale_ub', None) - ) - # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works - # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) - - # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight - weight_scale_float32 = weight_scale.to(torch.float32) - - if not weight.is_contiguous(): - weight = weight.contiguous() - if not weight_scale.is_contiguous(): - weight_scale = weight_scale.contiguous() - output = torch.ops.fbgemm.f8f8bf16_rowwise( - x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True - ) - output = output + bias if bias is not None else output - # Hacky for now, we have the output to the device of x - output = output.to(x.device, x.dtype) - output = output.reshape(output_shape) - del x_quantized, x_scale + if weight.shape[0]!=weight_scale.shape[0]: + if weight.shape[1]==weight_scale.shape[0]: + # This is generally the case when we do backward pass. The only way is to dequantize as there is no column wise fp8 matmul + W_deq = weight_dequant(weight, weight_scale).T + x = torch_matmul(x, W_deq) + del W_deq + return x + else: + raise ValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}") + else: + # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here + output_shape = (*x.shape[:-1], -1) + # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. + # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 + x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x.view(-1, x.shape[-1]).contiguous(), scale_ub=getattr(weight, 'input_scale_ub', None) + ) + # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works + # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) + + # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight + weight_scale_float32 = weight_scale.to(torch.float32) + + if not weight.is_contiguous(): + weight = weight.contiguous() + if not weight_scale.is_contiguous(): + weight_scale = weight_scale.contiguous() + + output = torch.ops.fbgemm.f8f8bf16_rowwise( + x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True + ) + output = output + bias if bias is not None else output + # Hacky for now, we have the output to the device of x + output = output.to(x.device, x.dtype) + output = output.reshape(output_shape) + del x_quantized, x_scale ctx.weight = weight ctx.weight_scale = weight_scale @@ -341,7 +359,7 @@ def forward(ctx, x, weight, weight_scale, bias=None): @staticmethod def backward(ctx, grad_output): W_deq = weight_dequant(ctx.weight, ctx.weight_scale) - grad_X = torch_matmul(grad_output, W_deq) + grad_X = torch_matmul(grad_output, W_deq.t()) del W_deq return grad_X, None, None, None, None diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 9fb15cfbe..84680f058 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -293,7 +293,7 @@ def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): # TODO: After adding XPU BNB support, check this function if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W, quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -369,7 +369,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W, quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files @@ -445,7 +445,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W - if W.dtype == torch.float8_e4m3fn: return weight_dequant(W.t(), quant_state) + if W.dtype == torch.float8_e4m3fn: return weight_dequant(W, quant_state) if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files From 5b7d7557076431fb7cfbea7d5dff03778aff8ac4 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 10 Oct 2025 09:53:52 +0000 Subject: [PATCH 21/27] refactor and use bf16 for dequant --- unsloth/kernels/fp8.py | 16 +++++++++++++--- unsloth/kernels/utils.py | 16 +++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index e33959958..107ea085b 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -51,14 +51,14 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16): if s.shape[1] == 1: # this is row quantized weight, just simple multiplication suffices if x.shape[0]==s.shape[0]: - y = x.to(torch.float32) * s.to(torch.float32) + y = x.to(dtype) * s.to(dtype) elif x.shape[1]==s.shape[0]: # sometimes, this is called with the transpose of the weight. Adjust for that. - y = x.t().to(torch.float32) * s.to(torch.float32) + y = x.t().to(dtype) * s.to(dtype) y = y.t() else: raise ValueError(f'Incompatible shapes {x.shape=}, {s.shape=}') - return y.to(dtype) + return y else: # this is block quantized weight return weight_dequant_block(x, s, dtype=dtype) @@ -366,3 +366,13 @@ def backward(ctx, grad_output): @torch.compile def fbgemm_fp8_linear(X, weight, weight_scale, bias=None, ): return FbgemmFp8Linear.apply(X, weight, weight_scale, bias) + +@torch.compile +def fp8_linear(X, weight, weight_scale, bias=None): + if weight_scale.ndim==2 and weight_scale.shape[1]>1: + # This is block quantized FP8 matmul + out = fp8_e4m3_forward(X, weight, weight_scale) + else: + # Row quantized FP8 + out = fbgemm_fp8_linear(X, weight, weight_scale, bias) + return out diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 84680f058..fddf5ade5 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -20,7 +20,7 @@ from typing import Optional from .. import DEVICE_TYPE, DEVICE_COUNT -from .fp8 import fp8_e4m3_forward, fbgemm_fp8_linear, weight_dequant +from .fp8 import weight_dequant, fp8_linear # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -716,12 +716,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if W_quant is None: out = torch_matmul(X, W.t(), out = out) elif W.dtype == torch.float8_e4m3fn: - if W_quant.ndim==2 and W_quant.shape[1]>1: - # This is block quantized FP8 matmul - out = fp8_e4m3_forward(X, W, W_quant) - else: - # Row quantized FP8 - out = fbgemm_fp8_linear(X, W, W_quant, ) + out = fp8_linear(X, W, W_quant, bias) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -769,12 +764,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype==torch.float8_e4m3fn: - if W_quant.ndim==2 and W_quant.shape[1]>1: - # This is block quantized FP8 matmul - out = fp8_e4m3_forward(X, W, W_quant) - else: - # Row quantized FP8 - out = fbgemm_fp8_linear(X, W, W_quant, ) + out = fp8_linear(X, W, W_quant, ) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) From da7d3f952be1abe97ce34937cef7ee3e867eeb51 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sat, 11 Oct 2025 09:46:17 +0000 Subject: [PATCH 22/27] Use torch fp8 block matmul --- unsloth/kernels/fp8.py | 53 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 107ea085b..8f4d15947 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -17,6 +17,7 @@ import triton.language as tl from torch.nn import functional as F import math +from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block torch_matmul = torch.matmul @@ -367,11 +368,61 @@ def backward(ctx, grad_output): def fbgemm_fp8_linear(X, weight, weight_scale, bias=None, ): return FbgemmFp8Linear.apply(X, weight, weight_scale, bias) + +class FP8_torch_linear(torch.autograd.Function): + @staticmethod + def forward(ctx, X, weight, weight_scale, bias=None): + + orig_shape = X.shape + X = X.view(-1,X.shape[-1]) + + bs_n, bs_k = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', [128,128]) + bs_m = bs_n + + m,n = weight.shape + p,q = weight_scale.shape + + if triton.cdiv(m,bs_n)!=p or triton.cdiv(n,bs_k)!=q: + if triton.cdiv(m,bs_n)==q and triton.cdiv(n,bs_k)==p: + # weights are tranposed during backward pass for training :) + # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X + weight_scale = weight_scale.T + else: + raise ValueError(f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}") + + xq, xs = triton_quantize_fp8_block(X, bs_m, bs_n, None) + output = torch.ops.fbgemm.f8f8bf16_blockwise(xq, weight.contiguous(), xs, weight_scale.contiguous(),bs_m,bs_n, bs_k) + output = output + bias if bias is not None else output + + output = output.view(*orig_shape[:-1], -1) + + del xq + del xs + + ctx.weight = weight + ctx.weight_scale = weight_scale + ctx.block_size = [bs_m, bs_n, bs_k] + + return output + + @staticmethod + def backward(ctx, grad_output): + W_deq = weight_dequant(ctx.weight, ctx.weight_scale) + grad_X = torch_matmul(grad_output, W_deq.t()) + del W_deq + return grad_X, None, None, None, None + +@torch.compile +def fp8_torch_linear(X, weight, weight_scale, bias=None): + return FP8_torch_linear.apply(X, weight, weight_scale, bias) + + @torch.compile def fp8_linear(X, weight, weight_scale, bias=None): if weight_scale.ndim==2 and weight_scale.shape[1]>1: # This is block quantized FP8 matmul - out = fp8_e4m3_forward(X, weight, weight_scale) + #out = fp8_e4m3_forward(X, weight, weight_scale) + out = fp8_torch_linear(X, weight, weight_scale, bias) else: # Row quantized FP8 out = fbgemm_fp8_linear(X, weight, weight_scale, bias) From 5af9f6294e9e04b33adeb98f8873ca76e154b23f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 12 Oct 2025 12:54:46 +0000 Subject: [PATCH 23/27] Disable torch block matmul for now --- unsloth/kernels/fp8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 8f4d15947..74ceea4e0 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -421,8 +421,11 @@ def fp8_torch_linear(X, weight, weight_scale, bias=None): def fp8_linear(X, weight, weight_scale, bias=None): if weight_scale.ndim==2 and weight_scale.shape[1]>1: # This is block quantized FP8 matmul - #out = fp8_e4m3_forward(X, weight, weight_scale) - out = fp8_torch_linear(X, weight, weight_scale, bias) + out = fp8_e4m3_forward(X, weight, weight_scale) + # These operations fall apart when X have large values in it. So disabling for the timebeing? + # The above operation makes the training loop ~4x slower per step but the output is correct :( + # TODO: Fix the outlier handling in torch implementation and enable this + # out = fp8_torch_linear(X, weight, weight_scale, bias) else: # Row quantized FP8 out = fbgemm_fp8_linear(X, weight, weight_scale, bias) From 5e90163cfac5e9ef5a7d6938709f90fc1d7ce521 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Tue, 14 Oct 2025 18:15:00 +0000 Subject: [PATCH 24/27] safer import and cosmetics --- unsloth/kernels/fp8.py | 32 ++++++++++++++++++-------------- unsloth/kernels/utils.py | 10 +++++----- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 74ceea4e0..28fbe6a04 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -17,7 +17,12 @@ import triton.language as tl from torch.nn import functional as F import math -from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block +try: + from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block +except ImportError: + triton_quantize_fp8_block = None + +from unsloth_zoo.temporary_patches.common import torch_compile torch_matmul = torch.matmul @@ -51,9 +56,9 @@ def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128 def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16): if s.shape[1] == 1: # this is row quantized weight, just simple multiplication suffices - if x.shape[0]==s.shape[0]: + if x.shape[0] == s.shape[0]: y = x.to(dtype) * s.to(dtype) - elif x.shape[1]==s.shape[0]: + elif x.shape[1] == s.shape[0]: # sometimes, this is called with the transpose of the weight. Adjust for that. y = x.t().to(dtype) * s.to(dtype) y = y.t() @@ -258,7 +263,6 @@ def grid(META): BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=8, ) - return C class FP8_E4M3Linear(torch.autograd.Function): @@ -270,8 +274,8 @@ def forward(ctx, X, weight, weight_scale): p,q = weight_scale.shape block_size = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', None) assert block_size is not None, "block_size is not set" - if triton.cdiv(m,block_size[0])!=p or triton.cdiv(n,block_size[1])!=q: - if triton.cdiv(m,block_size[0])==q and triton.cdiv(n,block_size[1])==p: + if triton.cdiv(m,block_size[0]) != p or triton.cdiv(n,block_size[1]) != q: + if triton.cdiv(m,block_size[0]) == q and triton.cdiv(n,block_size[1]) == p: # weights are tranposed during backward pass for training :) # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X weight_scale = weight_scale.T @@ -305,7 +309,7 @@ def backward(ctx, grad_output): del W_deq return grad_X, None, None -@torch.compile +@torch_compile def fp8_e4m3_forward(X, weight, weight_scale): return FP8_E4M3Linear.apply(X, weight, weight_scale) @@ -315,8 +319,8 @@ class FbgemmFp8Linear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, weight_scale, bias=None): - if weight.shape[0]!=weight_scale.shape[0]: - if weight.shape[1]==weight_scale.shape[0]: + if weight.shape[0] != weight_scale.shape[0]: + if weight.shape[1] == weight_scale.shape[0]: # This is generally the case when we do backward pass. The only way is to dequantize as there is no column wise fp8 matmul W_deq = weight_dequant(weight, weight_scale).T x = torch_matmul(x, W_deq) @@ -364,7 +368,7 @@ def backward(ctx, grad_output): del W_deq return grad_X, None, None, None, None -@torch.compile +@torch_compile def fbgemm_fp8_linear(X, weight, weight_scale, bias=None, ): return FbgemmFp8Linear.apply(X, weight, weight_scale, bias) @@ -382,8 +386,8 @@ def forward(ctx, X, weight, weight_scale, bias=None): m,n = weight.shape p,q = weight_scale.shape - if triton.cdiv(m,bs_n)!=p or triton.cdiv(n,bs_k)!=q: - if triton.cdiv(m,bs_n)==q and triton.cdiv(n,bs_k)==p: + if triton.cdiv(m,bs_n) != p or triton.cdiv(n,bs_k) != q: + if triton.cdiv(m,bs_n) == q and triton.cdiv(n,bs_k) == p: # weights are tranposed during backward pass for training :) # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X weight_scale = weight_scale.T @@ -412,12 +416,12 @@ def backward(ctx, grad_output): del W_deq return grad_X, None, None, None, None -@torch.compile +@torch_compile def fp8_torch_linear(X, weight, weight_scale, bias=None): return FP8_torch_linear.apply(X, weight, weight_scale, bias) -@torch.compile +@torch_compile def fp8_linear(X, weight, weight_scale, bias=None): if weight_scale.ndim==2 and weight_scale.shape[1]>1: # This is block quantized FP8 matmul diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index fddf5ade5..adfce97f4 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -205,9 +205,9 @@ def get_lora_parameters(proj): W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None) - if getattr(base_layer, 'quant_method', None)=='fp8': + if getattr(base_layer, 'quant_method', None) == 'fp8': # we need to somehow store and pass this information :) - W.block_size = getattr(base_layer, 'block_size', [128,128]) + W.block_size = getattr(base_layer, 'block_size', [128, 128]) W_quant.block_size = W.block_size # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: @@ -253,9 +253,9 @@ def get_lora_parameters_bias(proj): return W, W_quant, None, None, None, base_layer.bias pass - if getattr(base_layer, 'quant_method', None)=='fp8': + if getattr(base_layer, 'quant_method', None) == 'fp8': # we need to somehow store and pass this information :) - W.block_size = getattr(base_layer, 'block_size', [128,128]) + W.block_size = getattr(base_layer, 'block_size', [128, 128]) W_quant.block_size = W.block_size adapter = getattr(proj, "active_adapters", None) @@ -763,7 +763,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass - if W.dtype==torch.float8_e4m3fn: + if W.dtype == torch.float8_e4m3fn: out = fp8_linear(X, W, W_quant, ) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) From 80a044956c907587db21e9e29e34b23907bcec64 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 15 Oct 2025 06:51:29 +0000 Subject: [PATCH 25/27] more cosmectics --- unsloth/kernels/fp8.py | 58 ++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 28fbe6a04..a5178fb47 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -48,7 +48,7 @@ def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128 s = s.contiguous() assert x.dim() == 2 and s.dim() == 2 M, N = x.size() - y = torch.empty_like(x, dtype=dtype) + y = torch.empty_like(x, dtype = dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y @@ -77,13 +77,12 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) s = tl.max(tl.abs(x)) / 448.0 - if s==0: - # For a row of all zeros, lets return zeros as is - # for LoRA, there are cases where dY has 0 in it and we should not let it be NaN - # this is a deviation from the original implementation. - s = 1.0 + # For a row of all zeros, lets return zeros as is + # for LoRA, there are cases where dY has 0 in it and we should not let it be NaN + # this is a deviation from the original implementation. + s = 1.0 if s == 0 else s y = x / s - y = y.to(y_ptr.dtype.element_ty) + y = y.to(y_ptr.dtype) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) @@ -92,12 +91,12 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, tor x = x.contiguous() assert x.shape[-1] % block_size == 0 y = torch.empty_like(x, dtype=torch.float8_e4m3fn) - s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype = torch.float32) def grid(meta): return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) - act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size) return y, s @@ -159,7 +158,7 @@ def _w8a8_block_fp8_matmul( offs_bsn = offs_bn // group_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) @@ -184,7 +183,7 @@ def _w8a8_block_fp8_matmul( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c, mask = c_mask) def w8a8_block_fp8_matmul_triton( @@ -224,7 +223,7 @@ def w8a8_block_fp8_matmul_triton( assert triton.cdiv(K, block_k) == Bs.shape[1] C_shape = A.shape[:-1] + (N,) - C = A.new_empty(C_shape, dtype=output_dtype) + C = A.new_empty(C_shape, dtype = output_dtype) BLOCK_SIZE_M = 128 if M < BLOCK_SIZE_M: @@ -258,10 +257,10 @@ def grid(META): As.stride(-1), Bs.stride(1), Bs.stride(0), - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=8, + BLOCK_SIZE_M = BLOCK_SIZE_M, + BLOCK_SIZE_N = BLOCK_SIZE_N, + BLOCK_SIZE_K = BLOCK_SIZE_K, + GROUP_SIZE_M = 8, ) return C @@ -274,8 +273,8 @@ def forward(ctx, X, weight, weight_scale): p,q = weight_scale.shape block_size = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', None) assert block_size is not None, "block_size is not set" - if triton.cdiv(m,block_size[0]) != p or triton.cdiv(n,block_size[1]) != q: - if triton.cdiv(m,block_size[0]) == q and triton.cdiv(n,block_size[1]) == p: + if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q: + if triton.cdiv(m, block_size[0]) == q and triton.cdiv(n, block_size[1]) == p: # weights are tranposed during backward pass for training :) # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X weight_scale = weight_scale.T @@ -303,7 +302,6 @@ def forward(ctx, X, weight, weight_scale): @staticmethod def backward(ctx, grad_output): - # W_deq = reconstruct_weight_fp8(ctx.weight, ctx.weight_scale, ctx.block_size[0], ctx.block_size[1]) W_deq = weight_dequant(ctx.weight, ctx.weight_scale) grad_X = torch_matmul(grad_output, W_deq.t()) del W_deq @@ -317,7 +315,7 @@ def fp8_e4m3_forward(X, weight, weight_scale): class FbgemmFp8Linear(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, weight_scale, bias=None): + def forward(ctx, x, weight, weight_scale, bias=None): if weight.shape[0] != weight_scale.shape[0]: if weight.shape[1] == weight_scale.shape[0]: @@ -334,7 +332,7 @@ def forward(ctx, x, weight, weight_scale, bias=None): # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - x.view(-1, x.shape[-1]).contiguous(), scale_ub=getattr(weight, 'input_scale_ub', None) + x.view(-1, x.shape[-1]).contiguous(), scale_ub = getattr(weight, 'input_scale_ub', None) ) # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) @@ -348,7 +346,7 @@ def forward(ctx, x, weight, weight_scale, bias=None): weight_scale = weight_scale.contiguous() output = torch.ops.fbgemm.f8f8bf16_rowwise( - x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum=True + x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum = True ) output = output + bias if bias is not None else output # Hacky for now, we have the output to the device of x @@ -378,16 +376,16 @@ class FP8_torch_linear(torch.autograd.Function): def forward(ctx, X, weight, weight_scale, bias=None): orig_shape = X.shape - X = X.view(-1,X.shape[-1]) + X = X.view(-1, X.shape[-1]) - bs_n, bs_k = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', [128,128]) + bs_n, bs_k = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', [128, 128]) bs_m = bs_n m,n = weight.shape p,q = weight_scale.shape - if triton.cdiv(m,bs_n) != p or triton.cdiv(n,bs_k) != q: - if triton.cdiv(m,bs_n) == q and triton.cdiv(n,bs_k) == p: + if triton.cdiv(m, bs_n) != p or triton.cdiv(n, bs_k) != q: + if triton.cdiv(m, bs_n) == q and triton.cdiv(n, bs_k) == p: # weights are tranposed during backward pass for training :) # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X weight_scale = weight_scale.T @@ -395,7 +393,11 @@ def forward(ctx, X, weight, weight_scale, bias=None): raise ValueError(f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}") xq, xs = triton_quantize_fp8_block(X, bs_m, bs_n, None) - output = torch.ops.fbgemm.f8f8bf16_blockwise(xq, weight.contiguous(), xs, weight_scale.contiguous(),bs_m,bs_n, bs_k) + ## TODO: Investigate and resolve the high divergence of this output from baseline + # WARNING: This causes the outputs to diverge from expected when X has high values in it. + # That results in the model producing gibberish, especially on longer sequences and training loss starting at high values like 8 instead of <1 ideally + # Please refrain from using this till this issue is resolved. This exists here just for a future headstart. + output = torch.ops.fbgemm.f8f8bf16_blockwise(xq, weight.contiguous(), xs, weight_scale.contiguous(), bs_m, bs_n, bs_k) output = output + bias if bias is not None else output output = output.view(*orig_shape[:-1], -1) @@ -423,7 +425,7 @@ def fp8_torch_linear(X, weight, weight_scale, bias=None): @torch_compile def fp8_linear(X, weight, weight_scale, bias=None): - if weight_scale.ndim==2 and weight_scale.shape[1]>1: + if weight_scale.ndim == 2 and weight_scale.shape[1] > 1: # This is block quantized FP8 matmul out = fp8_e4m3_forward(X, weight, weight_scale) # These operations fall apart when X have large values in it. So disabling for the timebeing? From dd4bf135f9c50366d08841e6ff7fda9c019f882f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 15 Oct 2025 07:26:50 +0000 Subject: [PATCH 26/27] add torchao operations --- unsloth/kernels/fp8.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index a5178fb47..3df11afa7 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -22,6 +22,13 @@ except ImportError: triton_quantize_fp8_block = None +try: + from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( + blockwise_fp8_gemm as torchao_blockwise_gemm, + ) +except ImportError: + torchao_blockwise_gemm = None + from unsloth_zoo.temporary_patches.common import torch_compile torch_matmul = torch.matmul @@ -82,7 +89,7 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): # this is a deviation from the original implementation. s = 1.0 if s == 0 else s y = x / s - y = y.to(y_ptr.dtype) + y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) @@ -264,7 +271,28 @@ def grid(META): ) return C -class FP8_E4M3Linear(torch.autograd.Function): +def torchao_block_matmul( + act_q: torch.Tensor, + weight_q: torch.Tensor, + act_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: tuple[int, int], + output_dtype: torch.dtype = torch.bfloat16, +): + out = torchao_blockwise_gemm( + act_q.contiguous(), + act_scale.contiguous(), + weight_q.contiguous(), + weight_scale.contiguous(), + block_size=block_size[1], + ) + return out.to(output_dtype) + +# This torchao FP8 matmul seems to be ~3x faster than the w8a8_block_fp8_matmul_triton. Though this is 15-30% slower than fbgemm implementation. +# But this gives very comparable results when it comes to training loss, so we prefer using it when available. +fp8_block_matmul = torchao_block_matmul if torchao_blockwise_gemm is not None else w8a8_block_fp8_matmul_triton + +class FP8BlockQuantLinear(torch.autograd.Function): @staticmethod def forward(ctx, X, weight, weight_scale): @@ -285,7 +313,7 @@ def forward(ctx, X, weight, weight_scale): weight = weight.contiguous() # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353 qinput, scale = act_quant(X, block_size[1]) - output = w8a8_block_fp8_matmul_triton( + output = fp8_block_matmul( qinput, weight, scale, @@ -308,8 +336,8 @@ def backward(ctx, grad_output): return grad_X, None, None @torch_compile -def fp8_e4m3_forward(X, weight, weight_scale): - return FP8_E4M3Linear.apply(X, weight, weight_scale) +def fp8_block_quant_forward(X, weight, weight_scale): + return FP8BlockQuantLinear.apply(X, weight, weight_scale) class FbgemmFp8Linear(torch.autograd.Function): @@ -427,9 +455,9 @@ def fp8_torch_linear(X, weight, weight_scale, bias=None): def fp8_linear(X, weight, weight_scale, bias=None): if weight_scale.ndim == 2 and weight_scale.shape[1] > 1: # This is block quantized FP8 matmul - out = fp8_e4m3_forward(X, weight, weight_scale) + out = fp8_block_quant_forward(X, weight, weight_scale) # These operations fall apart when X have large values in it. So disabling for the timebeing? - # The above operation makes the training loop ~4x slower per step but the output is correct :( + # The above operation makes the training loop ~15-30% slower if torchao is available ~4x slower if not :( # TODO: Fix the outlier handling in torch implementation and enable this # out = fp8_torch_linear(X, weight, weight_scale, bias) else: From 82c8eeff6d2d6b7b8cd0c41fe5dcfe7c85a8ab4f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 16 Oct 2025 09:55:57 +0000 Subject: [PATCH 27/27] Spaceeeeeee --- unsloth/kernels/fp8.py | 13 ++++--------- unsloth/kernels/utils.py | 2 +- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py index 3df11afa7..aa2d9074b 100644 --- a/unsloth/kernels/fp8.py +++ b/unsloth/kernels/fp8.py @@ -297,8 +297,8 @@ class FP8BlockQuantLinear(torch.autograd.Function): @staticmethod def forward(ctx, X, weight, weight_scale): # block_size = getattr(weight, 'block_size', [128,128]) - m,n = weight.shape - p,q = weight_scale.shape + m, n = weight.shape + p, q = weight_scale.shape block_size = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', None) assert block_size is not None, "block_size is not set" if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q: @@ -321,11 +321,9 @@ def forward(ctx, X, weight, weight_scale): block_size, output_dtype=X.dtype, ) - ctx.weight = weight ctx.weight_scale = weight_scale ctx.block_size = block_size - return output.to(X.dtype) @staticmethod @@ -344,7 +342,6 @@ class FbgemmFp8Linear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, weight_scale, bias=None): - if weight.shape[0] != weight_scale.shape[0]: if weight.shape[1] == weight_scale.shape[0]: # This is generally the case when we do backward pass. The only way is to dequantize as there is no column wise fp8 matmul @@ -384,7 +381,6 @@ def forward(ctx, x, weight, weight_scale, bias=None): ctx.weight = weight ctx.weight_scale = weight_scale - return output @staticmethod @@ -409,8 +405,8 @@ def forward(ctx, X, weight, weight_scale, bias=None): bs_n, bs_k = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', [128, 128]) bs_m = bs_n - m,n = weight.shape - p,q = weight_scale.shape + m, n = weight.shape + p, q = weight_scale.shape if triton.cdiv(m, bs_n) != p or triton.cdiv(n, bs_k) != q: if triton.cdiv(m, bs_n) == q and triton.cdiv(n, bs_k) == p: @@ -436,7 +432,6 @@ def forward(ctx, X, weight, weight_scale, bias=None): ctx.weight = weight ctx.weight_scale = weight_scale ctx.block_size = [bs_m, bs_n, bs_k] - return output @staticmethod diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index adfce97f4..9a46d0d5d 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -764,7 +764,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): pass if W.dtype == torch.float8_e4m3fn: - out = fp8_linear(X, W, W_quant, ) + out = fp8_linear(X, W, W_quant) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out)