Skip to content

Define warmup allocator for torchao quantization #37764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,45 @@ def is_serializable(self, safe_serialization=None) -> bool:
return False
return _is_torchao_serializable

def get_cuda_warm_up_factor(self):
"""
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
- A factor of 2 means we pre-allocate the full memory footprint of the model.
- A factor of 4 means we pre-allocate half of that, and so on

However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8)
That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the torch_dtype
not the actual bit-width of the quantized data.

To correct for this:
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

if size_digit == "4":
return 8
else:
return 4

# Original mapping for non-AOBaseConfig types
map_to_target_dtype = {
"int4_weight_only": 8,
"int8_weight_only": 4,
"int8_dynamic_activation_int8_weight": 4,
"autoquant": 4,
}

return map_to_target_dtype[self.quantization_config.quant_type]

@property
def is_trainable(self) -> bool:
supported_quant_types_for_training = [
Expand Down