Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
710b1ff
fix
SunMarc Nov 5, 2025
f72f96d
fixes for more models torch_bc
ArthurZucker Nov 5, 2025
e341529
nits and fixes
ArthurZucker Nov 5, 2025
0e51dec
last update
ArthurZucker Nov 5, 2025
0f022b5
Revert "tied weight first shot to the fiiiixxxxxx"
ArthurZucker Nov 5, 2025
1dabb4c
here we go again
ArthurZucker Nov 5, 2025
0c2b667
an attempt
ArthurZucker Nov 6, 2025
c48e1ed
up?
ArthurZucker Nov 6, 2025
d223635
nits
ArthurZucker Nov 6, 2025
bdbc01a
Fix bnb loading !
SunMarc Nov 6, 2025
399388d
rm print
SunMarc Nov 6, 2025
acbeeae
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 6, 2025
e16da23
rm import
SunMarc Nov 7, 2025
386e259
update
SunMarc Nov 7, 2025
9788014
Merge remote-tracking branch 'upstream/refactor-weight-loading' into …
SunMarc Nov 7, 2025
72eff97
Update src/transformers/core_model_loading.py
SunMarc Nov 7, 2025
d841a04
Fix loadedparam
SunMarc Nov 7, 2025
e235eed
Merge remote-tracking branch 'upstream/fix-bnb' into fix-bnb
SunMarc Nov 7, 2025
e4df752
rm report
SunMarc Nov 7, 2025
3e69622
Fix tests single gpu
SunMarc Nov 7, 2025
a052513
should fix it
SunMarc Nov 7, 2025
db4fe31
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
9fa1b7a
guard needed for compressed-tensors
SunMarc Nov 10, 2025
ea5822d
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
5881d8e
deal with buffers
SunMarc Nov 10, 2025
3651460
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
00b0044
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
7d8df52
fix
SunMarc Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 82 additions & 20 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,21 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import Any, Optional, Union

import torch

from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer,DTensor,Replicate
from .utils import logging
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer, DTensor, Replicate
from .quantizers import HfQuantizer
from .utils import is_torch_greater_or_equal, logging
from .utils.quantization_config import QuantizationMethod


_torch_distributed_available = torch.distributed.is_available()
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor


import itertools
Expand Down Expand Up @@ -81,6 +90,21 @@

logger = logging.get_logger(__name__)

str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}


def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
"""
Expand Down Expand Up @@ -361,6 +385,7 @@ def data(self):
@data.setter
def data(self, new):
pass

def __lt__(self, other): return torch.Tensor.__lt__(self, other)
def __le__(self, other): return torch.Tensor.__le__(self, other)
def __gt__(self, other): return torch.Tensor.__gt__(self, other)
Expand Down Expand Up @@ -462,6 +487,7 @@ def set_param_for_module(
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
distributed_operation: Optional[TensorParallelLayer],
hf_quantizer,
):
with log_to_misc(layer_name, misc, layer_name):
module_path, _, param_name = layer_name.rpartition(".")
Expand All @@ -470,7 +496,6 @@ def set_param_for_module(
param_value = param_value[0]
elif not isinstance(param_value, torch.nn.Parameter):
param_value = param_value[...]
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
ref = meta_model_state_dict.get(layer_name, empty_param)


Expand All @@ -485,14 +510,18 @@ def set_param_for_module(
shape=ref.size(),
stride=ref.stride(),
)
if not use_dtensor:
if not use_dtensor:
# we convert to local
param_value = param_value.to_local()

if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())

# to skip any inplace method that modifies the param data
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)

if ref is not None and ref.shape != param_value.shape:
# skip mismatch for hf_quantizer for now
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the shape of the BnbLinear not correct? this I also don't think we want long term no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because when we initialize the meta model with nn.Linear4bit, those don't have the right shape as the weights are not quantized yet. But yeah maybe we can fix this by overwriting the shape of the param when replacing the layers. In long term, we will remove this yes.

mismatch_keys.add((layer_name, param_value.shape, ref.shape))
setattr(module_obj._parameters[param_name], "_is_hf_initialized", False) # Needs to be initialized
missing_keys.discard(layer_name)
Expand All @@ -513,7 +542,7 @@ def convert_and_load_state_dict_in_model(
state_dict,
weight_mapping,
tp_plan,
quantizer,
hf_quantizer,
dtype=None,
device_map=None,
dtype_plan=None,
Expand Down Expand Up @@ -574,16 +603,19 @@ def convert_and_load_state_dict_in_model(
empty_param = meta_model_state_dict.get(t)
# If it does not exist, it's unexpected
if empty_param is None:
unexpected_keys.add(t)
continue

if quantizer is not None and quantizer.param_needs_quantization(model, t):
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
from .integrations.finegrained_fp8 import Fp8Quantize

converter.quantization_operation = Fp8Quantize() # TODO support other methods
if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto 😉

pass
else:
raise ValueError("This quantization method is gonna be supported SOOOON")
unexpected_keys.add(t)
continue

if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
converter.quantization_operation = hf_quantizer.get_quantize_ops()
# TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype
k_dtype = tensor.get_dtype()
dtype = str_to_torch_dtype[k_dtype]
empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta")
_, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer)
Comment on lines +613 to +618
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? is it because BNB needs say bf16 always? can you elaborate here because I don't upcast any of the parameters, they just have the _dtype on meta, and then get whatever was loaded from the weights

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to infer the right dtype for each values in the checkpoints:

  • some of the values are not parameters or buffers of the model so we shouldn't change the dtype
  • for some parameters / buffers, we should also keep the same dtype as the checkpoint (empty_param_checkpoint) because the _dtype on meta is not correct ... (fp16 instead of int8) . But this can be fixed potentially if we initialize the correct dtype. For bnb it should work but not sure for other method like torchao as the dtype is hard to infer from the beginning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if quantize we never change the dtype of the param, which is the source of truth

else:
_dtype = dtype
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
Expand Down Expand Up @@ -648,9 +680,7 @@ def convert_and_load_state_dict_in_model(
if op := converter.quantization_operation:
with log_to_misc(layer_name, misc, op=op):
realized_value.update(
op.convert(
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
)
op.convert({k: realized_value.pop(k)}, model=model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of passing the whole model!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I could not do that but let's keep this for now

)

for k, output_value in realized_value.items():
Expand All @@ -666,9 +696,10 @@ def convert_and_load_state_dict_in_model(
missing_keys,
misc,
converter.distributed_operation,
hf_quantizer
)
except SkipLayer:
continue
except Exception as e :
raise e
del group

# Update progress bar
Expand Down Expand Up @@ -697,3 +728,34 @@ def revert_weight_conversion(model, state_dict):
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict

def _infer_parameter_dtype(
model: torch.nn.Module,
param_name: str,
empty_param: torch.Tensor,
hf_quantizer: Optional[HfQuantizer] = None,
) -> tuple[bool, Optional[torch.dtype]]:
try:
old_param = model.get_parameter_or_buffer(param_name)
except Exception as e:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
QuantizationMethod.MXFP4,
QuantizationMethod.BITS_AND_BYTES,
}:
return True, None
else:
raise e
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
casting_dtype = None
Comment on lines +750 to +753
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't the methods do this in the quantize because they should!

is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
# dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name):
casting_dtype = model.config._pre_quantization_dtype
else:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"get_keys_to_not_convert",
"replace_with_bnb_linear",
"validate_bnb_backend_availability",
"Bnb4bitQuantize",
],
"deepspeed": [
"HfDeepSpeedConfig",
Expand Down Expand Up @@ -177,6 +178,7 @@
unpack_weights,
)
from .bitsandbytes import (
Bnb4bitQuantize,
dequantize_and_replace,
get_keys_to_not_convert,
replace_with_bnb_linear,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def _get_device_map(
if max_memory is not None and device_name in max_memory:
inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])

model.tie_weights()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model.tie_weights()
model.tie_weights() # TODO make sure to remove this later on if possible, does not make sense

device_map = infer_auto_device_map(
model,
max_memory=inferred_max_memory,
Expand Down
46 changes: 46 additions & 0 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect
from collections import defaultdict
from inspect import signature
from typing import Optional

from ..quantizers.quantizers_utils import get_module_from_name
from ..utils import (
get_available_devices,
is_accelerate_available,
Expand All @@ -26,6 +29,49 @@

logger = logging.get_logger(__name__)

from ..core_model_loading import ConversionOps


class Bnb4bitQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

full_name = target_key
# update param name to get the weights instead of the quantized stats
target_key = self.hf_quantizer.get_param_name(target_key)
module, _ = get_module_from_name(model, target_key)

if not self.hf_quantizer.pre_quantized:
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D):
value = value.T
old_value = model.get_parameter_or_buffer(target_key)
new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device)
return {target_key : new_value}
else:
module_name = target_key.rsplit(".", 1)[0]
# Save the states for later quantization when they are all gathered
if not hasattr(self.hf_quantizer, "param_quant_stats"):
self.hf_quantizer.param_quant_stats = defaultdict(dict)
Comment on lines +58 to +60
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are gathered from what? sorry I am not familiar with it, you need which states?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.

self.hf_quantizer.param_quant_stats[module_name].update({full_name: value})
# We are ready for quantization in this case (note, the +1 is for the weight itself)
if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1:
weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight")
new_value = bnb.nn.Params4bit.from_prequantized(
data=weight,
quantized_stats=self.hf_quantizer.param_quant_stats[module_name],
requires_grad=False,
device=value.device,
module=module
)
del self.hf_quantizer.param_quant_stats[module_name]
return {target_key : new_value}
return {}

def _replace_with_bnb_linear(
model,
Expand Down
27 changes: 9 additions & 18 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,37 +568,29 @@ def replace_with_fp8_linear(
)

return model


class QuantizationOp(ConversionOps):
"""Base class for quantization operations."""

pass


class Fp8Quantize(QuantizationOp):
class Fp8Quantize(ConversionOps):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""

reverse_op: type[ConversionOps]

def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
self.reverse_op = Fp8Dequantize

def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

# Resolve block size (support dict-like or attr-like quant_config)
block_size = None
if quant_config is not None:
if isinstance(quant_config, dict):
block_size = quant_config.get("weight_block_size")
if self.hf_quantizer.quantization_config is not None:
if isinstance(self.hf_quantizer.quantization_config, dict):
block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
else:
block_size = getattr(quant_config, "weight_block_size", None)
block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
if block_size is None:
block_size = (value.shape[-2], value.shape[-1])

Expand Down Expand Up @@ -655,8 +647,7 @@ def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) ->
scale_key: inv_scales,
}


class Fp8Dequantize(QuantizationOp):
class Fp8Dequantize(ConversionOps):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""

def __init__(self, block_size: Optional[tuple[int, int]] = None):
Expand Down
Loading
Loading