From f618e846a81e0dc170dea20a0322f170587bfd24 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 10 Feb 2025 23:34:58 +0000 Subject: [PATCH 1/4] update --- .../quantization/lifecycle/apply.py | 1 - .../quantization/lifecycle/forward.py | 5 ++- .../quantization/lifecycle/initialize.py | 38 ++++++++++++++++--- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 31f14df0..ed71a041 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -152,7 +152,6 @@ def apply_quantization_config( continue # layer matches ignore list, continue targets = find_name_or_class_matches(name, submodule, target_to_scheme) - if targets: # mark modules to be quantized by adding # quant scheme to the matching layers diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f4f93f27..7f1a88f0 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -269,6 +269,7 @@ def wrapped_forward(self, *args, **kwargs): # forward call return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) + #breakpoint() input_ = args[0] compressed = module.quantization_status == QuantizationStatus.COMPRESSED @@ -288,7 +289,8 @@ def wrapped_forward(self, *args, **kwargs): output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs ) - + #breakpoint() + # restore back to unquantized_value if scheme.weights is not None and not compressed: self.weight.data = unquantized_weight @@ -304,6 +306,7 @@ def wrapped_forward(self, *args, **kwargs): output = forward_quantize( module, output, "output", scheme.output_activations ) + print("running output qdq") return output # bind wrapped forward to module class so reference to `self` is correct diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8dd8fc51..2365678d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -77,7 +77,7 @@ def initialize_module_for_quantization( if is_attention_module(module): # quantized actions based on calltime status - _initialize_attn_scales(module) + _initialize_attn_scales(module, quant_args=scheme.output_activations) else: @@ -109,8 +109,11 @@ def initialize_module_for_quantization( if scheme.output_activations is not None: if not is_kv_cache_quant_scheme(scheme): + weight_shape = None + if isinstance(module, torch.nn.Linear) and hasattr(module, "weight"): + weight_shape = module.weight.shape _initialize_scale_zero_point( - module, "output", scheme.output_activations + module, "output", scheme.output_activations, weight_shape=weight_shape ) module.quantization_scheme = scheme @@ -152,7 +155,12 @@ def _initialize_scale_zero_point( else: expected_shape = 1 - if base_name == "weight" and weight_shape is not None: + if quantization_args.strategy in (QuantizationStrategy.CHANNEL, QuantizationStrategy.GROUP): + assert weight_shape is not None + # only supported atm for weight quant, output_activations + assert base_name in ("weight", "output") + print("weight_shape", weight_shape) + if quantization_args.strategy == QuantizationStrategy.CHANNEL: # (output_channels, 1) expected_shape = (weight_shape[0], 1) @@ -160,6 +168,10 @@ def _initialize_scale_zero_point( num_groups = weight_shape[1] // quantization_args.group_size expected_shape = (weight_shape[0], max(num_groups, 1)) + if base_name == "output": + expected_shape = tuple(reversed(expected_shape)) + print("expected_shape", base_name, expected_shape) + scale_dtype = module.weight.dtype if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: scale_dtype = torch.float16 @@ -190,10 +202,26 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module) -> None: +def _initialize_attn_scales(module: Module, quant_args: QuantizationArgs) -> None: """Initlaize k_scale, v_scale for self_attn""" - expected_shape = 1 # per tensor + strategy = quant_args.strategy + # (1024, 1) - v_scale, + # (1024, 1) - q_scale + + """ + - We want channelwise scales for on output activations of k/q/v + - Requires num_heads_k scales (num_key_value_heads scales?) + - So the expected shape is (1024, 1) --> (8, 128) where 8 is the number of kv heads + """ + + # we get access to this after the k_proj and v_proj have been restructured already (reshaped) + if strategy == QuantizationStrategy.CHANNEL: + #expected_shape = (module.v_proj.weight.shape[0], 1) + #expected_shape = (8, 128) # num heads * head_dim + expected_shape = 1 + else: + expected_shape = 1 # per tensor param = next(module.parameters()) scale_dtype = param.dtype From b9b28467ea0985fe9f9a8304353755c2b8ab10f3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 11 Feb 2025 17:47:41 +0000 Subject: [PATCH 2/4] update --- .../quantization/lifecycle/apply.py | 3 +++ .../quantization/lifecycle/initialize.py | 17 +++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ed71a041..feec19ce 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -129,6 +129,9 @@ def apply_quantization_config( target_to_scheme = OrderedDict() config = process_quantization_config(config) names_to_scheme = OrderedDict() + + # have to consolidate attn quant with weight/input quant + # currently treats them as mutually exclusive for scheme in config.config_groups.values(): for target in scheme.targets: target_to_scheme[target] = scheme diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2365678d..5753d4e2 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -159,18 +159,23 @@ def _initialize_scale_zero_point( assert weight_shape is not None # only supported atm for weight quant, output_activations assert base_name in ("weight", "output") - print("weight_shape", weight_shape) + #print("weight_shape", weight_shape) if quantization_args.strategy == QuantizationStrategy.CHANNEL: # (output_channels, 1) expected_shape = (weight_shape[0], 1) + + if base_name == "output": + expected_shape = tuple(reversed(expected_shape)) + elif quantization_args.strategy == QuantizationStrategy.GROUP: - num_groups = weight_shape[1] // quantization_args.group_size - expected_shape = (weight_shape[0], max(num_groups, 1)) + #num_groups = weight_shape[1] // quantization_args.group_size + #expected_shape = (weight_shape[0], max(num_groups, 1)) + + num_groups = weight_shape[0] // quantization_args.group_size + expected_shape = (1, num_groups) - if base_name == "output": - expected_shape = tuple(reversed(expected_shape)) - print("expected_shape", base_name, expected_shape) + #print("expected_shape", base_name, expected_shape) scale_dtype = module.weight.dtype if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: From 6f64b382f6909419d05ee2992e2e8be18cb25740 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 6 Mar 2025 23:20:04 -0500 Subject: [PATCH 3/4] channel wise fp8 attn --- .../quantization/lifecycle/forward.py | 6 +-- .../quantization/lifecycle/initialize.py | 52 ++++++------------- 2 files changed, 20 insertions(+), 38 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 7f1a88f0..c26cf54f 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -269,7 +269,7 @@ def wrapped_forward(self, *args, **kwargs): # forward call return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) - #breakpoint() + # breakpoint() input_ = args[0] compressed = module.quantization_status == QuantizationStatus.COMPRESSED @@ -289,8 +289,8 @@ def wrapped_forward(self, *args, **kwargs): output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs ) - #breakpoint() - + # breakpoint() + # restore back to unquantized_value if scheme.weights is not None and not compressed: self.weight.data = unquantized_weight diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 639de110..9eb4aa5f 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -77,10 +77,9 @@ def initialize_module_for_quantization( if is_attention_module(module): # quantized actions based on calltime status - _initialize_attn_scales(module, quant_args=scheme.output_activations) + _initialize_attn_scales(module) else: - if scheme.input_activations is not None: _initialize_scale_zero_point( module, @@ -113,7 +112,10 @@ def initialize_module_for_quantization( if isinstance(module, torch.nn.Linear) and hasattr(module, "weight"): weight_shape = module.weight.shape _initialize_scale_zero_point( - module, "output", scheme.output_activations, weight_shape=weight_shape + module, + "output", + scheme.output_activations, + weight_shape=weight_shape, ) module.quantization_scheme = scheme @@ -155,27 +157,20 @@ def _initialize_scale_zero_point( else: expected_shape = 1 - if quantization_args.strategy in (QuantizationStrategy.CHANNEL, QuantizationStrategy.GROUP): - assert weight_shape is not None - # only supported atm for weight quant, output_activations - assert base_name in ("weight", "output") - #print("weight_shape", weight_shape) + if base_name == "weight" and weight_shape is not None: if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) expected_shape = (weight_shape[0], 1) - if base_name == "output": - expected_shape = tuple(reversed(expected_shape)) - elif quantization_args.strategy == QuantizationStrategy.GROUP: - #num_groups = weight_shape[1] // quantization_args.group_size - #expected_shape = (weight_shape[0], max(num_groups, 1)) + num_groups = weight_shape[1] // quantization_args.group_size + expected_shape = (weight_shape[0], max(num_groups, 1)) - num_groups = weight_shape[0] // quantization_args.group_size - expected_shape = (1, num_groups) + if base_name == "output" and weight_shape is not None: + if quantization_args.strategy == QuantizationStrategy.CHANNEL: + expected_shape = weight_shape[0] - #print("expected_shape", base_name, expected_shape) + # TODO: add support for output activations scale_dtype = module.weight.dtype if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: @@ -207,26 +202,13 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module, quant_args: QuantizationArgs) -> None: +def _initialize_attn_scales( + module: Module, +) -> None: """Initlaize k_scale, v_scale for self_attn""" - strategy = quant_args.strategy - # (1024, 1) - v_scale, - # (1024, 1) - q_scale - - """ - - We want channelwise scales for on output activations of k/q/v - - Requires num_heads_k scales (num_key_value_heads scales?) - - So the expected shape is (1024, 1) --> (8, 128) where 8 is the number of kv heads - """ - - # we get access to this after the k_proj and v_proj have been restructured already (reshaped) - if strategy == QuantizationStrategy.CHANNEL: - #expected_shape = (module.v_proj.weight.shape[0], 1) - #expected_shape = (8, 128) # num heads * head_dim - expected_shape = 1 - else: - expected_shape = 1 # per tensor + # per token for each layer + expected_shape = 1 param = next(module.parameters()) scale_dtype = param.dtype From 6fb81baa947f77141cec91cbd20a19131c4d9b78 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 7 Mar 2025 07:09:34 -0500 Subject: [PATCH 4/4] remove unnec comments --- src/compressed_tensors/quantization/lifecycle/apply.py | 2 -- src/compressed_tensors/quantization/lifecycle/forward.py | 3 --- src/compressed_tensors/quantization/lifecycle/initialize.py | 2 -- 3 files changed, 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 09ece647..e688e500 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -130,8 +130,6 @@ def apply_quantization_config( config = process_quantization_config(config) names_to_scheme = OrderedDict() - # have to consolidate attn quant with weight/input quant - # currently treats them as mutually exclusive for scheme in config.config_groups.values(): for target in scheme.targets: target_to_scheme[target] = scheme diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index c26cf54f..f4f93f27 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -269,7 +269,6 @@ def wrapped_forward(self, *args, **kwargs): # forward call return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) - # breakpoint() input_ = args[0] compressed = module.quantization_status == QuantizationStatus.COMPRESSED @@ -289,7 +288,6 @@ def wrapped_forward(self, *args, **kwargs): output = forward_func_orig.__get__(module, module.__class__)( input_, *args[1:], **kwargs ) - # breakpoint() # restore back to unquantized_value if scheme.weights is not None and not compressed: @@ -306,7 +304,6 @@ def wrapped_forward(self, *args, **kwargs): output = forward_quantize( module, output, "output", scheme.output_activations ) - print("running output qdq") return output # bind wrapped forward to module class so reference to `self` is correct diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 9eb4aa5f..8ec01dab 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -170,8 +170,6 @@ def _initialize_scale_zero_point( if quantization_args.strategy == QuantizationStrategy.CHANNEL: expected_shape = weight_shape[0] - # TODO: add support for output activations - scale_dtype = module.weight.dtype if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: scale_dtype = torch.float16