Skip to content

support to export static afp8 model #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 49 additions & 29 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,24 +539,6 @@ def parse_format_to_list(self, format: str) -> list:
self.scale_dtype = torch.float32
logger.info(f"change `scale_dtype` to `torch.float32`")

# only support to export afp8
if self.act_bits <= 8:
if "fp8" not in self.act_data_type:
if len(formats) > 1 or "fake" not in formats:
logger.warning(
f"Currently only support to export auto_round format quantized model"
" with fp8 dtype activation for activation quantization."
" Change format to fake and save."
)
formats = ["fake"]
else:
if len(formats) > 1 or "auto_round" not in formats:
logger.warning(
f"Currently only support to export auto_round format for W{self.bits}AFP8 model,"
" change format to auto_round"
)
formats = ["auto_round"]

# Adjust format settings based on compatibility
for index in range(len(formats)):
format = formats[index]
Expand All @@ -581,8 +563,9 @@ def remove_duplicates(lst):
return [x for x in lst if not (x in seen or seen.add(x))]

formats = remove_duplicates(formats)
for format in formats:
self._check_supported_format(format)
for i in range(len(formats)):
formats[i] = self._check_supported_format(formats[i])
formats = remove_duplicates(formats)
return formats

def _check_supported_format(self, format: str) -> bool:
Expand Down Expand Up @@ -631,6 +614,7 @@ def _check_supported_format(self, format: str) -> bool:
)
sys.exit(-1)

return format
def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace=True, **kwargs):
"""Quantizes the model and saves it in the specified format(s).

Expand Down Expand Up @@ -1098,7 +1082,31 @@ def quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
"""
if self.amp:
self.model.to(self.amp_dtype)
self.model.to("cpu")

# all_blocks = get_block_names(self.model)
if self.act_bits <= 8:
all_blocks = get_block_names(self.model)
hook_handles = self.register_act_max_hook(self.model, [blocks[-1] for blocks in all_blocks])
try:
self.calib(self.nsamples, self.batch_size)
except RuntimeError as e:
if "CUDA out of memory" in str(e) or "MODULE:PT_DEVMEM" in str(e):
if len(os.environ.get("CUDA_VISIBLE_DEVICES")) > 1:
raise RuntimeError("Out of memory, please consider reducing nsamples or batch_size")
try:
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
from auto_round.utils import register_per_layer_to_device
hook_handles.extend(register_per_layer_to_device(self.model, self.device))
self.calib(self.nsamples, self.batch_size, device=self.device)
except RuntimeError as e:
if "CUDA out of memory" in str(e) or "MODULE:PT_DEVMEM" in str(e):
raise RuntimeError("Out of memory, please consider reducing nsamples or batch_size")
else:
raise
else:
raise
for handle in hook_handles:
handle.remove()

all_to_quantized_module_names: list[str] = [
n for n, m in self.model.named_modules() if check_to_quantized(m)
Expand All @@ -1109,6 +1117,7 @@ def quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
self.quantize_embedding_layer()

if has_gguf_k and not self.disable_opt_rtn:
self.model.to("cpu")
self.quant_rtn_with_imatrix(all_to_quantized_module_names)
else:
pbar = tqdm(all_to_quantized_module_names)
Expand Down Expand Up @@ -1576,7 +1585,7 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de
return output

@torch.no_grad()
def calib(self, nsamples, bs):
def calib(self, nsamples, bs, device=None):
"""Perform calibration for quantization.

This method calibrates the model for quantization by processing a specified
Expand Down Expand Up @@ -1611,11 +1620,12 @@ def calib(self, nsamples, bs):
for n, m in embed_layers:
m = m.to(self.device)

device = self.model.device if device is None else device
for data in self.dataloader:
if data is None:
continue
if isinstance(data, torch.Tensor):
input_ids = data.to(self.model.device)
input_ids = data.to(device)
data_new = input_ids
elif isinstance(data, str):
if self.tokenizer is None:
Expand All @@ -1624,15 +1634,15 @@ def calib(self, nsamples, bs):
data = self.tokenizer(data, truncation=True, max_length=self.seqlen, return_tensors="pt").data
data_new = {}
for key in data.keys():
data_new[key] = data[key].to(self.model.device)
data_new[key] = data[key].to(device)
input_ids = data_new["input_ids"]
elif isinstance(data, tuple) or isinstance(data, list):
data_new = to_device(data)
input_ids = data_new[0]
else:
data_new = {}
for key in data.keys():
data_new[key] = to_device(data[key], self.model.device)
data_new[key] = to_device(data[key], device)
if key == 'images':
data_new[key] = to_dtype(data_new[key], self.model.dtype)
input_ids = data_new["input_ids"]
Expand Down Expand Up @@ -2055,20 +2065,30 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp
dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
logger.info(dump_info)

def register_act_max_hook(self, model):
def register_act_max_hook(self, model, last_block_name=None):
def get_act_max_hook(module, input, output):
if isinstance(input, (tuple, list)):
input = input[0]
if not hasattr(module, "act_max"):
module.act_max = torch.abs(input).max().item()
else:
module.act_max = max(torch.abs(input).max().item(), module.act_max)

def early_quit_hook(module, input, output):
raise NotImplementedError

hook_handles = []

for n, m in model.named_modules():
if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m):
hook = m.register_forward_hook(get_act_max_hook)
# if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m):
if n in self.layer_config:
config = self.layer_config[n]
if "act_dynamic" in config and config["act_dynamic"] is False and check_to_quantized(config):
hook = m.register_forward_hook(get_act_max_hook)
hook_handles.append(hook)
if (isinstance(last_block_name, list) and n in last_block_name) or \
n == last_block_name:
hook = m.register_forward_hook(early_quit_hook)
hook_handles.append(hook)
return hook_handles

Expand Down Expand Up @@ -2395,7 +2415,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
Returns:
object: The compressed model object.
"""
self._check_supported_format(format)
format = self._check_supported_format(format)

if self.low_cpu_mem_usage:
self.model = self.model.to('cpu')
Expand Down
9 changes: 7 additions & 2 deletions auto_round/export/export_to_autoround/export_to_fp8_woq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_neq_config(config, data_type, bits, group_size, sym):


class FP8WOQLinear(torch.nn.Module):
def __init__(self, in_features, out_features, weight, weight_scale, bias=None, weight_zp=None):
def __init__(self, in_features, out_features, weight, weight_scale, bias=None, weight_zp=None, act_scale=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
Expand All @@ -63,6 +63,9 @@ def __init__(self, in_features, out_features, weight, weight_scale, bias=None, w
if weight_zp:
self.register_buffer('weight_zp', weight_zp.to(torch.bfloat16))

if act_scale:
self.register_buffer('act_scale', weight_scale.to(torch.bfloat16))


def pack_layer(layer_name, model, data_type, packing_device=None):
"""
Expand Down Expand Up @@ -101,6 +104,7 @@ def pack_layer(layer_name, model, data_type, packing_device=None):
scale = layer.scale
zp = layer.zp
weight = layer.weight
act_scale = layer.act_scale if hasattr(layer, "act_scale") else None
torch_dtype = torch.float8_e4m3fn
if "fp8_e5m2" in data_type:
torch_dtype = torch.float8_e5m2
Expand All @@ -121,7 +125,8 @@ def pack_layer(layer_name, model, data_type, packing_device=None):
in_features = layer.weight.shape[0]
out_features = layer.weight.shape[1]
bias = layer.bias
my_linear = FP8WOQLinear(in_features, out_features, q_weight, scale, bias, zp)
my_linear = FP8WOQLinear(
in_features, out_features, weight=q_weight, weight_scale=scale, bias=bias, weight_zp=zp, act_scale=act_scale)

my_linear.to(device)
set_module(model, layer_name, my_linear)
Expand Down
21 changes: 18 additions & 3 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,30 @@ def tune(args):
model_name = args.model.rstrip("/")

if model_name.split('/')[-1].strip('.') == "" and "gguf" not in args.format:
export_dir = os.path.join(args.output_dir, f"w{autoround.bits}g{autoround.group_size}")
if autoround.group_size == -1:
if "fp" in autoround.act_data_type:
suffix = f"afp{autoround.act_bits}"
else:
suffix = f"a{autoround.act_bits}"
else:
suffix = f"g{autoround.group_size}"
export_dir = os.path.join(args.output_dir, f"w{autoround.bits}{suffix}")
elif model_name.split('/')[-1].strip('.') == "" and "gguf" in args.format:
export_dir = args.output_dir
elif model_name.split('./')[-1].strip('./') != "" and "gguf" in args.format:
export_dir = os.path.join(args.output_dir,
model_name.split('/')[-1] + "-gguf")
else:
export_dir = os.path.join(args.output_dir,
model_name.split('/')[-1] + f"-w{autoround.bits}g{autoround.group_size}")
if autoround.group_size == -1:
if "fp" in autoround.act_data_type:
suffix = f"afp{autoround.act_bits}"
else:
suffix = f"a{autoround.act_bits}"
else:
suffix = f"g{autoround.group_size}"
export_dir = os.path.join(
args.output_dir,
model_name.split('/')[-1] + f"-w{autoround.bits}{suffix}")

model, folders = autoround.quantize_and_save(export_dir, format=args.format)

Expand Down
27 changes: 27 additions & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,3 +1920,30 @@ def clean_module_parameter(submodule, parameter):
submodule._buffers[parameter] = None
else:
submodule._parameters[parameter] = None

def get_named_children(model, pre=[]):
"""Get all the name and children of given model."""
module_list = []
if len(list(model.children())) == 0:
return [(".".join(pre), model)]
for name, module in model.named_children():
module_list += get_named_children(module, pre=pre + [name])
return module_list

def register_per_layer_to_device(model, device):
def forward_pre_hook(module, input):
module = module.to(device)

def forward_hook(module, input, output):
module = mv_module_from_gpu(module)
clear_memory()

hook_handels = []
for n, m in get_named_children(model):
hook = m.register_forward_pre_hook(forward_pre_hook)
hook_handels.append(hook)
hook = m.register_forward_hook(forward_hook)
hook_handels.append(hook)

return hook_handels

18 changes: 18 additions & 0 deletions test/test_cpu/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,24 @@ def test_autoround_3bit_sym_format(self):
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
shutil.rmtree(quantized_model_path, ignore_errors=True)


def test_static_afp8_export(self):
autoround = AutoRound(
self.model,
self.tokenizer,
bits=8,
group_size=-1,
iters=0,
act_bits=8,
nsamples=2,
data_type="fp8_sym",
act_data_type="fp8_sym",
act_dynamic=False,
)
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
shutil.rmtree(quantized_model_path, ignore_errors=True)


if __name__ == "__main__":
unittest.main()
Expand Down
Loading