Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
710b1ff
fix
SunMarc Nov 5, 2025
f72f96d
fixes for more models torch_bc
ArthurZucker Nov 5, 2025
e341529
nits and fixes
ArthurZucker Nov 5, 2025
0e51dec
last update
ArthurZucker Nov 5, 2025
0f022b5
Revert "tied weight first shot to the fiiiixxxxxx"
ArthurZucker Nov 5, 2025
1dabb4c
here we go again
ArthurZucker Nov 5, 2025
0c2b667
an attempt
ArthurZucker Nov 6, 2025
c48e1ed
up?
ArthurZucker Nov 6, 2025
d223635
nits
ArthurZucker Nov 6, 2025
bdbc01a
Fix bnb loading !
SunMarc Nov 6, 2025
399388d
rm print
SunMarc Nov 6, 2025
acbeeae
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 6, 2025
e16da23
rm import
SunMarc Nov 7, 2025
386e259
update
SunMarc Nov 7, 2025
9788014
Merge remote-tracking branch 'upstream/refactor-weight-loading' into …
SunMarc Nov 7, 2025
72eff97
Update src/transformers/core_model_loading.py
SunMarc Nov 7, 2025
d841a04
Fix loadedparam
SunMarc Nov 7, 2025
e235eed
Merge remote-tracking branch 'upstream/fix-bnb' into fix-bnb
SunMarc Nov 7, 2025
e4df752
rm report
SunMarc Nov 7, 2025
3e69622
Fix tests single gpu
SunMarc Nov 7, 2025
a052513
should fix it
SunMarc Nov 7, 2025
db4fe31
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
9fa1b7a
guard needed for compressed-tensors
SunMarc Nov 10, 2025
ea5822d
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
5881d8e
deal with buffers
SunMarc Nov 10, 2025
3651460
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
00b0044
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
7d8df52
fix
SunMarc Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,17 @@ def set_param_for_module(
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
distributed_operation: Optional[TensorParallelLayer],
hf_quantizer,
):
with log_to_misc(layer_name, misc, layer_name):
module_path, _, param_name = layer_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
if isinstance(param_value, list):
param_value = param_value[0]
elif isinstance(param_value, torch.nn.Parameter):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing that's for BNB? Can't we force it to return a data instead of a param?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do tensor[...], the nn.parameters becomes a tensor. So the isinstance check later becomes invalid. But yeah this is for bnb as we need to return a nn.Params4bit which is a subclass of torch.nn.Parameter and I didn't want to pollute this function with hf_quantizer related logic as much as possible.

else:
param_value = param_value[...]
ref = meta_model_state_dict.get(layer_name, empty_param)
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
Expand Down Expand Up @@ -395,7 +401,7 @@ def convert_and_load_state_dict_in_model(
state_dict,
weight_mapping,
tp_plan,
quantizer,
hf_quantizer,
dtype=None,
device_map=None,
dtype_plan=None,
Expand Down Expand Up @@ -460,14 +466,9 @@ def convert_and_load_state_dict_in_model(
if empty_param is None:
unexpected_keys.add(t)
continue

if quantizer is not None and quantizer.param_needs_quantization(model, t):
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
from .integrations.finegrained_fp8 import Fp8Quantize

converter.quantization_operation = Fp8Quantize() # TODO support other methods
else:
raise ValueError("This quantization method is gonna be supported SOOOON")

if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
converter.quantization_operation = hf_quantizer.get_quantize_ops()
else:
_dtype = dtype
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
Expand Down Expand Up @@ -532,7 +533,7 @@ def convert_and_load_state_dict_in_model(
with log_to_misc(layer_name, misc, op=op):
realized_value.update(
op.convert(
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
{k: realized_value.pop(k)}, quant_config=hf_quantizer.quantization_config, model=model
)
)

Expand All @@ -549,6 +550,7 @@ def convert_and_load_state_dict_in_model(
missing_keys,
misc,
converter.distributed_operation,
hf_quantizer
)
except SkipLayer:
continue
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"get_keys_to_not_convert",
"replace_with_bnb_linear",
"validate_bnb_backend_availability",
"Bnb4bitQuantize",
],
"deepspeed": [
"HfDeepSpeedConfig",
Expand All @@ -51,7 +52,7 @@
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear", "Fp8Quantize"],
"fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
"ggml": [
"GGUF_CONFIG_MAPPING",
Expand Down Expand Up @@ -181,6 +182,7 @@
get_keys_to_not_convert,
replace_with_bnb_linear,
validate_bnb_backend_availability,
Bnb4bitQuantize,
)
from .deepspeed import (
HfDeepSpeedConfig,
Expand All @@ -196,7 +198,7 @@
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear, Fp8Quantize
from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
from .ggml import (
GGUF_CONFIG_MAPPING,
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from inspect import signature
from typing import Optional

from ..utils import (
get_available_devices,
Expand All @@ -26,7 +27,16 @@

logger = logging.get_logger(__name__)

from ..core_model_loading import ConversionOps

class Bnb4bitQuantize(ConversionOps):
def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
old_value = model.get_parameter_or_buffer(target_key)
new_value = bnb.nn.Params4bit(value, **old_value.__dict__).to(value.device)
return {target_key : new_value}

def _replace_with_bnb_linear(
model,
modules_to_not_convert=None,
Expand Down
15 changes: 3 additions & 12 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,7 @@ def replace_with_fp8_linear(
)

return model


class QuantizationOp(ConversionOps):
"""Base class for quantization operations."""

pass


class Fp8Quantize(QuantizationOp):
class Fp8Quantize(ConversionOps):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""
Expand All @@ -587,7 +579,7 @@ def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self.reverse_op = Fp8Dequantize

def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
def convert(self, input_dict: torch.Tensor, quant_config: Optional[dict[str, Any]]= None, **kwargs) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
Expand Down Expand Up @@ -655,8 +647,7 @@ def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) ->
scale_key: inv_scales,
}


class Fp8Dequantize(QuantizationOp):
class Fp8Dequantize(ConversionOps):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""

def __init__(self, block_size: Optional[tuple[int, int]] = None):
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4733,8 +4733,13 @@ def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool)
for key in self.state_dict():
# If it's part of the keys that will be loaded, mark it as already initialized
if key not in missing_keys:
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
# some quantization methods save in the state_dict tensors that are not stored as buffer or parameters
try:
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
except AttributeError as e:
if not is_quantized:
raise e

def set_is_initialized_for_modules(module):
# A module is already initialized if and only if all its children are also already initialized, and all
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,12 @@ def _convert_model_for_quantization(self, model):
parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name]["module_name"](
model.config.get_text_config()
)



def get_quantize_ops(self):
raise NotImplementedError(
f"{self.quantization_config.quant_method} is not available yet and will be supported soon."
)

class SequentialLlama4TextExperts(ModuleList):
"""
A module that implements a compressed version of a list of expert modules.
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,7 @@ def _dequantize(self, model):
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model

def get_quantize_ops(self):
from ..integrations.bitsandbytes import Bnb4bitQuantize
return Bnb4bitQuantize()
4 changes: 4 additions & 0 deletions src/transformers/quantizers/quantizer_finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,7 @@ def is_trainable(self) -> bool:
def get_accelerator_warm_up_factor(self):
# Pre-processing is done cleanly, so we can allocate everything here
return 2

def get_quantize_ops(self):
from ..integrations import Fp8Quantize
return Fp8Quantize()
2 changes: 1 addition & 1 deletion tests/utils/test_core_model_loading_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_moe_and_qkv_conversion(self):
]

missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, quantizer=None
model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None
)

self.assertEqual(missing, set())
Expand Down