Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Nov 5, 2025

What does this PR do?

This PR fixes bnb support in the new weight loading logic.

Testing

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_name = "meta-llama/Llama-3.2-3B-Instruct"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")

#model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
# don't pass quantization_config

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map=0
)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, do_sample=False, max_new_tokens=1024)
print(tokenizer.decode(outputs[0]))
  • check why the memory is way too high when quantizing on the fly
  • bnb tests

@SunMarc SunMarc requested a review from MekkCyber November 5, 2025 16:29
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Niice! makes sense

@SunMarc SunMarc changed the title Fix bnb on the fly for the weights refactor Fix bnb for the weights refactor Nov 6, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .finegrained_fp8 import FP8Linear, Fp8Quantize, replace_with_fp8_linear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we removed it because it was importing with a jit derocator making it slow

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will import it directly from the correct file then

Comment on lines +53 to +55
# Save the states for later quantization when they are all gathered
if not hasattr(self.hf_quantizer, "param_quant_stats"):
self.hf_quantizer.param_quant_stats = defaultdict(dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are gathered from what? sorry I am not familiar with it, you need which states?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.

Comment on lines +426 to +431
def is_valid_unexpected_keys(self, k):
"""
Check if the keys is valid or not even if it is not in the state_dict of the meta model.
This is because the state dict of the model might change after quantization like for 4bit bnb
"""
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would love to avoid this, meaning can we make sure the 1st call to hf_quantizer.quantize_model just properly prepares the meta model?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more to take care of the case where we load the quantized checkpoint. I don't think there is a good way to fix this but let's keep this for now. We can think of fixing this later


if ref is not None and ref.shape != param_value.shape:
# skip mismatch for hf_quantizer for now
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the shape of the BnbLinear not correct? this I also don't think we want long term no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because when we initialize the meta model with nn.Linear4bit, those don't have the right shape as the weights are not quantized yet. But yeah maybe we can fix this by overwriting the shape of the param when replacing the layers. In long term, we will remove this yes.

from .integrations.finegrained_fp8 import Fp8Quantize

converter.quantization_operation = Fp8Quantize() # TODO support other methods
if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto 😉

Comment on lines +499 to +504
converter.quantization_operation = hf_quantizer.get_quantize_ops()
# TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype
k_dtype = tensor.get_dtype()
dtype = str_to_torch_dtype[k_dtype]
empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta")
_, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? is it because BNB needs say bf16 always? can you elaborate here because I don't upcast any of the parameters, they just have the _dtype on meta, and then get whatever was loaded from the weights

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to infer the right dtype for each values in the checkpoints:

  • some of the values are not parameters or buffers of the model so we shouldn't change the dtype
  • for some parameters / buffers, we should also keep the same dtype as the checkpoint (empty_param_checkpoint) because the _dtype on meta is not correct ... (fp16 instead of int8) . But this can be fixed potentially if we initialize the correct dtype. For bnb it should work but not sure for other method like torchao as the dtype is hard to infer from the beginning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if quantize we never change the dtype of the param, which is the source of truth

op.convert(
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
)
op.convert({k: realized_value.pop(k)}, model=model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of passing the whole model!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I could not do that but let's keep this for now

Comment on lines +635 to +638
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
casting_dtype = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't the methods do this in the quantize because they should!

@ArthurZucker ArthurZucker force-pushed the refactor-weight-loading branch from 28b620d to f692f4b Compare November 7, 2025 07:55
@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: finegrained_fp8

SunMarc and others added 3 commits November 7, 2025 14:21
Comment on lines +309 to +327
base_cls: an nn.Parameter subclass (or nn.Parameter)
Returns a new class that combines the base_cls with LoadedParameterMixin
"""
class LoadedParam(base_cls):
_inplace_methods = [
'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_',
'copy_', 'erfinv_', 'log_'
]
def __new__(cls, from_existing, **kwargs):
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
inst._original_param = from_existing
# Explicitly override all in-place methods per instance
for method_name in inst._inplace_methods:
setattr(inst, method_name, MethodType(inst._skip, inst))

return inst

def _skip(self, *args, **kwargs):
"""Helper to skip in-place operations."""
Copy link
Member Author

@SunMarc SunMarc Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified a bit to accomodate subclass of nn.parameters cc @ArthurZucker. Also if we are using this class, it means the param is initialized as you said so let's simplify everything

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep absolutely, ty

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants