-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Fix bnb for the weights refactor #42043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor-weight-loading
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Niice! makes sense
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| from .eetq import replace_with_eetq_linear | ||
| from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear | ||
| from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear | ||
| from .finegrained_fp8 import FP8Linear, Fp8Quantize, replace_with_fp8_linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we removed it because it was importing with a jit derocator making it slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will import it directly from the correct file then
| # Save the states for later quantization when they are all gathered | ||
| if not hasattr(self.hf_quantizer, "param_quant_stats"): | ||
| self.hf_quantizer.param_quant_stats = defaultdict(dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are gathered from what? sorry I am not familiar with it, you need which states?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
| def is_valid_unexpected_keys(self, k): | ||
| """ | ||
| Check if the keys is valid or not even if it is not in the state_dict of the meta model. | ||
| This is because the state dict of the model might change after quantization like for 4bit bnb | ||
| """ | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would love to avoid this, meaning can we make sure the 1st call to hf_quantizer.quantize_model just properly prepares the meta model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more to take care of the case where we load the quantized checkpoint. I don't think there is a good way to fix this but let's keep this for now. We can think of fixing this later
|
|
||
| if ref is not None and ref.shape != param_value.shape: | ||
| # skip mismatch for hf_quantizer for now | ||
| if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the shape of the BnbLinear not correct? this I also don't think we want long term no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is because when we initialize the meta model with nn.Linear4bit, those don't have the right shape as the weights are not quantized yet. But yeah maybe we can fix this by overwriting the shape of the param when replacing the layers. In long term, we will remove this yes.
| from .integrations.finegrained_fp8 import Fp8Quantize | ||
|
|
||
| converter.quantization_operation = Fp8Quantize() # TODO support other methods | ||
| if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto 😉
| converter.quantization_operation = hf_quantizer.get_quantize_ops() | ||
| # TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype | ||
| k_dtype = tensor.get_dtype() | ||
| dtype = str_to_torch_dtype[k_dtype] | ||
| empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta") | ||
| _, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? is it because BNB needs say bf16 always? can you elaborate here because I don't upcast any of the parameters, they just have the _dtype on meta, and then get whatever was loaded from the weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to infer the right dtype for each values in the checkpoints:
- some of the values are not parameters or buffers of the model so we shouldn't change the dtype
- for some parameters / buffers, we should also keep the same dtype as the checkpoint (
empty_param_checkpoint) because the_dtypeon meta is not correct ... (fp16 instead of int8) . But this can be fixed potentially if we initialize the correctdtype. For bnb it should work but not sure for other method like torchao as the dtype is hard to infer from the beginning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but if quantize we never change the dtype of the param, which is the source of truth
| op.convert( | ||
| {k: realized_value.pop(k)}, quant_config=quantizer.quantization_config | ||
| ) | ||
| op.convert({k: realized_value.pop(k)}, model=model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a fan of passing the whole model!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish I could not do that but let's keep this for now
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
| # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params | ||
| # in int/uint/bool and not cast them. | ||
| casting_dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't the methods do this in the quantize because they should!
28b620d to
f692f4b
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: finegrained_fp8 |
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
| base_cls: an nn.Parameter subclass (or nn.Parameter) | ||
| Returns a new class that combines the base_cls with LoadedParameterMixin | ||
| """ | ||
| class LoadedParam(base_cls): | ||
| _inplace_methods = [ | ||
| 'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_', | ||
| 'copy_', 'erfinv_', 'log_' | ||
| ] | ||
| def __new__(cls, from_existing, **kwargs): | ||
| inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__) | ||
| inst._original_param = from_existing | ||
| # Explicitly override all in-place methods per instance | ||
| for method_name in inst._inplace_methods: | ||
| setattr(inst, method_name, MethodType(inst._skip, inst)) | ||
|
|
||
| return inst | ||
|
|
||
| def _skip(self, *args, **kwargs): | ||
| """Helper to skip in-place operations.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplified a bit to accomodate subclass of nn.parameters cc @ArthurZucker. Also if we are using this class, it means the param is initialized as you said so let's simplify everything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep absolutely, ty
What does this PR do?
This PR fixes bnb support in the new weight loading logic.
Testing