diff --git a/auto_round/autoround.py b/auto_round/autoround.py index d49b7bf16..6a11b075f 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -537,24 +537,6 @@ def parse_format_to_list(self, format: str) -> list: self.scale_dtype = torch.float32 logger.info(f"change `scale_dtype` to `torch.float32`") - # only support to export afp8 - if self.act_bits <= 8: - if "fp8" not in self.act_data_type: - if len(formats) > 1 or "fake" not in formats: - logger.warning( - f"Currently only support to export auto_round format quantized model" - " with fp8 dtype activation for activation quantization." - " Change format to fake and save." - ) - formats = ["fake"] - else: - if len(formats) > 1 or "auto_round" not in formats: - logger.warning( - f"Currently only support to export auto_round format for W{self.bits}AFP8 model," - " change format to auto_round" - ) - formats = ["auto_round"] - # Adjust format settings based on compatibility for index in range(len(formats)): format = formats[index] @@ -579,8 +561,9 @@ def remove_duplicates(lst): return [x for x in lst if not (x in seen or seen.add(x))] formats = remove_duplicates(formats) - for format in formats: - self._check_supported_format(format) + for i in range(len(formats)): + formats[i] = self._check_supported_format(formats[i]) + formats = remove_duplicates(formats) return formats def _check_supported_format(self, format: str) -> bool: @@ -615,7 +598,7 @@ def _check_supported_format(self, format: str) -> bool: ) format = "fake" else: - if format != "auto_round": + if not (format == "auto_round" or format == "auto_round:fp8"): logger.warning( f"Currently only support to export auto_round format for static W{self.bits}AFP8 model," " change format to auto_round" @@ -629,6 +612,7 @@ def _check_supported_format(self, format: str) -> bool: ) sys.exit(-1) + return format def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace=True, **kwargs): """Quantizes the model and saves it in the specified format(s). @@ -1107,7 +1091,6 @@ def quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: """ if self.amp: self.model.to(self.amp_dtype) - self.model.to("cpu") all_to_quantized_module_names: list[str] = [ n for n, m in self.model.named_modules() if check_to_quantized(m) @@ -1117,8 +1100,26 @@ def quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.quantize_embedding_layer() + self.model.to("cpu") if has_gguf_k and not self.disable_opt_rtn: self.quant_rtn_with_imatrix(all_to_quantized_module_names) + elif self.act_bits <=8 and self.act_dynamic is False: + hook_handles = self.register_act_max_hook(self.model) + try: + self.quantize_via_rtn_blockwise(all_to_quantized_module_names) + except RuntimeError as e: + logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.") + self.model = self.model.to("cpu") + clear_memory() + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + import accelerate + accelerate.hooks.remove_hook_from_submodules(self.model) + orig_device = self.device + self.device = "cpu" + self.quantize_via_rtn_blockwise(all_to_quantized_module_names) + self.device = orig_device + for handle in hook_handles: + handle.remove() else: block_names_cnt = len(flatten_list(get_block_names(self.model, True))) clear_mem_freq = len(all_to_quantized_module_names) // block_names_cnt @@ -1200,6 +1201,8 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - continue hook = AlignDevicesHook(m.tuning_device, io_same_device=True) add_hook_to_module(m, hook, True) + else: + block = block.to(self.device) input_ids = self.get_block_outputs( block, @@ -1209,7 +1212,6 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - self.device, self.cache_device, ) - if self.device_map is not None: accelerate.hooks.remove_hook_from_submodules(block) @@ -2092,9 +2094,20 @@ def get_act_max_hook(module, input, output): hook_handles = [] for n, m in model.named_modules(): + # for block if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m): hook = m.register_forward_hook(get_act_max_hook) hook_handles.append(hook) + continue + + # for whole model, RTN + if n in self.layer_config: + config = self.layer_config[n] + if config["bits"] <= 8 and "act_dynamic" in config and config[ + "act_dynamic"] is False and check_to_quantized(config): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + continue return hook_handles def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")): @@ -2420,7 +2433,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k Returns: object: The compressed model object. """ - self._check_supported_format(format) + format = self._check_supported_format(format) if self.low_cpu_mem_usage: self.model = self.model.to('cpu') diff --git a/auto_round/data_type/fp8.py b/auto_round/data_type/fp8.py index 91de8b187..70873e6ed 100644 --- a/auto_round/data_type/fp8.py +++ b/auto_round/data_type/fp8.py @@ -18,46 +18,6 @@ from auto_round.data_type.register import register_dtype -@register_dtype("fp8_dynamic_per_token_sym") -def fp8_dynamic_per_token_sym(tensor, max_scale=1.0, **kwargs): - """Dynamic per-token symmetric quantization using float8. - - This function dynamically calculates a per-token scaling factor for each group of tokens - and applies symmetric quantization using float8 format. - - Args: - tensor (torch.Tensor): Input tensor to quantize. - max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. - **kwargs: Additional arguments for compatibility. - - Returns: - tuple: - - Quantized and dequantized tensor (torch.Tensor). - - Scale tensor used for quantization (torch.Tensor). - - Placeholder for zp (None). - """ - orig_shape = tensor.shape - info = torch.finfo(torch.float8_e4m3fn) - orig_dtype = tensor.dtype - - tensor = tensor.reshape(-1, orig_shape[-1]) - max_tensor = torch.max(torch.abs(tensor), dim=-1)[ - 0] * max_scale - - scale = max_tensor.to(torch.float32) / info.max - min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm - scale = torch.clip(scale, min=min_scaling_factor) - if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 - tensor = tensor.to(torch.bfloat16) - scale = scale.unsqueeze(dim=-1) - fp8_res = (tensor / scale) - fp8_res = torch.clip(fp8_res, info.min, info.max) - fp8_res = float8_e4m3fn_ste(fp8_res) - qdq_res = fp8_res * scale - qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) - return qdq_res, scale, None - - @register_dtype(("fp8_sym","fp8","fp8_e4m3")) def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0,**kwargs): """Symmetric quantization using float8 format. @@ -79,6 +39,10 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0,**k info = torch.finfo(torch.float8_e4m3fn) orig_dtype = tensor.dtype tensor,orig_shape,pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + if isinstance(max_scale, torch.Tensor): + max_scale = max_scale.to(tensor.device) + if isinstance(v, torch.Tensor): + v = v.to(tensor.device) if tensor_max is None: ##dynamic per-token max_tensor = torch.max(torch.abs(tensor), dim=-1)[ 0] * max_scale diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 7a78a2f83..10da020a2 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -158,13 +158,13 @@ def pack_layer(layer_name, model, backend): if not isinstance(layer, SUPPORTED_LAYER_TYPES): ##already packed return - if int(layer.act_bits) <= 8: - return pack_qact_layer(layer_name, model) - if "fp8" in backend: from auto_round.export.export_to_autoround.export_to_fp8_woq import pack_layer return pack_layer(layer_name,model,backend) + if int(layer.act_bits) <= 8: + return pack_qact_layer(layer_name, model) + if not check_to_quantized(layer): return diff --git a/auto_round/export/export_to_autoround/export_to_fp8_woq.py b/auto_round/export/export_to_autoround/export_to_fp8_woq.py index 7c6862b5c..d29e021c2 100644 --- a/auto_round/export/export_to_autoround/export_to_fp8_woq.py +++ b/auto_round/export/export_to_autoround/export_to_fp8_woq.py @@ -47,7 +47,17 @@ def check_neq_config(config, data_type, bits, group_size, sym): class FP8WOQLinear(torch.nn.Module): - def __init__(self, in_features, out_features, weight, weight_scale, bias=None, weight_zp=None): + + def __init__( + self, + in_features, + out_features, + weight, + weight_scale, + bias=None, + weight_zp=None, + act_scale=None, + dtype=torch.bfloat16): super().__init__() self.in_features = in_features self.out_features = out_features @@ -58,10 +68,13 @@ def __init__(self, in_features, out_features, weight, weight_scale, bias=None, w else: self.register_parameter("bias", None) - self.register_buffer('weight_scale', weight_scale.to(torch.bfloat16)) + self.register_buffer('weight_scale', weight_scale.to(dtype)) if weight_zp: - self.register_buffer('weight_zp', weight_zp.to(torch.bfloat16)) + self.register_buffer('weight_zp', weight_zp.to(dtype)) + + if act_scale: + self.register_buffer('act_scale', act_scale.to(dtype)) def pack_layer(layer_name, model, data_type, packing_device=None): @@ -101,6 +114,7 @@ def pack_layer(layer_name, model, data_type, packing_device=None): scale = layer.scale zp = layer.zp weight = layer.weight + act_scale = layer.act_scale if hasattr(layer, "act_scale") else None torch_dtype = torch.float8_e4m3fn if "fp8_e5m2" in data_type: torch_dtype = torch.float8_e5m2 @@ -121,7 +135,15 @@ def pack_layer(layer_name, model, data_type, packing_device=None): in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] bias = layer.bias - my_linear = FP8WOQLinear(in_features, out_features, q_weight, scale, bias, zp) + my_linear = FP8WOQLinear( + in_features, + out_features, + weight=q_weight, + weight_scale=scale, + bias=bias, + weight_zp=zp, + act_scale=act_scale, + dtype=model.dtype) my_linear.to(device) set_module(model, layer_name, my_linear) @@ -141,7 +163,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round", quantization_config["fmt"] = "e5m2" else: quantization_config["fmt"] = "e4m3" - quantization_config["activation_scheme"] = "dynamic" + quantization_config["activation_scheme"] = "dynamic" if quantization_config['act_dynamic'] else "static" tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index d237035e5..8faea4e8f 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -489,15 +489,30 @@ def tune(args): model_name = args.model.rstrip("/") if model_name.split('/')[-1].strip('.') == "" and "gguf" not in args.format: - export_dir = os.path.join(args.output_dir, f"w{autoround.bits}g{autoround.group_size}") + if autoround.group_size <= 0: + if "fp" in autoround.act_data_type: + suffix = f"afp{autoround.act_bits}" + else: + suffix = f"a{autoround.act_bits}" + else: + suffix = f"g{autoround.group_size}" + export_dir = os.path.join(args.output_dir, f"w{autoround.bits}{suffix}") elif model_name.split('/')[-1].strip('.') == "" and "gguf" in args.format: export_dir = args.output_dir elif model_name.split('./')[-1].strip('./') != "" and "gguf" in args.format: export_dir = os.path.join(args.output_dir, model_name.split('/')[-1] + "-gguf") else: - export_dir = os.path.join(args.output_dir, - model_name.split('/')[-1] + f"-w{autoround.bits}g{autoround.group_size}") + if autoround.group_size <= 0: + if "fp" in autoround.act_data_type: + suffix = f"afp{autoround.act_bits}" + else: + suffix = f"a{autoround.act_bits}" + else: + suffix = f"g{autoround.group_size}" + export_dir = os.path.join( + args.output_dir, + model_name.split('/')[-1] + f"-w{autoround.bits}{suffix}") model, folders = autoround.quantize_and_save(export_dir, format=args.format) diff --git a/auto_round/utils.py b/auto_round/utils.py index ac0f9b039..f64c8fb41 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -479,6 +479,9 @@ def check_to_quantized(config): if isinstance(config, dict): bits = int(config.get("bits", 16)) act_bits = int(config.get("act_bits", 16)) + elif hasattr(config, "orig_layer"): + bits = int(config.orig_layer.bits) if hasattr(config.orig_layer, "bits") else 16 + act_bits = int(config.orig_layer.act_bits) if hasattr(config.orig_layer, "act_bits") else 16 else: bits = int(config.bits) if hasattr(config, "bits") else 16 act_bits = int(config.act_bits) if hasattr(config, "act_bits") else 16 diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 98f9599f7..a226daeed 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -301,9 +301,12 @@ def _set_dict_attr(attr_dict, attr_name): tmp_shape = (1) if self.orig_layer.act_group_size > 1: tmp_shape = (1, self.orig_layer.act_group_size) - _, act_scale, _ = self._qdq_act(torch.zeros(tmp_shape).to(self.device), + if act_max is not None: + _, act_scale, _ = self._qdq_act(torch.zeros(tmp_shape).to(self.device), act_max_scale=self.act_max_scale, act_max=act_max) - self.orig_layer.act_max = torch.tensor(self.orig_layer.act_max * act_max_scale.item()).to("cpu") + self.orig_layer.act_max = torch.tensor(self.orig_layer.act_max * act_max_scale.item()).to("cpu") + else: + act_scale = torch.ones_like(act_max_scale, dtype=self.act_data_type) self.orig_layer.act_scale = act_scale.to("cpu") self.orig_layer.q_scale_thresh = self.q_scale_thresh diff --git a/test/test_cpu/test_act_quantization.py b/test/test_cpu/test_act_quantization.py index f2d3bcb8c..f65c6d690 100644 --- a/test/test_cpu/test_act_quantization.py +++ b/test/test_cpu/test_act_quantization.py @@ -65,8 +65,8 @@ def test_wint4fp8_dynamic(self): seqlen=2, dataset=self.llm_dataloader, act_bits=8, - data_type="fp8_to_int_sym", - act_data_type="fp8_dynamic_per_token" + data_type="fp8", + act_data_type="fp8", ) autoround.quantize() @@ -88,3 +88,5 @@ def test_wint4fp8_static(self): ) autoround.quantize() +if __name__ == "__main__": + unittest.main() diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index 06dee0181..8915c548e 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -203,6 +203,57 @@ def test_autoround_3bit_sym_format(self): print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_static_afp8_export(self): + import os + from safetensors import safe_open + + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + autoround = AutoRound( + model, + self.tokenizer, + bits=8, + group_size=-1, + iters=0, + act_bits=8, + nsamples=2, + data_type="fp8", + act_data_type="fp8", + act_dynamic=False, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") + self.assertIn("model.decoder.layers.8.self_attn.k_proj.act_scale", f.keys()) + self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.act_scale").shape, torch.Size([1,1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + shutil.rmtree(quantized_model_path, ignore_errors=True) + + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + autoround = AutoRound( + model, + self.tokenizer, + bits=8, + group_size=-1, + iters=1, + act_bits=8, + nsamples=2, + data_type="fp8", + act_data_type="fp8", + act_dynamic=False, + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + + + f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") + self.assertIn("model.decoder.layers.8.self_attn.k_proj.act_scale", f.keys()) + self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.act_scale").shape, torch.Size([1,1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + shutil.rmtree(quantized_model_path, ignore_errors=True) if __name__ == "__main__": unittest.main()