-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Fix bnb for the weights refactor #42043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor-weight-loading
Are you sure you want to change the base?
Changes from all commits
710b1ff
f72f96d
e341529
0e51dec
0f022b5
1dabb4c
0c2b667
c48e1ed
d223635
bdbc01a
399388d
acbeeae
e16da23
386e259
9788014
72eff97
d841a04
e235eed
e4df752
3e69622
a052513
db4fe31
9fa1b7a
ea5822d
5881d8e
3651460
00b0044
7d8df52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,12 +26,21 @@ | |
| from contextlib import contextmanager | ||
| from dataclasses import dataclass, field | ||
| from functools import partial | ||
| from types import MethodType | ||
| from typing import Any, Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer,DTensor,Replicate | ||
| from .utils import logging | ||
| from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer, DTensor, Replicate | ||
| from .quantizers import HfQuantizer | ||
| from .utils import is_torch_greater_or_equal, logging | ||
| from .utils.quantization_config import QuantizationMethod | ||
|
|
||
|
|
||
| _torch_distributed_available = torch.distributed.is_available() | ||
| _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") | ||
| if _is_dtensor_available: | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
|
|
||
| import itertools | ||
|
|
@@ -81,6 +90,21 @@ | |
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| str_to_torch_dtype = { | ||
| "BOOL": torch.bool, | ||
| "U8": torch.uint8, | ||
| "I8": torch.int8, | ||
| "I16": torch.int16, | ||
| "F16": torch.float16, | ||
| "BF16": torch.bfloat16, | ||
| "I32": torch.int32, | ||
| "F32": torch.float32, | ||
| "F64": torch.float64, | ||
| "I64": torch.int64, | ||
| "F8_E4M3": torch.float8_e4m3fn, | ||
| "F8_E5M2": torch.float8_e5m2, | ||
| } | ||
|
|
||
|
|
||
| def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str: | ||
| """ | ||
|
|
@@ -361,6 +385,7 @@ def data(self): | |
| @data.setter | ||
| def data(self, new): | ||
| pass | ||
|
|
||
| def __lt__(self, other): return torch.Tensor.__lt__(self, other) | ||
| def __le__(self, other): return torch.Tensor.__le__(self, other) | ||
| def __gt__(self, other): return torch.Tensor.__gt__(self, other) | ||
|
|
@@ -462,6 +487,7 @@ def set_param_for_module( | |
| missing_keys: MutableSet[str], | ||
| misc: MutableMapping[str, Any], | ||
| distributed_operation: Optional[TensorParallelLayer], | ||
| hf_quantizer, | ||
| ): | ||
| with log_to_misc(layer_name, misc, layer_name): | ||
| module_path, _, param_name = layer_name.rpartition(".") | ||
|
|
@@ -470,7 +496,6 @@ def set_param_for_module( | |
| param_value = param_value[0] | ||
| elif not isinstance(param_value, torch.nn.Parameter): | ||
| param_value = param_value[...] | ||
| param_value = param_value[0] if isinstance(param_value, list) else param_value[...] | ||
| ref = meta_model_state_dict.get(layer_name, empty_param) | ||
|
|
||
|
|
||
|
|
@@ -485,14 +510,18 @@ def set_param_for_module( | |
| shape=ref.size(), | ||
| stride=ref.stride(), | ||
| ) | ||
| if not use_dtensor: | ||
| if not use_dtensor: | ||
| # we convert to local | ||
| param_value = param_value.to_local() | ||
|
|
||
| if param_name not in module_obj._buffers: | ||
| param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) | ||
|
|
||
| # to skip any inplace method that modifies the param data | ||
| param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value) | ||
|
|
||
| if ref is not None and ref.shape != param_value.shape: | ||
| # skip mismatch for hf_quantizer for now | ||
| if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: | ||
| mismatch_keys.add((layer_name, param_value.shape, ref.shape)) | ||
| setattr(module_obj._parameters[param_name], "_is_hf_initialized", False) # Needs to be initialized | ||
| missing_keys.discard(layer_name) | ||
|
|
@@ -513,7 +542,7 @@ def convert_and_load_state_dict_in_model( | |
| state_dict, | ||
| weight_mapping, | ||
| tp_plan, | ||
| quantizer, | ||
| hf_quantizer, | ||
| dtype=None, | ||
| device_map=None, | ||
| dtype_plan=None, | ||
|
|
@@ -574,16 +603,19 @@ def convert_and_load_state_dict_in_model( | |
| empty_param = meta_model_state_dict.get(t) | ||
| # If it does not exist, it's unexpected | ||
| if empty_param is None: | ||
| unexpected_keys.add(t) | ||
| continue | ||
|
|
||
| if quantizer is not None and quantizer.param_needs_quantization(model, t): | ||
| if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer": | ||
| from .integrations.finegrained_fp8 import Fp8Quantize | ||
|
|
||
| converter.quantization_operation = Fp8Quantize() # TODO support other methods | ||
| if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto 😉 |
||
| pass | ||
| else: | ||
| raise ValueError("This quantization method is gonna be supported SOOOON") | ||
| unexpected_keys.add(t) | ||
| continue | ||
|
|
||
| if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t): | ||
| converter.quantization_operation = hf_quantizer.get_quantize_ops() | ||
| # TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype | ||
| k_dtype = tensor.get_dtype() | ||
| dtype = str_to_torch_dtype[k_dtype] | ||
| empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta") | ||
| _, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer) | ||
|
Comment on lines
+613
to
+618
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? is it because BNB needs say bf16 always? can you elaborate here because I don't upcast any of the parameters, they just have the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to infer the right
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but if |
||
| else: | ||
| _dtype = dtype | ||
| matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name) | ||
|
|
@@ -648,9 +680,7 @@ def convert_and_load_state_dict_in_model( | |
| if op := converter.quantization_operation: | ||
| with log_to_misc(layer_name, misc, op=op): | ||
| realized_value.update( | ||
| op.convert( | ||
| {k: realized_value.pop(k)}, quant_config=quantizer.quantization_config | ||
| ) | ||
| op.convert({k: realized_value.pop(k)}, model=model) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not a fan of passing the whole model!
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wish I could not do that but let's keep this for now |
||
| ) | ||
|
|
||
| for k, output_value in realized_value.items(): | ||
|
|
@@ -666,9 +696,10 @@ def convert_and_load_state_dict_in_model( | |
| missing_keys, | ||
| misc, | ||
| converter.distributed_operation, | ||
| hf_quantizer | ||
| ) | ||
| except SkipLayer: | ||
| continue | ||
| except Exception as e : | ||
| raise e | ||
| del group | ||
|
|
||
| # Update progress bar | ||
|
|
@@ -697,3 +728,34 @@ def revert_weight_conversion(model, state_dict): | |
| original_state_dict[key] = value | ||
| state_dict = original_state_dict | ||
| return state_dict | ||
|
|
||
| def _infer_parameter_dtype( | ||
| model: torch.nn.Module, | ||
| param_name: str, | ||
| empty_param: torch.Tensor, | ||
| hf_quantizer: Optional[HfQuantizer] = None, | ||
| ) -> tuple[bool, Optional[torch.dtype]]: | ||
| try: | ||
| old_param = model.get_parameter_or_buffer(param_name) | ||
| except Exception as e: | ||
| if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in { | ||
| QuantizationMethod.HQQ, | ||
| QuantizationMethod.QUARK, | ||
| QuantizationMethod.MXFP4, | ||
| QuantizationMethod.BITS_AND_BYTES, | ||
| }: | ||
| return True, None | ||
| else: | ||
| raise e | ||
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
| # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params | ||
| # in int/uint/bool and not cast them. | ||
| casting_dtype = None | ||
|
Comment on lines
+750
to
+753
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't the methods do this in the |
||
| is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn | ||
| if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: | ||
| # dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes | ||
| if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name): | ||
| casting_dtype = model.config._pre_quantization_dtype | ||
| else: | ||
| casting_dtype = old_param.dtype | ||
| return old_param is not None and old_param.is_contiguous(), casting_dtype | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -435,6 +435,7 @@ def _get_device_map( | |||||
| if max_memory is not None and device_name in max_memory: | ||||||
| inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name]) | ||||||
|
|
||||||
| model.tie_weights() | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| device_map = infer_auto_device_map( | ||||||
| model, | ||||||
| max_memory=inferred_max_memory, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| import inspect | ||
| from collections import defaultdict | ||
| from inspect import signature | ||
| from typing import Optional | ||
|
|
||
| from ..quantizers.quantizers_utils import get_module_from_name | ||
| from ..utils import ( | ||
| get_available_devices, | ||
| is_accelerate_available, | ||
|
|
@@ -26,6 +29,49 @@ | |
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| from ..core_model_loading import ConversionOps | ||
|
|
||
|
|
||
| class Bnb4bitQuantize(ConversionOps): | ||
| def __init__(self, hf_quantizer): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]: | ||
| target_key, value = tuple(input_dict.items())[0] | ||
| value = value[0] if isinstance(value, list) else value | ||
|
|
||
| full_name = target_key | ||
| # update param name to get the weights instead of the quantized stats | ||
| target_key = self.hf_quantizer.get_param_name(target_key) | ||
| module, _ = get_module_from_name(model, target_key) | ||
|
|
||
| if not self.hf_quantizer.pre_quantized: | ||
| # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. | ||
| # Since weights are saved in the correct "orientation", we skip transposing when loading. | ||
| if issubclass(module.source_cls, Conv1D): | ||
| value = value.T | ||
| old_value = model.get_parameter_or_buffer(target_key) | ||
| new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device) | ||
| return {target_key : new_value} | ||
| else: | ||
| module_name = target_key.rsplit(".", 1)[0] | ||
| # Save the states for later quantization when they are all gathered | ||
| if not hasattr(self.hf_quantizer, "param_quant_stats"): | ||
| self.hf_quantizer.param_quant_stats = defaultdict(dict) | ||
|
Comment on lines
+58
to
+60
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. they are gathered from what? sorry I am not familiar with it, you need which states?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically, we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor. |
||
| self.hf_quantizer.param_quant_stats[module_name].update({full_name: value}) | ||
| # We are ready for quantization in this case (note, the +1 is for the weight itself) | ||
| if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1: | ||
| weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight") | ||
| new_value = bnb.nn.Params4bit.from_prequantized( | ||
| data=weight, | ||
| quantized_stats=self.hf_quantizer.param_quant_stats[module_name], | ||
| requires_grad=False, | ||
| device=value.device, | ||
| module=module | ||
| ) | ||
| del self.hf_quantizer.param_quant_stats[module_name] | ||
| return {target_key : new_value} | ||
| return {} | ||
|
|
||
| def _replace_with_bnb_linear( | ||
| model, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the shape of the BnbLinear not correct? this I also don't think we want long term no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is because when we initialize the meta model with nn.Linear4bit, those don't have the right shape as the weights are not quantized yet. But yeah maybe we can fix this by overwriting the shape of the param when replacing the layers. In long term, we will remove this yes.