-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Fix bnb for the weights refactor #42043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor-weight-loading
Are you sure you want to change the base?
Changes from 18 commits
710b1ff
f72f96d
e341529
0e51dec
0f022b5
1dabb4c
0c2b667
c48e1ed
d223635
bdbc01a
399388d
acbeeae
e16da23
386e259
9788014
72eff97
d841a04
e235eed
e4df752
3e69622
a052513
db4fe31
9fa1b7a
ea5822d
5881d8e
3651460
00b0044
7d8df52
5ce08ca
1a2b5ca
2f1b69c
3c2d946
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,17 +26,40 @@ | |
| from contextlib import contextmanager | ||
| from dataclasses import dataclass, field | ||
| from functools import partial | ||
| from types import MethodType | ||
| from typing import Any, Optional, Union | ||
|
|
||
| import torch | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
| from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer | ||
| from .utils import logging | ||
| from .quantizers import HfQuantizer | ||
| from .utils import is_torch_greater_or_equal, logging | ||
| from .utils.quantization_config import QuantizationMethod | ||
|
|
||
|
|
||
| _torch_distributed_available = torch.distributed.is_available() | ||
| _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") | ||
| if _is_dtensor_available: | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| str_to_torch_dtype = { | ||
| "BOOL": torch.bool, | ||
| "U8": torch.uint8, | ||
| "I8": torch.int8, | ||
| "I16": torch.int16, | ||
| "F16": torch.float16, | ||
| "BF16": torch.bfloat16, | ||
| "I32": torch.int32, | ||
| "F32": torch.float32, | ||
| "F64": torch.float64, | ||
| "I64": torch.int64, | ||
| "F8_E4M3": torch.float8_e4m3fn, | ||
| "F8_E5M2": torch.float8_e5m2, | ||
| } | ||
|
|
||
|
|
||
| def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str: | ||
| """ | ||
|
|
@@ -280,85 +303,59 @@ class ConversionEntry: | |
| GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4 | ||
|
|
||
|
|
||
| class LoadedParameter(torch.nn.Parameter): | ||
| r""" | ||
| Because `transformers` initialized the missing keys we need to make sure | ||
| we can skip the ones that are actually loaded. Now we could force something, but | ||
| we want people to have an intuitive API usage, thus they can keep the well know API, and | ||
| just define their custom `_init_weight`, as long as they don't use `module.xxx.data`. | ||
|
|
||
| We added a check for this in `make fixup` to force people to use it. | ||
| After the `missing` weights are initialized, LoadedParameters become just nn.Parameters. | ||
| # Factory function to create LoadedParameter subclasses dynamically | ||
| def get_loaded_parameter_class(base_cls): | ||
| """ | ||
|
|
||
| def __new__(cls, data=None, requires_grad=True): | ||
| inst = super().__new__(cls, data, requires_grad) | ||
| inst._is_hf_initialized = False | ||
| return inst | ||
|
|
||
| def __repr__(self): | ||
| return f"LoadedParameter(_is_hf_initialized={self._is_hf_initialized}, data={self.data}" | ||
| # block .data assignment when flagged | ||
| @property | ||
| def data(self): | ||
| return super().data | ||
|
|
||
| @data.setter | ||
| def data(self, new): | ||
| if not getattr(self, "_is_hf_initialized", False): | ||
| super(LoadedParameter, LoadedParameter).data.__set__(self, new) # delegate to base | ||
| # else: skip or warn | ||
|
|
||
| # shadow common in-place init methods | ||
| def _guard(self, fn, *a, **k): | ||
| if getattr(self, "_is_hf_initialized", False): | ||
| base_cls: an nn.Parameter subclass (or nn.Parameter) | ||
| Returns a new class that combines the base_cls with LoadedParameterMixin | ||
| """ | ||
| class LoadedParam(base_cls): | ||
| _inplace_methods = [ | ||
| 'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_', | ||
| 'copy_', 'erfinv_', 'log_' | ||
| ] | ||
| def __new__(cls, from_existing, **kwargs): | ||
| inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__) | ||
| inst._original_param = from_existing | ||
| # Explicitly override all in-place methods per instance | ||
| for method_name in inst._inplace_methods: | ||
| setattr(inst, method_name, MethodType(inst._skip, inst)) | ||
|
|
||
| return inst | ||
|
|
||
| def _skip(self, *args, **kwargs): | ||
| """Helper to skip in-place operations.""" | ||
| return self | ||
| return fn(*a, **k) | ||
|
|
||
| def normal_(self, *a, **k): | ||
| return self._guard(super().normal_, *a, **k) | ||
|
|
||
| def uniform_(self, *a, **k): | ||
| return self._guard(super().uniform_, *a, **k) | ||
| def __repr__(self): | ||
| return f"LoadedParameter(data={self.data})" | ||
|
|
||
| def zero_(self): | ||
| return self._guard(super().zero_) | ||
| @property | ||
| def data(self): | ||
| return super().data | ||
|
|
||
| def fill_(self, *a, **k): | ||
| return self._guard(super().fill_, *a, **k) | ||
| @data.setter | ||
| def data(self, new): | ||
| pass | ||
|
|
||
| def copy_(self, *a, **k): | ||
| return self._guard(super().copy_, *a, **k) | ||
| return LoadedParam | ||
|
|
||
| def mul_(self, *a, **k): | ||
| return self._guard(super().copy_, *a, **k) | ||
|
|
||
| def add_(self, *a, **k): | ||
| return self._guard(super().copy_, *a, **k) | ||
|
|
||
| def clamp_(self, *a, **k): | ||
| return self._guard(super().copy_, *a, **k) | ||
|
|
||
| def erfinv_(self, *a, **k): | ||
| return self._guard(super().copy_, *a, **k) | ||
|
|
||
| def log_(self, *a, **k): | ||
| return self._guard(super().copy_, *a, **k) | ||
|
|
||
|
|
||
| def _materialize_copy(tensor, dtype): | ||
| def _materialize_copy(tensor, dtype=None): | ||
| # PyTorch: this runs in C and releases the GIL; good for threads. | ||
| return tensor[...].to(dtype) | ||
| tensor = tensor[...] | ||
| if dtype is not None: | ||
| tensor = tensor.to(dtype) | ||
| return tensor | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def spawn_materialize(thread_pool, tensor, dtype) -> Future: | ||
| def spawn_materialize(thread_pool, tensor, dtype=None) -> Future: | ||
| def _job(): | ||
| return _materialize_copy(tensor, dtype) | ||
|
|
||
| return thread_pool.submit(_job) | ||
|
|
||
|
|
||
| def spawn_tp_materialize(thread_pool, tensor, dtype, sharding_method, tensor_idx) -> Future: | ||
| def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future: | ||
| def _job(): | ||
| return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0] | ||
|
|
||
|
|
@@ -423,11 +420,17 @@ def set_param_for_module( | |
| missing_keys: MutableSet[str], | ||
| misc: MutableMapping[str, Any], | ||
| distributed_operation: Optional[TensorParallelLayer], | ||
| hf_quantizer, | ||
| ): | ||
| with log_to_misc(layer_name, misc, layer_name): | ||
| module_path, _, param_name = layer_name.rpartition(".") | ||
| module_obj = model.get_submodule(module_path) if module_path else model | ||
| param_value = param_value[0] if isinstance(param_value, list) else param_value[...] | ||
| if isinstance(param_value, list): | ||
| param_value = param_value[0] | ||
| elif isinstance(param_value, torch.nn.Parameter): | ||
| pass | ||
|
||
| else: | ||
| param_value = param_value[...] | ||
| ref = meta_model_state_dict.get(layer_name, empty_param) | ||
| use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor | ||
| if not isinstance(param_value, torch.nn.Parameter): | ||
|
|
@@ -442,13 +445,16 @@ def set_param_for_module( | |
| ) | ||
| else: | ||
| pass # TODO for "local" stuff, it will trigger missmatched no? | ||
| param_value: LoadedParameter = LoadedParameter(param_value, requires_grad=param_value.is_floating_point()) | ||
| else: | ||
| param_value: LoadedParameter = LoadedParameter(param_value.data) | ||
| param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) | ||
|
|
||
| # to skip any inplace method that modifies the param data | ||
| param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value) | ||
|
|
||
| if ref is not None and ref.shape != param_value.shape: | ||
| # skip mismatch for hf_quantizer for now | ||
| if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the shape of the BnbLinear not correct? this I also don't think we want long term no?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is because when we initialize the meta model with nn.Linear4bit, those don't have the right shape as the weights are not quantized yet. But yeah maybe we can fix this by overwriting the shape of the param when replacing the layers. In long term, we will remove this yes. |
||
| mismatch_keys.add((layer_name, param_value.shape, ref.shape)) | ||
| missing_keys.discard(layer_name) | ||
| # maybe not needed anymore, but let's keep this as there are still code related to that | ||
| param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing | ||
| setattr(module_obj, param_name, param_value) | ||
|
|
||
|
|
@@ -464,7 +470,7 @@ def convert_and_load_state_dict_in_model( | |
| state_dict, | ||
| weight_mapping, | ||
| tp_plan, | ||
| quantizer, | ||
| hf_quantizer, | ||
| dtype=None, | ||
| device_map=None, | ||
| dtype_plan=None, | ||
|
|
@@ -525,16 +531,19 @@ def convert_and_load_state_dict_in_model( | |
| empty_param = meta_model_state_dict.get(t) | ||
| # If it does not exist, it's unexpected | ||
| if empty_param is None: | ||
| unexpected_keys.add(t) | ||
| continue | ||
|
|
||
| if quantizer is not None and quantizer.param_needs_quantization(model, t): | ||
| if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer": | ||
| from .integrations.finegrained_fp8 import Fp8Quantize | ||
|
|
||
| converter.quantization_operation = Fp8Quantize() # TODO support other methods | ||
| if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t): | ||
|
||
| pass | ||
| else: | ||
| raise ValueError("This quantization method is gonna be supported SOOOON") | ||
| unexpected_keys.add(t) | ||
| continue | ||
|
|
||
| if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t): | ||
| converter.quantization_operation = hf_quantizer.get_quantize_ops() | ||
| # TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype | ||
| k_dtype = tensor.get_dtype() | ||
| dtype = str_to_torch_dtype[k_dtype] | ||
| empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta") | ||
| _, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer) | ||
|
Comment on lines
+639
to
+644
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? is it because BNB needs say bf16 always? can you elaborate here because I don't upcast any of the parameters, they just have the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to infer the right
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but if |
||
| else: | ||
| _dtype = dtype | ||
| matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name) | ||
|
|
@@ -599,9 +608,7 @@ def convert_and_load_state_dict_in_model( | |
| if op := converter.quantization_operation: | ||
| with log_to_misc(layer_name, misc, op=op): | ||
| realized_value.update( | ||
| op.convert( | ||
| {k: realized_value.pop(k)}, quant_config=quantizer.quantization_config | ||
| ) | ||
| op.convert({k: realized_value.pop(k)}, model=model) | ||
|
||
| ) | ||
|
|
||
| for k, output_value in realized_value.items(): | ||
|
|
@@ -617,9 +624,10 @@ def convert_and_load_state_dict_in_model( | |
| missing_keys, | ||
| misc, | ||
| converter.distributed_operation, | ||
| hf_quantizer | ||
| ) | ||
| except SkipLayer: | ||
| continue | ||
| except Exception as e : | ||
| raise e | ||
| del group | ||
|
|
||
| # Update progress bar | ||
|
|
@@ -648,3 +656,34 @@ def revert_weight_conversion(model, state_dict): | |
| original_state_dict[key] = value | ||
| state_dict = original_state_dict | ||
| return state_dict | ||
|
|
||
| def _infer_parameter_dtype( | ||
| model: torch.nn.Module, | ||
| param_name: str, | ||
| empty_param: torch.Tensor, | ||
| hf_quantizer: Optional[HfQuantizer] = None, | ||
| ) -> tuple[bool, Optional[torch.dtype]]: | ||
| try: | ||
| old_param = model.get_parameter_or_buffer(param_name) | ||
| except Exception as e: | ||
| if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in { | ||
| QuantizationMethod.HQQ, | ||
| QuantizationMethod.QUARK, | ||
| QuantizationMethod.MXFP4, | ||
| QuantizationMethod.BITS_AND_BYTES, | ||
| }: | ||
| return True, None | ||
| else: | ||
| raise e | ||
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
| # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params | ||
| # in int/uint/bool and not cast them. | ||
| casting_dtype = None | ||
|
Comment on lines
+774
to
+777
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't the methods do this in the |
||
| is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn | ||
| if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: | ||
| # dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes | ||
| if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name): | ||
| casting_dtype = model.config._pre_quantization_dtype | ||
| else: | ||
| casting_dtype = old_param.dtype | ||
| return old_param is not None and old_param.is_contiguous(), casting_dtype | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| import inspect | ||
| from collections import defaultdict | ||
| from inspect import signature | ||
| from typing import Optional | ||
|
|
||
| from ..quantizers.quantizers_utils import get_module_from_name | ||
| from ..utils import ( | ||
| get_available_devices, | ||
| is_accelerate_available, | ||
|
|
@@ -26,6 +29,44 @@ | |
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| from ..core_model_loading import ConversionOps | ||
|
|
||
|
|
||
| class Bnb4bitQuantize(ConversionOps): | ||
| def __init__(self, hf_quantizer): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]: | ||
| target_key, value = tuple(input_dict.items())[0] | ||
| value = value[0] if isinstance(value, list) else value | ||
| if not self.hf_quantizer.pre_quantized: | ||
| old_value = model.get_parameter_or_buffer(target_key) | ||
| new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device) | ||
| return {target_key : new_value} | ||
| else: | ||
| full_name = target_key | ||
| # update param name to get the weights instead of the quantized stats | ||
| target_key = self.hf_quantizer.get_param_name(target_key) | ||
| module, _ = get_module_from_name(model, target_key) | ||
|
|
||
| module_name = target_key.rsplit(".", 1)[0] | ||
| # Save the states for later quantization when they are all gathered | ||
| if not hasattr(self.hf_quantizer, "param_quant_stats"): | ||
| self.hf_quantizer.param_quant_stats = defaultdict(dict) | ||
|
Comment on lines
+61
to
+63
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. they are gathered from what? sorry I am not familiar with it, you need which states?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically, we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor. |
||
| self.hf_quantizer.param_quant_stats[module_name].update({full_name: value}) | ||
| # We are ready for quantization in this case (note, the +1 is for the weight itself) | ||
| if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1: | ||
| weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight") | ||
| new_value = bnb.nn.Params4bit.from_prequantized( | ||
| data=weight, | ||
| quantized_stats=self.hf_quantizer.param_quant_stats[module_name], | ||
| requires_grad=False, | ||
| device=value.device, | ||
| module=module | ||
| ) | ||
| del self.hf_quantizer.param_quant_stats[module_name] | ||
| return {target_key : new_value} | ||
| return {} | ||
|
|
||
| def _replace_with_bnb_linear( | ||
| model, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplified a bit to accomodate subclass of nn.parameters cc @ArthurZucker. Also if we are using this class, it means the param is initialized as you said so let's simplify everything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep absolutely, ty