Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
710b1ff
fix
SunMarc Nov 5, 2025
f72f96d
fixes for more models torch_bc
ArthurZucker Nov 5, 2025
e341529
nits and fixes
ArthurZucker Nov 5, 2025
0e51dec
last update
ArthurZucker Nov 5, 2025
0f022b5
Revert "tied weight first shot to the fiiiixxxxxx"
ArthurZucker Nov 5, 2025
1dabb4c
here we go again
ArthurZucker Nov 5, 2025
0c2b667
an attempt
ArthurZucker Nov 6, 2025
c48e1ed
up?
ArthurZucker Nov 6, 2025
d223635
nits
ArthurZucker Nov 6, 2025
bdbc01a
Fix bnb loading !
SunMarc Nov 6, 2025
399388d
rm print
SunMarc Nov 6, 2025
acbeeae
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 6, 2025
e16da23
rm import
SunMarc Nov 7, 2025
386e259
update
SunMarc Nov 7, 2025
9788014
Merge remote-tracking branch 'upstream/refactor-weight-loading' into …
SunMarc Nov 7, 2025
72eff97
Update src/transformers/core_model_loading.py
SunMarc Nov 7, 2025
d841a04
Fix loadedparam
SunMarc Nov 7, 2025
e235eed
Merge remote-tracking branch 'upstream/fix-bnb' into fix-bnb
SunMarc Nov 7, 2025
e4df752
rm report
SunMarc Nov 7, 2025
3e69622
Fix tests single gpu
SunMarc Nov 7, 2025
a052513
should fix it
SunMarc Nov 7, 2025
db4fe31
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
9fa1b7a
guard needed for compressed-tensors
SunMarc Nov 10, 2025
ea5822d
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
5881d8e
deal with buffers
SunMarc Nov 10, 2025
3651460
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
00b0044
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
7d8df52
fix
SunMarc Nov 10, 2025
5ce08ca
for now let's do this
SunMarc Nov 12, 2025
1a2b5ca
Merge remote-tracking branch 'upstream/refactor-weight-loading' into …
SunMarc Nov 12, 2025
2f1b69c
fix
SunMarc Nov 12, 2025
3c2d946
fix small test
SunMarc Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/modular-transformers/modeling_dummy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias



def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
2 changes: 2 additions & 0 deletions examples/modular-transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias



def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
1 change: 1 addition & 0 deletions report.json

Large diffs are not rendered by default.

209 changes: 124 additions & 85 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,40 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import Any, Optional, Union

import torch
from torch.distributed.tensor import DTensor

from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer
from .utils import logging
from .quantizers import HfQuantizer
from .utils import is_torch_greater_or_equal, logging
from .utils.quantization_config import QuantizationMethod


_torch_distributed_available = torch.distributed.is_available()
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor


logger = logging.get_logger(__name__)

str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}


def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
"""
Expand Down Expand Up @@ -280,85 +303,59 @@ class ConversionEntry:
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4


class LoadedParameter(torch.nn.Parameter):
r"""
Because `transformers` initialized the missing keys we need to make sure
we can skip the ones that are actually loaded. Now we could force something, but
we want people to have an intuitive API usage, thus they can keep the well know API, and
just define their custom `_init_weight`, as long as they don't use `module.xxx.data`.

We added a check for this in `make fixup` to force people to use it.
After the `missing` weights are initialized, LoadedParameters become just nn.Parameters.
# Factory function to create LoadedParameter subclasses dynamically
def get_loaded_parameter_class(base_cls):
"""

def __new__(cls, data=None, requires_grad=True):
inst = super().__new__(cls, data, requires_grad)
inst._is_hf_initialized = False
return inst

def __repr__(self):
return f"LoadedParameter(_is_hf_initialized={self._is_hf_initialized}, data={self.data}"
# block .data assignment when flagged
@property
def data(self):
return super().data

@data.setter
def data(self, new):
if not getattr(self, "_is_hf_initialized", False):
super(LoadedParameter, LoadedParameter).data.__set__(self, new) # delegate to base
# else: skip or warn

# shadow common in-place init methods
def _guard(self, fn, *a, **k):
if getattr(self, "_is_hf_initialized", False):
base_cls: an nn.Parameter subclass (or nn.Parameter)
Returns a new class that combines the base_cls with LoadedParameterMixin
"""
class LoadedParam(base_cls):
_inplace_methods = [
'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_',
'copy_', 'erfinv_', 'log_'
]
def __new__(cls, from_existing, **kwargs):
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
inst._original_param = from_existing
# Explicitly override all in-place methods per instance
for method_name in inst._inplace_methods:
setattr(inst, method_name, MethodType(inst._skip, inst))

return inst

def _skip(self, *args, **kwargs):
"""Helper to skip in-place operations."""
Copy link
Member Author

@SunMarc SunMarc Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified a bit to accomodate subclass of nn.parameters cc @ArthurZucker. Also if we are using this class, it means the param is initialized as you said so let's simplify everything

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep absolutely, ty

return self
return fn(*a, **k)

def normal_(self, *a, **k):
return self._guard(super().normal_, *a, **k)

def uniform_(self, *a, **k):
return self._guard(super().uniform_, *a, **k)
def __repr__(self):
return f"LoadedParameter(data={self.data})"

def zero_(self):
return self._guard(super().zero_)
@property
def data(self):
return super().data

def fill_(self, *a, **k):
return self._guard(super().fill_, *a, **k)
@data.setter
def data(self, new):
pass

def copy_(self, *a, **k):
return self._guard(super().copy_, *a, **k)
return LoadedParam

def mul_(self, *a, **k):
return self._guard(super().copy_, *a, **k)

def add_(self, *a, **k):
return self._guard(super().copy_, *a, **k)

def clamp_(self, *a, **k):
return self._guard(super().copy_, *a, **k)

def erfinv_(self, *a, **k):
return self._guard(super().copy_, *a, **k)

def log_(self, *a, **k):
return self._guard(super().copy_, *a, **k)


def _materialize_copy(tensor, dtype):
def _materialize_copy(tensor, dtype=None):
# PyTorch: this runs in C and releases the GIL; good for threads.
return tensor[...].to(dtype)
tensor = tensor[...]
if dtype is not None:
tensor = tensor.to(dtype)
return tensor


def spawn_materialize(thread_pool, tensor, dtype) -> Future:
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
def _job():
return _materialize_copy(tensor, dtype)

return thread_pool.submit(_job)


def spawn_tp_materialize(thread_pool, tensor, dtype, sharding_method, tensor_idx) -> Future:
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
def _job():
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]

Expand Down Expand Up @@ -423,11 +420,17 @@ def set_param_for_module(
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
distributed_operation: Optional[TensorParallelLayer],
hf_quantizer,
):
with log_to_misc(layer_name, misc, layer_name):
module_path, _, param_name = layer_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
if isinstance(param_value, list):
param_value = param_value[0]
elif isinstance(param_value, torch.nn.Parameter):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing that's for BNB? Can't we force it to return a data instead of a param?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do tensor[...], the nn.parameters becomes a tensor. So the isinstance check later becomes invalid. But yeah this is for bnb as we need to return a nn.Params4bit which is a subclass of torch.nn.Parameter and I didn't want to pollute this function with hf_quantizer related logic as much as possible.

else:
param_value = param_value[...]
ref = meta_model_state_dict.get(layer_name, empty_param)
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
Expand All @@ -442,13 +445,16 @@ def set_param_for_module(
)
else:
pass # TODO for "local" stuff, it will trigger missmatched no?
param_value: LoadedParameter = LoadedParameter(param_value, requires_grad=param_value.is_floating_point())
else:
param_value: LoadedParameter = LoadedParameter(param_value.data)
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())

# to skip any inplace method that modifies the param data
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)

if ref is not None and ref.shape != param_value.shape:
# skip mismatch for hf_quantizer for now
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the shape of the BnbLinear not correct? this I also don't think we want long term no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because when we initialize the meta model with nn.Linear4bit, those don't have the right shape as the weights are not quantized yet. But yeah maybe we can fix this by overwriting the shape of the param when replacing the layers. In long term, we will remove this yes.

mismatch_keys.add((layer_name, param_value.shape, ref.shape))
missing_keys.discard(layer_name)
# maybe not needed anymore, but let's keep this as there are still code related to that
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
setattr(module_obj, param_name, param_value)

Expand All @@ -464,7 +470,7 @@ def convert_and_load_state_dict_in_model(
state_dict,
weight_mapping,
tp_plan,
quantizer,
hf_quantizer,
dtype=None,
device_map=None,
dtype_plan=None,
Expand Down Expand Up @@ -525,16 +531,19 @@ def convert_and_load_state_dict_in_model(
empty_param = meta_model_state_dict.get(t)
# If it does not exist, it's unexpected
if empty_param is None:
unexpected_keys.add(t)
continue

if quantizer is not None and quantizer.param_needs_quantization(model, t):
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
from .integrations.finegrained_fp8 import Fp8Quantize

converter.quantization_operation = Fp8Quantize() # TODO support other methods
if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto 😉

pass
else:
raise ValueError("This quantization method is gonna be supported SOOOON")
unexpected_keys.add(t)
continue

if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
converter.quantization_operation = hf_quantizer.get_quantize_ops()
# TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype
k_dtype = tensor.get_dtype()
dtype = str_to_torch_dtype[k_dtype]
empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta")
_, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer)
Comment on lines +639 to +644
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? is it because BNB needs say bf16 always? can you elaborate here because I don't upcast any of the parameters, they just have the _dtype on meta, and then get whatever was loaded from the weights

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to infer the right dtype for each values in the checkpoints:

  • some of the values are not parameters or buffers of the model so we shouldn't change the dtype
  • for some parameters / buffers, we should also keep the same dtype as the checkpoint (empty_param_checkpoint) because the _dtype on meta is not correct ... (fp16 instead of int8) . But this can be fixed potentially if we initialize the correct dtype. For bnb it should work but not sure for other method like torchao as the dtype is hard to infer from the beginning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if quantize we never change the dtype of the param, which is the source of truth

else:
_dtype = dtype
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
Expand Down Expand Up @@ -599,9 +608,7 @@ def convert_and_load_state_dict_in_model(
if op := converter.quantization_operation:
with log_to_misc(layer_name, misc, op=op):
realized_value.update(
op.convert(
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
)
op.convert({k: realized_value.pop(k)}, model=model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of passing the whole model!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I could not do that but let's keep this for now

)

for k, output_value in realized_value.items():
Expand All @@ -617,9 +624,10 @@ def convert_and_load_state_dict_in_model(
missing_keys,
misc,
converter.distributed_operation,
hf_quantizer
)
except SkipLayer:
continue
except Exception as e :
raise e
del group

# Update progress bar
Expand Down Expand Up @@ -648,3 +656,34 @@ def revert_weight_conversion(model, state_dict):
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict

def _infer_parameter_dtype(
model: torch.nn.Module,
param_name: str,
empty_param: torch.Tensor,
hf_quantizer: Optional[HfQuantizer] = None,
) -> tuple[bool, Optional[torch.dtype]]:
try:
old_param = model.get_parameter_or_buffer(param_name)
except Exception as e:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
QuantizationMethod.MXFP4,
QuantizationMethod.BITS_AND_BYTES,
}:
return True, None
else:
raise e
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
casting_dtype = None
Comment on lines +774 to +777
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't the methods do this in the quantize because they should!

is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
# dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name):
casting_dtype = model.config._pre_quantization_dtype
else:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"get_keys_to_not_convert",
"replace_with_bnb_linear",
"validate_bnb_backend_availability",
"Bnb4bitQuantize",
],
"deepspeed": [
"HfDeepSpeedConfig",
Expand Down Expand Up @@ -177,6 +178,7 @@
unpack_weights,
)
from .bitsandbytes import (
Bnb4bitQuantize,
dequantize_and_replace,
get_keys_to_not_convert,
replace_with_bnb_linear,
Expand Down
41 changes: 41 additions & 0 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect
from collections import defaultdict
from inspect import signature
from typing import Optional

from ..quantizers.quantizers_utils import get_module_from_name
from ..utils import (
get_available_devices,
is_accelerate_available,
Expand All @@ -26,6 +29,44 @@

logger = logging.get_logger(__name__)

from ..core_model_loading import ConversionOps


class Bnb4bitQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
if not self.hf_quantizer.pre_quantized:
old_value = model.get_parameter_or_buffer(target_key)
new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device)
return {target_key : new_value}
else:
full_name = target_key
# update param name to get the weights instead of the quantized stats
target_key = self.hf_quantizer.get_param_name(target_key)
module, _ = get_module_from_name(model, target_key)

module_name = target_key.rsplit(".", 1)[0]
# Save the states for later quantization when they are all gathered
if not hasattr(self.hf_quantizer, "param_quant_stats"):
self.hf_quantizer.param_quant_stats = defaultdict(dict)
Comment on lines +61 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are gathered from what? sorry I am not familiar with it, you need which states?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.

self.hf_quantizer.param_quant_stats[module_name].update({full_name: value})
# We are ready for quantization in this case (note, the +1 is for the weight itself)
if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1:
weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight")
new_value = bnb.nn.Params4bit.from_prequantized(
data=weight,
quantized_stats=self.hf_quantizer.param_quant_stats[module_name],
requires_grad=False,
device=value.device,
module=module
)
del self.hf_quantizer.param_quant_stats[module_name]
return {target_key : new_value}
return {}

def _replace_with_bnb_linear(
model,
Expand Down
Loading
Loading