Skip to content

Commit 4c5ae67

Browse files
committed
add regularizer
1 parent 12fa56e commit 4c5ae67

File tree

8 files changed

+120
-5
lines changed

8 files changed

+120
-5
lines changed

examples/quantization_aware_training/cifar10/basecase/main.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import warnings
77
from enum import Enum
8+
import math
89

910
import torch
1011
import torch.nn as nn
@@ -21,6 +22,7 @@
2122

2223
from model import resnet20
2324
from sparsebit.quantization import QuantModel, parse_qconfig
25+
from sparsebit.quantization.regularizers import build_regularizer
2426

2527

2628
parser = argparse.ArgumentParser(description="PyTorch Cifar Training")
@@ -147,8 +149,6 @@ def main():
147149

148150
qconfig = parse_qconfig(args.config)
149151

150-
is_pact = qconfig.A.QUANTIZER.TYPE == "pact"
151-
152152
qmodel = QuantModel(model, qconfig).cuda() # 将model转化为量化模型,以支持后续QAT的各种量化操作
153153

154154
# set head and tail of model is 8bit
@@ -181,6 +181,11 @@ def main():
181181
optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1
182182
)
183183

184+
if qconfig.REGULARIZER.ENABLE:
185+
regularizer = build_regularizer(qconfig)
186+
else:
187+
regularizer = None
188+
184189
best_acc1 = 0
185190
for epoch in range(args.start_epoch, args.epochs):
186191
# train for one epoch
@@ -190,7 +195,7 @@ def main():
190195
criterion,
191196
optimizer,
192197
epoch,
193-
is_pact,
198+
regularizer,
194199
args.regularizer_lambda,
195200
args.print_freq,
196201
)
@@ -247,7 +252,7 @@ def train(
247252
criterion,
248253
optimizer,
249254
epoch,
250-
is_pact,
255+
regularizer,
251256
regularizer_lambda,
252257
print_freq,
253258
):
@@ -278,7 +283,7 @@ def train(
278283
# compute output
279284
output = model(images)
280285
ce_loss = criterion(output, target)
281-
regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda)
286+
regular_loss = get_regularizer_loss(model, regularizer, regularizer_lambda)
282287
loss = ce_loss + regular_loss
283288

284289
# measure accuracy and record loss
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
BACKEND: virtual
2+
W:
3+
QSCHEME: per-channel-symmetric
4+
QUANTIZER:
5+
TYPE: lsq
6+
BIT: 4
7+
A:
8+
QSCHEME: per-tensor-affine
9+
QUANTIZER:
10+
TYPE: lsq
11+
BIT: 4
12+
REGULARIZER:
13+
ENABLE: True
14+
TYPE: dampen

examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ A:
99
QUANTIZER:
1010
TYPE: pact
1111
BIT: 4
12+
REGULARIZER:
13+
ENABLE: True
14+
TYPE: pact

sparsebit/quantization/quant_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
_C.A.QADD.ENABLE_QUANT = False
4848
_C.A.SPECIFIC = []
4949

50+
_C.REGULARIZER = CN()
51+
_C.REGULARIZER.ENABLE = False
52+
_C.REGULARIZER.TYPE = ""
53+
5054

5155
def parse_qconfig(cfg_file):
5256
qconfig = _parse_config(cfg_file, default_cfg=_C)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
REGULARIZERS_MAP = {}
2+
3+
4+
def register_regularizer(regularizer):
5+
REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer
6+
return regularizer
7+
8+
9+
from .base import Regularizer
10+
from . import dampen, pact
11+
12+
13+
def build_regularizer(config):
14+
regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config)
15+
return regularizer
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class Regularizer(object):
2+
def __init__(self, config):
3+
self.config = config
4+
5+
def __call__(self):
6+
pass
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
3+
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
4+
from sparsebit.quantization.regularizers import register_regularizer
5+
6+
7+
@register_regularizer
8+
class Regularizer(BaseRegularizer):
9+
TYPE = "Dampen"
10+
11+
def __init__(self, config):
12+
super(Regularizer, self).__init__(config)
13+
self.config = config
14+
15+
def _get_loss(self, x, quantizer):
16+
17+
x_q = quantizer(x)
18+
19+
qmin, qmax = quantizer.qdesc.qrange
20+
21+
scale, zero_point = quantizer._qparams_preprocess(x)
22+
23+
scale = scale.detach()
24+
zero_point = zero_point.detach()
25+
26+
min_val = (qmin - zero_point) * scale
27+
28+
max_val = (qmax - zero_point) * scale
29+
30+
x_c = torch.min(torch.max(x, min_val), max_val)
31+
32+
loss = (x_q - x_c) ** 2
33+
34+
loss = loss.sum()
35+
36+
return loss
37+
38+
def __call__(self, model):
39+
loss = 0.0
40+
for n, m in model.named_modules():
41+
if (
42+
hasattr(m, "weight")
43+
and hasattr(m, "weight_quantizer")
44+
and m.weight_quantizer
45+
and m.weight_quantizer.is_enable
46+
):
47+
loss += self._get_loss(m.weight, m.weight_quantizer)
48+
return loss
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
3+
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
4+
from sparsebit.quantization.regularizers import register_regularizer
5+
6+
7+
@register_regularizer
8+
class Regularizer(BaseRegularizer):
9+
TYPE = "Pact"
10+
11+
def __init__(self, config):
12+
super(Regularizer, self).__init__(config)
13+
self.config = config
14+
15+
def __call__(self, model):
16+
loss = 0.0
17+
for n, p in model.named_parameters():
18+
if "alpha" in n:
19+
loss += (p ** 2).sum()
20+
return loss

0 commit comments

Comments
 (0)