Skip to content

Commit e07ab9a

Browse files
MekkCyberSunMarc
authored andcommitted
Define warmup allocator for torchao quantization (huggingface#37764)
* torchao allocator * add comment --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent 413927e commit e07ab9a

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,45 @@ def is_serializable(self, safe_serialization=None) -> bool:
277277
return False
278278
return _is_torchao_serializable
279279

280+
def get_cuda_warm_up_factor(self):
281+
"""
282+
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
283+
- A factor of 2 means we pre-allocate the full memory footprint of the model.
284+
- A factor of 4 means we pre-allocate half of that, and so on
285+
286+
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8)
287+
That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the torch_dtype
288+
not the actual bit-width of the quantized data.
289+
290+
To correct for this:
291+
- Use a division factor of 8 for int4 weights
292+
- Use a division factor of 4 for int8 weights
293+
"""
294+
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
295+
from torchao.core.config import AOBaseConfig
296+
297+
quant_type = self.quantization_config.quant_type
298+
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
299+
if isinstance(quant_type, AOBaseConfig):
300+
# Extract size digit using fuzzy match on the class name
301+
config_name = quant_type.__class__.__name__
302+
size_digit = fuzzy_match_size(config_name)
303+
304+
if size_digit == "4":
305+
return 8
306+
else:
307+
return 4
308+
309+
# Original mapping for non-AOBaseConfig types
310+
map_to_target_dtype = {
311+
"int4_weight_only": 8,
312+
"int8_weight_only": 4,
313+
"int8_dynamic_activation_int8_weight": 4,
314+
"autoquant": 4,
315+
}
316+
317+
return map_to_target_dtype[self.quantization_config.quant_type]
318+
280319
@property
281320
def is_trainable(self) -> bool:
282321
supported_quant_types_for_training = [

0 commit comments

Comments
 (0)