55import time
66import warnings
77from enum import Enum
8+ import math
89
910import torch
1011import torch .nn as nn
2122
2223from model import resnet20
2324from sparsebit .quantization import QuantModel , parse_qconfig
25+ from sparsebit .quantization .regularizers import build_regularizer
2426
2527
2628parser = argparse .ArgumentParser (description = "PyTorch Cifar Training" )
@@ -147,8 +149,6 @@ def main():
147149
148150 qconfig = parse_qconfig (args .config )
149151
150- is_pact = qconfig .A .QUANTIZER .TYPE == "pact"
151-
152152 qmodel = QuantModel (model , qconfig ).cuda () # 将model转化为量化模型,以支持后续QAT的各种量化操作
153153
154154 # set head and tail of model is 8bit
@@ -181,6 +181,11 @@ def main():
181181 optimizer , milestones = [100 , 150 ], last_epoch = args .start_epoch - 1
182182 )
183183
184+ if qconfig .REGULARIZER .ENABLE :
185+ regularizer = build_regularizer (qconfig )
186+ else :
187+ regularizer = None
188+
184189 best_acc1 = 0
185190 for epoch in range (args .start_epoch , args .epochs ):
186191 # train for one epoch
@@ -190,7 +195,7 @@ def main():
190195 criterion ,
191196 optimizer ,
192197 epoch ,
193- is_pact ,
198+ regularizer ,
194199 args .regularizer_lambda ,
195200 args .print_freq ,
196201 )
@@ -247,7 +252,7 @@ def train(
247252 criterion ,
248253 optimizer ,
249254 epoch ,
250- is_pact ,
255+ regularizer ,
251256 regularizer_lambda ,
252257 print_freq ,
253258):
@@ -278,7 +283,7 @@ def train(
278283 # compute output
279284 output = model (images )
280285 ce_loss = criterion (output , target )
281- regular_loss = get_regularizer_loss (model , is_pact , scale = regularizer_lambda )
286+ regular_loss = get_regularizer_loss (model , regularizer , regularizer_lambda )
282287 loss = ce_loss + regular_loss
283288
284289 # measure accuracy and record loss
0 commit comments