From 456b4d6197961fd21dd28f841c0ac01064ff4b8a Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Wed, 16 Apr 2025 18:11:33 +0800 Subject: [PATCH 1/8] enable llama4 int8 quantization baseline Signed-off-by: Zhang, Weiwei1 --- auto_round/autoround.py | 31 +++- auto_round/data_type/int.py | 32 +--- auto_round/script/mllm.py | 14 +- auto_round/special_model_handler.py | 5 +- auto_round/utils.py | 133 ++++++++++++- auto_round/wrapper.py | 279 ++++++++++++++++++++++++++-- run_llama4_quant.sh | 24 +++ 7 files changed, 472 insertions(+), 46 deletions(-) create mode 100644 run_llama4_quant.sh diff --git a/auto_round/autoround.py b/auto_round/autoround.py index f63827cf..09074651 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -513,6 +513,7 @@ def remove_duplicates(lst): return model, folders + @torch.inference_mode def quantize_rtn(self): if self.amp: @@ -529,7 +530,12 @@ def quantize_rtn(self): m = get_module(self.model, name) m.to(self.device) - m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False) + if "_fake" not in name: + m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False) + else: + from .wrapper import WrapperParameter + m = WrapperParameter(m, enable_minmax_tuning=False, + enable_norm_bias_tuning=False) m = m.unwrapper({}) m.to("cpu") if self.is_packing_immediate: @@ -542,6 +548,7 @@ def quantize_rtn(self): self.quantized = True return self.model, self.layer_config + def quantize(self): """Quantize the model and return the quantized model along with layer configurations. the entry of AutoRound. @@ -754,10 +761,11 @@ def set_layerwise_config(self, layer_config): # If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block if n not in layers_in_blocks and check_to_quantized(layer_config[n]): has_qlayer_outside_block = True - - in_features, out_features = get_layer_features(m) - if in_features <= layer_config[n]["group_size"]: - layer_config[n]["group_size"] = -1 + from .utils import ParamWrapper + if not isinstance(m , ParamWrapper): + in_features, out_features = get_layer_features(m) + if in_features <= layer_config[n]["group_size"]: + layer_config[n]["group_size"] = -1 # Apply the configuration to the corresponding layer in the model for key in keys: @@ -1391,7 +1399,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch mse_reduction = "sum" mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) scaler = self.get_scaler() # pylint: disable=assignment-from-none - init_loss = None + init_loss = 0 best_params = {} total_loss = 0 @@ -1635,13 +1643,21 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k return if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later self.model = self.model.to("cpu") - self.model.save_pretrained(output_dir) + os.makedirs(output_dir, exist_ok=True) + if "llama4" not in str(self.model.__class__.__name__).lower(): + self.model.save_pretrained(output_dir) + else: + from .utils import pack_to_int8 + pack_to_int8(self.model, output_dir) + if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) processor = kwargs.get("processor", None) if processor is not None: processor.save_pretrained(output_dir) + return + if self.act_bits <= 8 and format == "qdq": logger.warning( "Support for exporting activation quantization is limited. " @@ -2159,3 +2175,4 @@ def __init__( super_group_size=super_group_size, **kwargs, ) + diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 899abf17..3ddc1826 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -38,29 +38,16 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal Returns: Quantized and de-quantized tensor, scale, zero-point """ - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2 ** (bits - 1) - if tensor_min is None or tensor_max is None: - wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) - wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) - else: - wmin_tmp = tensor_min - wmax_tmp = tensor_max - - wmin_abs = -(wmin_tmp * min_scale) # pylint: disable=E1130 - wmax_abs = wmax_tmp * max_scale - max_v = (2 * (wmax_abs < wmin_abs).int() - 1) * torch.max(wmax_abs, wmin_abs) - scale = (max_v / maxq).to(scale_dtype) - scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) - zp = torch.full_like(scale, maxq) # pylint: disable=E1130 - scale = scale.unsqueeze(dim=-1) - zp = zp.unsqueeze(dim=-1) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) - qdq_result = (scale * (q - zp)).to(tensor.dtype) - qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - return qdq_result, scale, zp + assert tensor.dim() == 2 + qmax = 127.0 + abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1] + scale = abs_max / qmax # [rows, 1] + assert scale.shape == (tensor.shape[0], 1) + quantized = torch.round(tensor / scale) + quantized = torch.clamp(quantized, -qmax, qmax) + quantized = revert_tensor_by_pad(quantized, orig_shape=orig_shape, pad_len=pad_len) + return quantized, scale.to(torch.float32), None ## the values should be positive @@ -276,3 +263,4 @@ def quant_tensor_asym_wo_round(tensor, bits=4, group_size=-1, v=0, min_scale=1.0 qdq_result = (scale * (q - zp)).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, zp + diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 6d06aca1..3faf288a 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -327,11 +327,20 @@ def tune(args): model_name, torch_dtype=torch_dtype, use_auto_mapping=use_auto_mapping, - trust_remote_code=not args.disable_trust_remote_code) + trust_remote_code=not args.disable_trust_remote_code, + model_dtype=args.model_dtype) from auto_round import AutoRoundMLLM model = model.eval() + + from auto_round.utils import (set_module, ParamWrapper) + if "llama4" in str(model.__class__.__name__).lower(): + for n, p in model.named_parameters(): + if '.experts.gate_up_proj' in n or '.experts.down_proj' in n: + name = f"{n}_fake" + set_module(model, name, ParamWrapper(p)) + round = AutoRoundMLLM @@ -349,7 +358,7 @@ def tune(args): if args.fp_layers != "": fp_layers = args.fp_layers.replace(" ", "").split(",") for n, m in model.named_modules(): - if not isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)): + if not isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper)): continue for fp_layer in fp_layers: if fp_layer in n: @@ -564,3 +573,4 @@ def lmms_eval(args): apply_chat_template=False, ) return results + diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index f602f9b3..27e2b29f 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -22,7 +22,9 @@ "qwen2_vl", "deepseek_vl_v2", "chatglm", - "idefics3" + "idefics3", + "llama4", + "phi4mm" ] SPECIAL_SHARED_CACHE_KEYS = { @@ -104,3 +106,4 @@ def check_mllm_model_batch(model, batch_size, gradient_accumulate_steps=1): f"batch_size=1. As an alternative, set the gradient_accumulate_steps={accumulate_steps}") return 1, accumulate_steps return batch_size, gradient_accumulate_steps + diff --git a/auto_round/utils.py b/auto_round/utils.py index 00db00aa..549f5145 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -40,7 +40,12 @@ supported_formats = supported_formats + tuple(GGUF_CONFIG.keys()) -supported_layer_types = (torch.nn.Linear, transformers.modeling_utils.Conv1D) +class ParamWrapper(torch.nn.Module): + def __init__(self, param: torch.nn.Parameter): + super().__init__() + self.weight = param + +supported_layer_types = (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper) @lru_cache(None) @@ -768,7 +773,7 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): def get_layer_names_in_block(model, supported_types=(torch.nn.Linear, - transformers.modeling_utils.Conv1D), quant_block_list=None): + transformers.modeling_utils.Conv1D, ParamWrapper), quant_block_list=None): """Retrieves the names of layers within each block of the model. Returns: @@ -1062,7 +1067,7 @@ def get_fp_layer_names(model, fp_layers): fp_layers = fp_layers.replace(" ", "").split(",") all_layer_names = [] for n, m in model.named_modules(): - if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)): + if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper)): all_layer_names.append(n) not_to_quantized_layers = [] @@ -1136,6 +1141,127 @@ def get_device_and_parallelism(device): parallelism = False return device, parallelism +def translate_2_sglang_int8(model): + state_dict = model.state_dict() + count=0 + state_list = list(state_dict.keys()) + for name in state_list: + if ".experts." in name and "_fake" not in name: + state_dict.pop(name, None) + gc.collect() + for name, module in model.named_modules(): + if hasattr(module, "weight_scale"): + count+=1 + state_dict[f"{name}.weight_scale"] = module.weight_scale + state_dict[f"{name}.weight"] = state_dict[f"{name}.weight"].to(torch.int8) + gc.collect() + print(f"quantized_count: {count}") + + # handle specific large experts + new_state_dict = {} + from tqdm import tqdm + state_list = list(state_dict.keys()) + for name in tqdm(state_list): + if name.endswith("_fake.weight"): + weight = state_dict[name] + if weight.dim() != 3: + continue # skip any unexpected format + for id in range(int(weight.size(0))): + scale_name = f"{name}_scale" + weight_name_expert = name.replace("_fake", "") + weight_name_expert = weight_name_expert.replace("experts.", "experts."+str(id)+".") + weight_expert = weight[id].transpose(0,1).contiguous() + scale_expert = state_dict[scale_name][id] + if "gate_up_proj" in name: + weight_expert_0, weight_expert_1 = weight_expert.chunk(2,dim=0) + weight_expert_0 = weight_expert_0.contiguous() + weight_expert_1 = weight_expert_1.contiguous() + scale_0, scale_1 = scale_expert.chunk(2) + scale_0 = scale_0.contiguous() + scale_1 = scale_1.contiguous() + weight_name_expert_0 = weight_name_expert.replace("gate_up_proj", "gate_proj") + weight_name_expert_1 = weight_name_expert.replace("gate_up_proj", "up_proj") + new_scale_name_0 = f"{weight_name_expert_0}_scale" + new_scale_name_1 = f"{weight_name_expert_1}_scale" + new_state_dict[weight_name_expert_0] = weight_expert_0 + new_state_dict[new_scale_name_0] = scale_0 + new_state_dict[weight_name_expert_1] = weight_expert_1 + new_state_dict[new_scale_name_1] = scale_1 + else: + new_scale_name = f"{weight_name_expert}_scale" + new_state_dict[weight_name_expert] = weight_expert + new_state_dict[new_scale_name] = scale_expert + state_dict.pop(name, None) + state_dict.pop(scale_name, None) + gc.collect() + elif ".experts." not in name: + new_state_dict[name] = state_dict[name] + else: + continue + return new_state_dict + + +def pack_to_int8(model, output_dir): + import json + from safetensors.torch import save_file + with torch.no_grad(): + state_dict = translate_2_sglang_int8(model) + max_shard_size = 40 * 1024**3 # 40GB + shards = {} + current_shard = {} + current_size = 0 + shard_id = 1 + + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() # count param size + + # limit spilt size and save to files + if current_size + param_size > max_shard_size: + shard_name = f"model-{shard_id:05d}-of-00000.safetensors" + shard_path = os.path.join(output_dir, shard_name) + save_file(current_shard, shard_path) + + shards[shard_name] = list(current_shard.keys()) # record shard names + current_shard = {} + current_size = 0 + shard_id += 1 + + current_shard[name] = param + current_size += param_size + + # save last shard + if current_shard: + shard_name = f"model-{shard_id:05d}-of-00000.safetensors" + shard_path = os.path.join(output_dir, shard_name) + save_file(current_shard, shard_path) + shards[shard_name] = list(current_shard.keys()) + + # update files number + total_shards = shard_id + for old_name in list(shards.keys()): + new_name = old_name.replace("00000", f"{total_shards:05d}") + old_path = os.path.join(output_dir, old_name) + new_path = os.path.join(output_dir, new_name) + os.rename(old_path, new_path) + shards[new_name] = shards.pop(old_name) + + # build weight_map(params -> spilt file) + weight_map = {} + for shard_file, param_names in shards.items(): + for param_name in param_names: + weight_map[param_name] = shard_file + + # generate the model.safetensors.index.json + index = { + "metadata": {"total_size": sum(os.path.getsize(os.path.join(output_dir, f)) for f in shards.keys())}, + "weight_map": weight_map + } + + index_path = os.path.join(output_dir, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + return + def set_cuda_visible_devices(device): devices = device.replace(" ", "").split(',') @@ -1439,3 +1565,4 @@ def get_shared_keys(model): shared_keys = shared_cache_keys shared_keys += SPECIAL_SHARED_CACHE_KEYS.get(model.__class__.__name__, ()) return shared_keys + diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index a624d9f2..bafcbcf6 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -187,8 +187,9 @@ def _qdq_weight(self, value, min_scale, max_scale): data_type=self.data_type, q_scale_thresh=self.q_scale_thresh, **quant_kwargs - ) + ) weight_q = weight_q.to(weight.dtype) + if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): weight_q = weight_q.t() return weight_q, scale, zp @@ -243,14 +244,17 @@ def unwrapper(self, best_params): if self.orig_layer.weight.device.type == 'meta': self.orig_layer.to(self.device) ##unwrapper weight - qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) + # qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) + q_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) - self.orig_layer.weight.data.copy_(qdq_weight) + self.orig_layer.weight.data.copy_(q_weight) + # self.orig_layer.weight.data = self.orig_layer.weight.data.to(q_weight.dtype) # force to int8 self.orig_layer.weight.grad = None + - shape = qdq_weight.shape + shape = q_weight.shape if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): - shape = qdq_weight.t().shape + shape = q_weight.t().shape def _set_dict_attr(attr_dict, attr_name): for key in attr_dict.keys(): @@ -263,7 +267,8 @@ def _set_dict_attr(attr_dict, attr_name): if isinstance(scale, dict): _set_dict_attr(scale, "scale") else: - self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu") + # self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu") + self.orig_layer.weight_scale = scale.reshape(shape[0], -1).to("cpu") if zp is not None: if isinstance(zp, dict): @@ -428,6 +433,246 @@ def forward(self, input): return F.layer_norm( input, self.orig_layer.normalized_shape, weight_q, self.orig_layer.bias, self.orig_layer.eps).to( self.output_device) + + +class WrapperParameter(torch.nn.Module): + """A wrapper for layer parameter with quantized weights. + + This class wraps a given layer normalization module and applies quantization without round + to its weights. The quantization is parameterized by the number of bits and + an optional group size. + """ + + def __init__(self, orig_layer, enable_minmax_tuning=False, enable_norm_bias_tuning=False, device='cpu', **kwargs): + super(WrapperParameter, self).__init__() + self.orig_layer = orig_layer + self.output_device = device + self.group_size = -1 ## hard code + self.bits = 8 + self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device + self.enable_minmax_tuning = enable_minmax_tuning + self.enable_norm_bias_tuning = False + self.enable_act_quant = False + self.q_scale_thresh = 1e-5 + self._init_tuning_params_and_quant_func() + # self.orig_forward = self.linear_forward if isinstance(self.orig_layer, torch.nn.Linear) else self.conv1d_forward + + def _init_tuning_params_and_quant_func(self): + """Initializes tuning parameters and quantization functions. + + This method sets up required parameters and functions for weight quantization, + activation quantization, and bias/normalization. + """ + self.params = {} + p_dtype = torch.float32 ##parameter dtype + + orig_layer = self.orig_layer + orig_weight = getattr(orig_layer, "get_weight", lambda: orig_layer.weight)() + weight_reshape = reshape_and_pad_tensor(orig_layer.weight.data, orig_layer.group_size) + self.weight_min = None + self.weight_max = None + # self.weight_min = torch.clamp(weight_reshape.min(1)[0], max=0) + # self.weight_max = torch.clamp(weight_reshape.max(1)[0], min=0) + # self._init_params("value", p_dtype, weight_reshape.shape, 0, True) + # Min-max scale initialization + shape = get_scale_shape(orig_layer.weight, orig_layer.group_size) + # self._init_params("min_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) + # self._init_params("max_scale", p_dtype, shape, 1.0, self.enable_minmax_tuning) + + self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, + orig_layer.sym) + + def _init_params(self, name, dtype, shape, value, tunable): + """Initializes a parameter for tuning or uses a constant if tuning is disabled. + + Args: + name (str): Name of the parameter. + dtype (torch.dtype): Data type of the parameter. + shape (tuple): Shape of the parameter. + value (float): Initial value for the parameter. + tunable (bool): Whether the parameter should be tunable. + """ + if tunable: + p = torch.nn.Parameter(torch.ones(shape, device=self.device, dtype=dtype) * value, requires_grad=True) + self.params.update({name: p}) + else: + p = torch.tensor(1.0 * value, device=self.device, dtype=dtype) + + setattr(self, name, p) + + def _qdq_weight(self, value=0, min_scale=1.0, max_scale=1.0): + """Quantizes and dequantizes weights with tuning parameters. + + Args: + value (torch.Tensor): Value added for rounding for tuning. + min_scale (torch.Tensor): Minimum scale for the min value of quantization. + max_scale (torch.Tensor): Maximum scale for the max value of quantization. + + Returns: + tuple: Quantized weight, scale, and zero point. + """ + if isinstance(min_scale, torch.Tensor): + min_scale.data.clamp_(0, 1.0) + max_scale.data.clamp_(0, 1.0) + weight = self.orig_layer.weight + if weight.device.type == 'meta': + weight = self.orig_layer.weight.to(self.device) + # if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + # weight = weight.t() + quant_kwargs = {} + if hasattr(self.orig_layer, "super_bits"): + quant_kwargs["super_bits"] = self.orig_layer.super_bits + quant_kwargs["super_group_size"] = self.orig_layer.super_group_size + weight_q, scale = [], [] + with torch.no_grad(): + # chunk_list = torch.chunk(weight, chunks=weight.shape[0], dim=0) + for id in range(weight.shape[0]): + chunk_weight = weight[id] + chunk_list = [] + if chunk_weight.shape[-1] > 8192: # gate_up_proj + chunk_weight_0, chunk_weight_1 = chunk_weight.chunk(2,dim=-1) + chunk_list.append(chunk_weight_0) + chunk_list.append(chunk_weight_1) + else: + chunk_list.append(chunk_weight) + chunk_quantized_list = [] + chunk_scale_list = [] + for chunk in chunk_list: + chunk = chunk.transpose(0,1).contiguous() + chunk_weight_q, chunk_scale, zp = self.weight_quant_func( + chunk, + bits=self.orig_layer.bits, + group_size=self.orig_layer.group_size, + v=value, + min_scale=min_scale, + max_scale=max_scale, + scale_dtype=self.orig_layer.scale_dtype, + tensor_min=self.weight_min, + tensor_max=self.weight_max, + data_type=self.data_type, + q_scale_thresh=self.q_scale_thresh, + **quant_kwargs + ) + chunk_weight_q = chunk_weight_q.transpose(0,1).contiguous() + chunk_quantized_list.append(chunk_weight_q) + chunk_scale_list.append(chunk_scale) + chunk_weight_q = torch.cat(chunk_quantized_list, dim=1) + chunk_scale = torch.cat(chunk_scale_list, dim=0) + chunk_weight_q = chunk_weight_q.to(weight.dtype) + weight_q.append(chunk_weight_q.cpu()) + scale.append(chunk_scale.cpu()) + + weight_q = torch.stack(weight_q, dim=0) + scale = torch.stack(scale) + # if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + # weight_q = weight_q.t() + return weight_q, scale, zp + + def unwrapper(self, best_params): + """Restores the original layer by applying the best tuning parameters. + + Args: + best_params (dict): Dictionary containing the best tuning parameters. + + Returns: + torch.nn.Module: The unwrapped and restored original layer. + """ + best_params = best_params or {} + # v = best_params.get('value', torch.tensor(0.0)).to(self.device) + # min_scale = best_params.get('min_scale', torch.tensor(1.0)).to(self.device) + # max_scale = best_params.get('max_scale', torch.tensor(1.0)).to(self.device) + + if self.orig_layer.weight.device.type == 'meta': + self.orig_layer.to(self.device) + ##unwrapper weight for experts + qdq_weight, scale, zp = self._qdq_weight()#v, min_scale, max_scale + + self.orig_layer.weight.data.copy_(qdq_weight) + self.orig_layer.weight.grad = None + + shape = qdq_weight.shape + # if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + # shape = qdq_weight.t().shape + + def _set_dict_attr(attr_dict, attr_name): + for key in attr_dict.keys(): + if key == attr_name: + setattr(self.orig_layer, attr_name, attr_dict[key].reshape(shape[0], -1).to("cpu")) + else: + name = "w_" + key + setattr(self.orig_layer, name, attr_dict[key].to("cpu")) + + if isinstance(scale, dict): + _set_dict_attr(scale, "scale") + else: + # self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu") + self.orig_layer.weight_scale = scale.to("cpu") + + if zp is not None: + if isinstance(zp, dict): + _set_dict_attr(zp, "zp") + else: + zp = zp.reshape(shape[0], -1) + self.orig_layer.zp = zp.to("cpu") if zp is not None else None + else: + self.orig_layer.zp = None + + return self.orig_layer + + # def linear_forward(self, x, weight, bias): + # """Performs the forward pass for a linear layer. + + # Args: + # x (torch.Tensor): Input tensor. + # weight (torch.Tensor): Weight tensor for the linear layer. + # bias (torch.Tensor): Bias tensor for the linear layer. + + # Returns: + # torch.Tensor: Output tensor after applying the linear layer. + # """ + # return F.linear(x, weight, bias) + + # def conv1d_forward(self, x, weight, bias): + # """Performs the forward pass for a Conv1D layer. + + # Args: + # x (torch.Tensor): Input tensor. + # weight (torch.Tensor): Weight tensor for the Conv1D layer. + # bias (torch.Tensor): Bias tensor for the Conv1D layer. + + # Returns: + # torch.Tensor: Output tensor after applying the Conv1D layer. + # """ + # size_out = x.size()[:-1] + (self.orig_layer.nf,) + # x = torch.addmm(bias, x.view(-1, x.size(-1)), weight) + # x = x.view(*size_out) + # return x + + # def forward(self, x): + # """Executes the forward pass with quantized weights and optional bias/activation quantization. + + # Args: + # x (torch.Tensor): Input tensor. + + # Returns: + # torch.Tensor: Output tensor after applying the wrapped layer. + # """ + # x = x.to(self.device) + # weight_q, _, _ = self._qdq_weight(self.value, self.min_scale, self.max_scale) + + # if self.enable_act_quant: + # act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + # x, _, _ = self._qdq_act(x, act_max_scale=self.act_max_scale, act_max=act_max) + + # # pylint: disable=not-callable + # bias = self.orig_layer.bias + # if bias is not None and bias.device.type == 'meta': + # bias = self.orig_layer.get_bias().to(self.device) + # if self.enable_norm_bias_tuning: + # bias, _, _ = self._qdq_bias(bias, self.bias_v) + + # output = self.orig_forward(x, weight_q, bias).to(self.output_device) + # return output class WrapperLlamaNorm(torch.nn.Module): @@ -506,6 +751,11 @@ def forward(self, x, **kwargs): return hidden_states +# def set_parameter(block, name: str, new_param: torch.nn.Parameter): +# """Recursively set a parameter in a module by its dot-separated name.""" +# sub_module = getattr(block, name) +# setattr(block, name, new_param) + def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device='cpu', **kwargs): """Wraps the layers in the given block with a custom Wrapper module. @@ -518,15 +768,21 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device=' """ quantized_layers = [] unquantized_layers = [] + from .utils import ParamWrapper for n, m in block.named_modules(): - if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)): + if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper)): if not check_to_quantized(m): unquantized_layers.append(n) continue - new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, device=device, - **kwargs, - ) + if "_fake" not in n: + new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning, + enable_norm_bias_tuning=enable_norm_bias_tuning, device=device, + **kwargs, + ) + else: + new_m = WrapperParameter(m, enable_minmax_tuning=enable_minmax_tuning, + enable_norm_bias_tuning=enable_norm_bias_tuning, device=device, + **kwargs,) set_module(block, n, new_m) quantized_layers.append(n) @@ -583,3 +839,4 @@ def unwrapper_block(block, best_params): best_param = None orig_layer = m.unwrapper(best_param) set_module(block, n, orig_layer) + diff --git a/run_llama4_quant.sh b/run_llama4_quant.sh new file mode 100644 index 00000000..f2d6c409 --- /dev/null +++ b/run_llama4_quant.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -x + +## build from source +# pip install -e . + +model_dir=/PATH/TO/LLAMA4/MODEL/ +save_path=/PATH/TO/SAVE/MODEL/ +for model in Llama-4-Maverick-17B-128E-Instruct +do +python3 -m auto_round --mllm \ + --model /$model_dir/${model} \ + --bits 8 \ + --group_size -1 \ + --batch_size 1 \ + --iters 0 \ + --nsamples 8 \ + --format fake \ + --fp_layers "router,shared_expert,feed_forward.down_proj,feed_forward.gate_proj,feed_forward.up_proj,k_proj,o_proj,q_proj,v_proj" \ + --disable_minmax_tuning \ + --scale_dtype fp32 \ + --output_dir ${save_path}/${model} \ + 2>&1| tee -a ${save_path}/${model}_int8.txt +done From 9337aa724d45814cd9cb57f34377e3ca882feb08 Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Wed, 16 Apr 2025 23:00:05 +0800 Subject: [PATCH 2/8] add save config Signed-off-by: Zhang, Weiwei1 --- auto_round/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/auto_round/utils.py b/auto_round/utils.py index 549f5145..eb2242be 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1260,6 +1260,9 @@ def pack_to_int8(model, output_dir): index_path = os.path.join(output_dir, "model.safetensors.index.json") with open(index_path, "w") as f: json.dump(index, f, indent=2) + if hasattr(model, config): + model.config.save_pretrained(output_dir) + return From e2689878eb36e701af9c165aa001110700137f1a Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Wed, 16 Apr 2025 23:03:09 +0800 Subject: [PATCH 3/8] fixtypo Signed-off-by: Zhang, Weiwei1 --- auto_round/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/utils.py b/auto_round/utils.py index eb2242be..7cf0ed09 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1260,7 +1260,7 @@ def pack_to_int8(model, output_dir): index_path = os.path.join(output_dir, "model.safetensors.index.json") with open(index_path, "w") as f: json.dump(index, f, indent=2) - if hasattr(model, config): + if hasattr(model, "config"): model.config.save_pretrained(output_dir) return From 470927a065dbef4aa58b5649f5aed26904ddb02c Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Fri, 18 Apr 2025 13:30:23 +0800 Subject: [PATCH 4/8] refine script Signed-off-by: Zhang, Weiwei1 --- auto_round/autoround.py | 4 +++- run_llama4_quant.sh | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 09074651..7af803de 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -1643,10 +1643,12 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k return if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later self.model = self.model.to("cpu") - os.makedirs(output_dir, exist_ok=True) if "llama4" not in str(self.model.__class__.__name__).lower(): + os.makedirs(output_dir, exist_ok=True) self.model.save_pretrained(output_dir) else: + output_dir = output_dir.replace("-fake","") + os.makedirs(output_dir, exist_ok=True) from .utils import pack_to_int8 pack_to_int8(self.model, output_dir) diff --git a/run_llama4_quant.sh b/run_llama4_quant.sh index f2d6c409..9fd41607 100644 --- a/run_llama4_quant.sh +++ b/run_llama4_quant.sh @@ -2,10 +2,10 @@ set -x ## build from source -# pip install -e . +# pip install -e .[cpu] -model_dir=/PATH/TO/LLAMA4/MODEL/ -save_path=/PATH/TO/SAVE/MODEL/ +model_dir=$1 +save_path=$2 for model in Llama-4-Maverick-17B-128E-Instruct do python3 -m auto_round --mllm \ @@ -22,3 +22,4 @@ python3 -m auto_round --mllm \ --output_dir ${save_path}/${model} \ 2>&1| tee -a ${save_path}/${model}_int8.txt done + From dd84a4d8ba4f74be063547870a702657c1161a0a Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Fri, 18 Apr 2025 13:58:02 +0800 Subject: [PATCH 5/8] typofix Signed-off-by: Zhang, Weiwei1 --- run_llama4_quant.sh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/run_llama4_quant.sh b/run_llama4_quant.sh index 9fd41607..fe930b4b 100644 --- a/run_llama4_quant.sh +++ b/run_llama4_quant.sh @@ -3,11 +3,10 @@ set -x ## build from source # pip install -e .[cpu] +model=$1 +model_dir=$2 +save_path=$3 -model_dir=$1 -save_path=$2 -for model in Llama-4-Maverick-17B-128E-Instruct -do python3 -m auto_round --mllm \ --model /$model_dir/${model} \ --bits 8 \ @@ -21,5 +20,5 @@ python3 -m auto_round --mllm \ --scale_dtype fp32 \ --output_dir ${save_path}/${model} \ 2>&1| tee -a ${save_path}/${model}_int8.txt -done + From 87e161587472d31193bfbce201b8d84c202095d3 Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Fri, 18 Apr 2025 14:59:47 +0800 Subject: [PATCH 6/8] refine shell Signed-off-by: Zhang, Weiwei1 --- run_llama4_quant.sh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/run_llama4_quant.sh b/run_llama4_quant.sh index fe930b4b..a062fa06 100644 --- a/run_llama4_quant.sh +++ b/run_llama4_quant.sh @@ -3,12 +3,11 @@ set -x ## build from source # pip install -e .[cpu] -model=$1 -model_dir=$2 -save_path=$3 +model_path=$1 +save_path=$2 python3 -m auto_round --mllm \ - --model /$model_dir/${model} \ + --model /$model_path/ \ --bits 8 \ --group_size -1 \ --batch_size 1 \ @@ -18,7 +17,7 @@ python3 -m auto_round --mllm \ --fp_layers "router,shared_expert,feed_forward.down_proj,feed_forward.gate_proj,feed_forward.up_proj,k_proj,o_proj,q_proj,v_proj" \ --disable_minmax_tuning \ --scale_dtype fp32 \ - --output_dir ${save_path}/${model} \ - 2>&1| tee -a ${save_path}/${model}_int8.txt + --output_dir ${save_path}/ \ + 2>&1| tee -a ${save_path}/int8_log.txt From 8a464a18f1b66d448db37d3d79192ca389d9dc5a Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Wed, 30 Apr 2025 09:38:40 +0800 Subject: [PATCH 7/8] enable qwen3 sglang int8 quantize Signed-off-by: Zhang, Weiwei1 --- auto_round/autoround.py | 2 +- auto_round/utils.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 7af803de..b9a34d29 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -1643,7 +1643,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k return if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later self.model = self.model.to("cpu") - if "llama4" not in str(self.model.__class__.__name__).lower(): + if "llama4" and "qwen3" not in str(self.model.__class__.__name__).lower(): os.makedirs(output_dir, exist_ok=True) self.model.save_pretrained(output_dir) else: diff --git a/auto_round/utils.py b/auto_round/utils.py index 7cf0ed09..9c87b1d2 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1145,16 +1145,19 @@ def translate_2_sglang_int8(model): state_dict = model.state_dict() count=0 state_list = list(state_dict.keys()) - for name in state_list: - if ".experts." in name and "_fake" not in name: - state_dict.pop(name, None) + llama4_model_type = "llama4" in str(model.__class__.__name__).lower() + if llama4_model_type: + for name in state_list: + if ".experts." in name and "_fake" not in name: + state_dict.pop(name, None) gc.collect() for name, module in model.named_modules(): if hasattr(module, "weight_scale"): count+=1 state_dict[f"{name}.weight_scale"] = module.weight_scale state_dict[f"{name}.weight"] = state_dict[f"{name}.weight"].to(torch.int8) - gc.collect() + if llama4_model_type: + gc.collect() print(f"quantized_count: {count}") # handle specific large experts @@ -1162,7 +1165,7 @@ def translate_2_sglang_int8(model): from tqdm import tqdm state_list = list(state_dict.keys()) for name in tqdm(state_list): - if name.endswith("_fake.weight"): + if llama4_model_type and name.endswith("_fake.weight"): weight = state_dict[name] if weight.dim() != 3: continue # skip any unexpected format @@ -1194,7 +1197,7 @@ def translate_2_sglang_int8(model): state_dict.pop(name, None) state_dict.pop(scale_name, None) gc.collect() - elif ".experts." not in name: + elif not llama4_model_type or ".experts." not in name: new_state_dict[name] = state_dict[name] else: continue From 1b5b5e2fe8c409b6201c28a1ad8914a109a2249b Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Wed, 30 Apr 2025 10:41:03 +0800 Subject: [PATCH 8/8] add qwen3 shell script Signed-off-by: Zhang, Weiwei1 --- run_qwen3.sh | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 run_qwen3.sh diff --git a/run_qwen3.sh b/run_qwen3.sh new file mode 100644 index 00000000..fd565318 --- /dev/null +++ b/run_qwen3.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -x + +## build from source +# pip install -e .[cpu] +model_path=$1 +save_path=$2 + +python3 -m auto_round \ + --model /$model_path/ \ + --bits 8 \ + --group_size -1 \ + --iters 0 \ + --format fake \ + --fp_layers "mlp.gate" \ + --disable_minmax_tuning \ + --scale_dtype fp32 \ + --output_dir ${save_path}/ \ + 2>&1| tee -a ${save_path}/int8_log.txt + +