Skip to content

support weight int16 convert and fake quant #370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions examples/quantization/w16_post.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import argparse
import os
import sys

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))

sys.path.insert(1, os.path.join(CURRENT_PATH, '../../'))

import torch
import numpy as np

from tinynn.converter import TFLiteConverter
from tinynn.graph.tracer import model_tracer
from tinynn.util.train_util import DLContext
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.quantization.fake_quantize import set_ptq_fake_quantize

device = torch.device('cuda', 0)


class TensorDataset:
def __init__(self, path):
assert os.path.exists(path), "%sincorrect path" % path
self.path = path
self.data_list = [fname for fname in os.listdir(self.path) if fname.lower().endswith('.npy')]

def __getitem__(self, index):
input_path = os.path.join(self.path, self.data_list[index])
input_npy = np.load(input_path)
input_npy = input_npy.reshape((1, 58))
return torch.from_numpy(input_npy)

def __len__(self):
return len(self.data_list)


def main_worker(args):
# !change to your calibrate data
dataloader = TensorDataset('/data/zhouye/0603/TinyNeuralNetwork/examples/pr_solve/sun_0930/new_data')

with model_tracer():
dummy_input = torch.rand(1, 58)

# from graphmodule_q import QGraphModule
from examples.pr_solve.sun_0930.graphmodule_q_v1 import QGraphModule

model = QGraphModule()
model.load_state_dict(
torch.load("/data/zhouye/0603/TinyNeuralNetwork/examples/pr_solve/sun_0930/graphmodule_v1.pth")
)

model.eval()

quantizer = QATQuantizer(
model,
dummy_input,
work_dir='out',
config={
'extra_tracer_opts': {'patch_torch_size': True},
'force_overwrite': False,
'rewrite_graph': False,
'asymmetric': False,
'per_tensor': False,
'override_qconfig_func': set_ptq_fake_quantize,
},
)
qat_model = quantizer.quantize()
quantizer.rescale_activations_with_quant_min_max(0, 65535)

# Move model to the appropriate device
qat_model.to(device=device)
qat_model.eval()
context = DLContext()
context.device = device

qat_model.apply(torch.quantization.disable_fake_quant)
qat_model.apply(torch.quantization.enable_observer)
for i in range(100):
qat_model(dataloader[i].to(device=device))
# Disable observer and enable fake quantization to validate model with quantization error
qat_model.apply(torch.quantization.disable_observer)
qat_model.apply(torch.quantization.enable_fake_quant)
qat_model(dummy_input.to(device=device))

print(qat_model)
quantizer.rescale_activations_with_quant_min_max(0, 255)

with torch.no_grad():
qat_model.eval()
qat_model.cpu()

# The step below converts the model to an actual quantized model, which uses the quantized kernels.
qat_converted_model = quantizer.convert(qat_model)

# When converting quantized models, please ensure the quantization backend is set.
torch.backends.quantized.engine = quantizer.backend

fp_weight_dict = {}
for n, m in qat_model.named_children():
if hasattr(m, 'weight'):
fp_weight_dict[n] = m.weight.data
converter = TFLiteConverter(
qat_converted_model,
dummy_input,
strict_symmetric_check=True,
quantize_target_type='int16',
tflite_path='out/qat_model.tflite',
output_transpose=False,
fp_weight_dict=fp_weight_dict,
)
converter.convert()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-path', metavar='DIR', default="/data/datasets/cifar10", help='path to dataset')
parser.add_argument('--config', type=str, default=os.path.join(CURRENT_PATH, 'config.yml'))
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--cle', type=bool, default=False)

args = parser.parse_args()
main_worker(args)
121 changes: 121 additions & 0 deletions examples/quantization/w16_post_fq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import argparse
import os
import sys

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))

sys.path.insert(1, os.path.join(CURRENT_PATH, '../../'))

import torch
import numpy as np

from tinynn.graph.tracer import model_tracer
from tinynn.util.train_util import DLContext
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.quantization.fake_quantize import set_ptq_fake_quantize
from tinynn.util.quantization_analysis_util import graph_error_analysis, layer_error_analysis

device = torch.device('cuda', 0)


class TensorDataset:
def __init__(self, path):
assert os.path.exists(path), "%sincorrect path" % path
self.path = path
self.data_list = [fname for fname in os.listdir(self.path) if fname.lower().endswith('.npy')]

def __getitem__(self, index):
input_path = os.path.join(self.path, self.data_list[index])
input_npy = np.load(input_path)
input_npy = input_npy.reshape((1, 58))
return torch.from_numpy(input_npy)

def __len__(self):
return len(self.data_list)


def main_worker(args):
# !change to your calibrate data
dataloader = TensorDataset('/data/zhouye/0603/TinyNeuralNetwork/examples/pr_solve/sun_0930/new_data')

test_data = dataloader[2]

with model_tracer():
dummy_input = torch.rand(1, 58)

# load qmodel
from examples.pr_solve.sun_0930.graphmodule_q_v1 import QGraphModule

model = QGraphModule()
model.load_state_dict(
torch.load("/data/zhouye/0603/TinyNeuralNetwork/examples/pr_solve/sun_0930/graphmodule_v1.pth")
)

model.eval()
output_float = model(test_data)

quantizer = QATQuantizer(
model,
dummy_input,
work_dir='out',
config={
'extra_tracer_opts': {'patch_torch_size': True},
'force_overwrite': False,
'rewrite_graph': False,
'asymmetric': False,
'per_tensor': False,
'override_qconfig_func': set_ptq_fake_quantize,
},
)
qat_model = quantizer.quantize()
# modify rescale_activations_with_quant_min_max to set weight int16 quant range
quantizer.rescale_activations_with_quant_min_max(0, 65535)

# Move model to the appropriate device
qat_model.to(device=device)
qat_model.eval()
context = DLContext()
context.device = device

qat_model.apply(torch.quantization.disable_fake_quant)
qat_model.apply(torch.quantization.enable_observer)
for i in range(100):
qat_model(dataloader[i].to(device=device))
# Disable observer and enable fake quantization to validate model with quantization error
qat_model.apply(torch.quantization.disable_observer)
qat_model.apply(torch.quantization.enable_fake_quant)
qat_model(dummy_input.to(device=device))

print(qat_model)
dummy_input_real = test_data
output = qat_model(test_data.to(device=device))
print("quant: ", output)
print("fp: ", output_float)
print("diff: ", output.detach().cpu() - output_float.detach().cpu())
graph_error_analysis(qat_model, dummy_input_real, metric='cosine')
layer_error_analysis(qat_model, dummy_input_real, metric='cosine')

for n, m in qat_model.named_children():
print(n)
if hasattr(m, 'weight_fake_quant'):
print(
f"|weight q_param: scale:{float(m.weight_fake_quant.scale)}, zp:{int(m.weight_fake_quant.zero_point)}"
)
if hasattr(m, 'activation_post_process'):
print(
f"|activation q_param: scale:{float(m.activation_post_process.scale)},"
f" zp:{int(m.activation_post_process.zero_point)}"
)
# exit()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-path', metavar='DIR', default="/data/datasets/cifar10", help='path to dataset')
parser.add_argument('--config', type=str, default=os.path.join(CURRENT_PATH, 'config.yml'))
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--cle', type=bool, default=False)

args = parser.parse_args()
main_worker(args)
2 changes: 2 additions & 0 deletions tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
group_tensors: bool = False,
missing_outputs_as_constants: bool = False,
legacy_gelu: bool = False,
fp_weight_dict: dict = None,
) -> None:
""" The TFLiteConverter class

Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(
self.tensor_map = {}
self.tensor_map_copies = {}
self.common_graph = CommonGraph()
self.common_graph.temp_store = fp_weight_dict

if type(dummy_input) in (tuple, list):
self.dummy_input = dummy_input
Expand Down
2 changes: 2 additions & 0 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self) -> None:
self.rev_q_mapping = {}
self.transform_store = {}
self.constant_mapping = {}
self.temp_store = {}
self.relations = {}

def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/torch/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def parse(self, node, attrs, args, graph_converter):
class PrimGetAttrConverter(PrimOperatorConverter):
def parse(self, node, attrs, args, graph_converter):
name, name_type = attrs.get('name', (None, None))
graph_converter.relations[self.output_names[0]] = self.input_names[0]
if name is not None and name_type == 's':
v = getattr(self.input_tensors[0], name)
self.output_tensors.append(v)
Expand Down
14 changes: 14 additions & 0 deletions tinynn/converter/operators/torch/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,20 @@ def parse_common(self, graph_converter, fusedActivation=tfl_schema.ActivationFun
bias = state[0][1]

weight_tensor = self.create_attr_tensor(weight)
# find fp_weight from graph_converter.temp_store
fp_name = ''
for module_name in graph_converter.temp_store.keys():
for relation_v in graph_converter.relations.values():
if module_name in relation_v:
fp_name = module_name
fp_weight = graph_converter.temp_store[fp_name]
scale = float(fp_weight.abs().max() / 32767)
int16_weight = torch.quantize_per_tensor(fp_weight, scale, 0, torch.qint32)
weight_tensor.quantization.scale = scale
weight_tensor.tensor = torch.int_repr(int16_weight.detach()).numpy().astype(np.int16)
weight_tensor.dtype = weight_tensor.tensor.dtype
weight_tensor.buffer = tfl.Buffer(weight_tensor.tensor.tobytes())

outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
output_tensor = outputs[0]

Expand Down
10 changes: 9 additions & 1 deletion tinynn/graph/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class PTQFakeQuantize(torch.quantization.FakeQuantize):
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
if self.scale.shape != _scale.shape:
self.scale.resize_(_scale.shape)
self.zero_point.resize_(_zero_point.shape)
self.scale.copy_(_scale)
self.zero_point.copy_(_zero_point)

if self.fake_quant_enabled[0] == 1:
if self.scale == 1 and self.zero_point == 0:
Expand Down Expand Up @@ -115,9 +122,10 @@ def set_ptq_fake_quantize(name, module):
reduce_range=False,
)
asym_fq = PTQFakeQuantize.with_args(
observer=torch.quantization.HistogramObserver,
observer=torch.quantization.MinMaxObserver,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_symmetric,
dtype=torch.quint8,
reduce_range=False,
)
Expand Down
4 changes: 3 additions & 1 deletion tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,9 @@ def rescale_activations_with_quant_min_max(self, quant_min: int, quant_max: int)
"""Rescales activations with provided quant_min and quant_max"""
for n, m in self.model.named_modules():
if '.weight_fake_quant' in n:
continue
quant_min = -32768
quant_max = 32767
# continue

if isinstance(m, torch.quantization.FakeQuantize):
observer = getattr(m, 'activation_post_process', None)
Expand Down
Loading