Skip to content

enable llama4 int8 quantization baseline #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def remove_duplicates(lst):

return model, folders


@torch.inference_mode
def quantize_rtn(self):
if self.amp:
Expand All @@ -529,7 +530,12 @@ def quantize_rtn(self):
m = get_module(self.model, name)

m.to(self.device)
m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False)
if "_fake" not in name:
m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False)
else:
from .wrapper import WrapperParameter
m = WrapperParameter(m, enable_minmax_tuning=False,
enable_norm_bias_tuning=False)
m = m.unwrapper({})
m.to("cpu")
if self.is_packing_immediate:
Expand All @@ -542,6 +548,7 @@ def quantize_rtn(self):
self.quantized = True
return self.model, self.layer_config


def quantize(self):
"""Quantize the model and return the quantized model along with layer configurations.
the entry of AutoRound.
Expand Down Expand Up @@ -754,10 +761,11 @@ def set_layerwise_config(self, layer_config):
# If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block
if n not in layers_in_blocks and check_to_quantized(layer_config[n]):
has_qlayer_outside_block = True

in_features, out_features = get_layer_features(m)
if in_features <= layer_config[n]["group_size"]:
layer_config[n]["group_size"] = -1
from .utils import ParamWrapper
if not isinstance(m , ParamWrapper):
in_features, out_features = get_layer_features(m)
if in_features <= layer_config[n]["group_size"]:
layer_config[n]["group_size"] = -1

# Apply the configuration to the corresponding layer in the model
for key in keys:
Expand Down Expand Up @@ -1391,7 +1399,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
mse_reduction = "sum"
mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device)
scaler = self.get_scaler() # pylint: disable=assignment-from-none
init_loss = None
init_loss = 0
best_params = {}
total_loss = 0

Expand Down Expand Up @@ -1635,13 +1643,23 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
return
if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later
self.model = self.model.to("cpu")
self.model.save_pretrained(output_dir)
if "llama4" not in str(self.model.__class__.__name__).lower():
os.makedirs(output_dir, exist_ok=True)
self.model.save_pretrained(output_dir)
else:
output_dir = output_dir.replace("-fake","")
os.makedirs(output_dir, exist_ok=True)
from .utils import pack_to_int8
pack_to_int8(self.model, output_dir)

if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
processor = kwargs.get("processor", None)
if processor is not None:
processor.save_pretrained(output_dir)

return

if self.act_bits <= 8 and format == "qdq":
logger.warning(
"Support for exporting activation quantization is limited. "
Expand Down Expand Up @@ -2159,3 +2177,4 @@ def __init__(
super_group_size=super_group_size,
**kwargs,
)

32 changes: 10 additions & 22 deletions auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,16 @@ def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scal
Returns:
Quantized and de-quantized tensor, scale, zero-point
"""

tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
maxq = 2 ** (bits - 1)
if tensor_min is None or tensor_max is None:
wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0)
wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0)
else:
wmin_tmp = tensor_min
wmax_tmp = tensor_max

wmin_abs = -(wmin_tmp * min_scale) # pylint: disable=E1130
wmax_abs = wmax_tmp * max_scale
max_v = (2 * (wmax_abs < wmin_abs).int() - 1) * torch.max(wmax_abs, wmin_abs)
scale = (max_v / maxq).to(scale_dtype)
scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh))
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
scale = scale.unsqueeze(dim=-1)
zp = zp.unsqueeze(dim=-1)
int_w = round_ste(tensor / scale + v)
q = torch.clamp(int_w + zp, 0, 2 ** bits - 1)
qdq_result = (scale * (q - zp)).to(tensor.dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
return qdq_result, scale, zp
assert tensor.dim() == 2
qmax = 127.0
abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0] # [rows, 1]
scale = abs_max / qmax # [rows, 1]
assert scale.shape == (tensor.shape[0], 1)
quantized = torch.round(tensor / scale)
quantized = torch.clamp(quantized, -qmax, qmax)
quantized = revert_tensor_by_pad(quantized, orig_shape=orig_shape, pad_len=pad_len)
return quantized, scale.to(torch.float32), None


## the values should be positive
Expand Down Expand Up @@ -276,3 +263,4 @@ def quant_tensor_asym_wo_round(tensor, bits=4, group_size=-1, v=0, min_scale=1.0
qdq_result = (scale * (q - zp)).to(tensor.dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
return qdq_result, scale, zp

14 changes: 12 additions & 2 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,20 @@ def tune(args):
model_name,
torch_dtype=torch_dtype,
use_auto_mapping=use_auto_mapping,
trust_remote_code=not args.disable_trust_remote_code)
trust_remote_code=not args.disable_trust_remote_code,
model_dtype=args.model_dtype)

from auto_round import AutoRoundMLLM

model = model.eval()

from auto_round.utils import (set_module, ParamWrapper)
if "llama4" in str(model.__class__.__name__).lower():
for n, p in model.named_parameters():
if '.experts.gate_up_proj' in n or '.experts.down_proj' in n:
name = f"{n}_fake"
set_module(model, name, ParamWrapper(p))


round = AutoRoundMLLM

Expand All @@ -349,7 +358,7 @@ def tune(args):
if args.fp_layers != "":
fp_layers = args.fp_layers.replace(" ", "").split(",")
for n, m in model.named_modules():
if not isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)):
if not isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper)):
continue
for fp_layer in fp_layers:
if fp_layer in n:
Expand Down Expand Up @@ -564,3 +573,4 @@ def lmms_eval(args):
apply_chat_template=False,
)
return results

5 changes: 4 additions & 1 deletion auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
"qwen2_vl",
"deepseek_vl_v2",
"chatglm",
"idefics3"
"idefics3",
"llama4",
"phi4mm"
]

SPECIAL_SHARED_CACHE_KEYS = {
Expand Down Expand Up @@ -104,3 +106,4 @@ def check_mllm_model_batch(model, batch_size, gradient_accumulate_steps=1):
f"batch_size=1. As an alternative, set the gradient_accumulate_steps={accumulate_steps}")
return 1, accumulate_steps
return batch_size, gradient_accumulate_steps

136 changes: 133 additions & 3 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@

supported_formats = supported_formats + tuple(GGUF_CONFIG.keys())

supported_layer_types = (torch.nn.Linear, transformers.modeling_utils.Conv1D)
class ParamWrapper(torch.nn.Module):
def __init__(self, param: torch.nn.Parameter):
super().__init__()
self.weight = param

supported_layer_types = (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper)


@lru_cache(None)
Expand Down Expand Up @@ -768,7 +773,7 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs):


def get_layer_names_in_block(model, supported_types=(torch.nn.Linear,
transformers.modeling_utils.Conv1D), quant_block_list=None):
transformers.modeling_utils.Conv1D, ParamWrapper), quant_block_list=None):
"""Retrieves the names of layers within each block of the model.

Returns:
Expand Down Expand Up @@ -1062,7 +1067,7 @@ def get_fp_layer_names(model, fp_layers):
fp_layers = fp_layers.replace(" ", "").split(",")
all_layer_names = []
for n, m in model.named_modules():
if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)):
if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D, ParamWrapper)):
all_layer_names.append(n)
not_to_quantized_layers = []

Expand Down Expand Up @@ -1136,6 +1141,130 @@ def get_device_and_parallelism(device):
parallelism = False
return device, parallelism

def translate_2_sglang_int8(model):
state_dict = model.state_dict()
count=0
state_list = list(state_dict.keys())
for name in state_list:
if ".experts." in name and "_fake" not in name:
state_dict.pop(name, None)
gc.collect()
for name, module in model.named_modules():
if hasattr(module, "weight_scale"):
count+=1
state_dict[f"{name}.weight_scale"] = module.weight_scale
state_dict[f"{name}.weight"] = state_dict[f"{name}.weight"].to(torch.int8)
gc.collect()
print(f"quantized_count: {count}")

# handle specific large experts
new_state_dict = {}
from tqdm import tqdm
state_list = list(state_dict.keys())
for name in tqdm(state_list):
if name.endswith("_fake.weight"):
weight = state_dict[name]
if weight.dim() != 3:
continue # skip any unexpected format
for id in range(int(weight.size(0))):
scale_name = f"{name}_scale"
weight_name_expert = name.replace("_fake", "")
weight_name_expert = weight_name_expert.replace("experts.", "experts."+str(id)+".")
weight_expert = weight[id].transpose(0,1).contiguous()
scale_expert = state_dict[scale_name][id]
if "gate_up_proj" in name:
weight_expert_0, weight_expert_1 = weight_expert.chunk(2,dim=0)
weight_expert_0 = weight_expert_0.contiguous()
weight_expert_1 = weight_expert_1.contiguous()
scale_0, scale_1 = scale_expert.chunk(2)
scale_0 = scale_0.contiguous()
scale_1 = scale_1.contiguous()
weight_name_expert_0 = weight_name_expert.replace("gate_up_proj", "gate_proj")
weight_name_expert_1 = weight_name_expert.replace("gate_up_proj", "up_proj")
new_scale_name_0 = f"{weight_name_expert_0}_scale"
new_scale_name_1 = f"{weight_name_expert_1}_scale"
new_state_dict[weight_name_expert_0] = weight_expert_0
new_state_dict[new_scale_name_0] = scale_0
new_state_dict[weight_name_expert_1] = weight_expert_1
new_state_dict[new_scale_name_1] = scale_1
else:
new_scale_name = f"{weight_name_expert}_scale"
new_state_dict[weight_name_expert] = weight_expert
new_state_dict[new_scale_name] = scale_expert
state_dict.pop(name, None)
state_dict.pop(scale_name, None)
gc.collect()
elif ".experts." not in name:
new_state_dict[name] = state_dict[name]
else:
continue
return new_state_dict


def pack_to_int8(model, output_dir):
import json
from safetensors.torch import save_file
with torch.no_grad():
state_dict = translate_2_sglang_int8(model)
max_shard_size = 40 * 1024**3 # 40GB
shards = {}
current_shard = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part should be refined and supported in main branch.
Better follow the origin code style if possible

current_size = 0
shard_id = 1

for name, param in state_dict.items():
param_size = param.numel() * param.element_size() # count param size

# limit spilt size and save to files
if current_size + param_size > max_shard_size:
shard_name = f"model-{shard_id:05d}-of-00000.safetensors"
shard_path = os.path.join(output_dir, shard_name)
save_file(current_shard, shard_path)

shards[shard_name] = list(current_shard.keys()) # record shard names
current_shard = {}
current_size = 0
shard_id += 1

current_shard[name] = param
current_size += param_size

# save last shard
if current_shard:
shard_name = f"model-{shard_id:05d}-of-00000.safetensors"
shard_path = os.path.join(output_dir, shard_name)
save_file(current_shard, shard_path)
shards[shard_name] = list(current_shard.keys())

# update files number
total_shards = shard_id
for old_name in list(shards.keys()):
new_name = old_name.replace("00000", f"{total_shards:05d}")
old_path = os.path.join(output_dir, old_name)
new_path = os.path.join(output_dir, new_name)
os.rename(old_path, new_path)
shards[new_name] = shards.pop(old_name)

# build weight_map(params -> spilt file)
weight_map = {}
for shard_file, param_names in shards.items():
for param_name in param_names:
weight_map[param_name] = shard_file

# generate the model.safetensors.index.json
index = {
"metadata": {"total_size": sum(os.path.getsize(os.path.join(output_dir, f)) for f in shards.keys())},
"weight_map": weight_map
}

index_path = os.path.join(output_dir, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
if hasattr(model, "config"):
model.config.save_pretrained(output_dir)

return


def set_cuda_visible_devices(device):
devices = device.replace(" ", "").split(',')
Expand Down Expand Up @@ -1439,3 +1568,4 @@ def get_shared_keys(model):
shared_keys = shared_cache_keys
shared_keys += SPECIAL_SHARED_CACHE_KEYS.get(model.__class__.__name__, ())
return shared_keys

Loading
Loading