Skip to content

Commit 279f9c3

Browse files
committed
2 parents 0e3d80a + 88a679a commit 279f9c3

15 files changed

+1929
-190
lines changed

docs/index.rst

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ Install in editable mode in a venv:
6363
pip install -e .[testing, docs, notebooks]
6464

6565

66+
Test suite
67+
++++++++++
68+
6669
Run entire test suite, parallelized across CPU cores:
6770

6871
::

docs/license.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. _license:
22

3-
=======
3+
========
44
License
5-
=======
5+
========
66

77
.. include:: ../LICENSE

notebooks/4_quant_lstm.ipynb

+933
Large diffs are not rendered by default.
+335
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import torch
5+
from brevitas.export.onnx import onnx_export_opset
6+
from torch.autograd import Function
7+
8+
AXIS_OPSET = 13
9+
DOMAIN_STRING = "onnx.brevitas"
10+
11+
12+
class DequantizeLinearFn(Function):
13+
@staticmethod
14+
def symbolic(g, x, input_scale, input_zero_point, input_axis):
15+
opset_version = onnx_export_opset()
16+
17+
if input_axis is not None and opset_version < AXIS_OPSET:
18+
raise RuntimeError("ONNX Opset 13 is required for per-channel quantization")
19+
elif input_axis is not None and opset_version >= AXIS_OPSET:
20+
ret = g.op("DequantizeLinear", x, input_scale, input_zero_point, axis_i=input_axis)
21+
else:
22+
ret = g.op("DequantizeLinear", x, input_scale, input_zero_point)
23+
return ret
24+
25+
@staticmethod
26+
def forward(ctx, int_x, input_scale, input_zero_point, input_axis):
27+
return int_x.float()
28+
29+
30+
class IntClipFn(Function):
31+
@staticmethod
32+
def symbolic(g, int_x, min_int_val, max_int_val):
33+
ret = g.op("Clip", int_x, min_int_val, max_int_val)
34+
return ret
35+
36+
@staticmethod
37+
def forward(ctx, int_x, min_int_val, max_int_val):
38+
return int_x
39+
40+
41+
class QuantizeLinearFn(Function):
42+
@staticmethod
43+
def symbolic(g, x, output_scale, ouput_zero_point, output_dtype, output_axis):
44+
opset_version = onnx_export_opset()
45+
46+
if output_axis is not None and opset_version < AXIS_OPSET:
47+
raise RuntimeError("ONNX Opset 13 is required for per-channel quantization")
48+
elif output_axis is not None and opset_version >= AXIS_OPSET:
49+
ret = g.op("QuantizeLinear", x, output_scale, ouput_zero_point, axis_i=output_axis)
50+
else:
51+
ret = g.op("QuantizeLinear", x, output_scale, ouput_zero_point)
52+
return ret
53+
54+
@staticmethod
55+
def forward(ctx, x, output_scale, ouput_zero_point, output_dtype, output_axis):
56+
return x.type(output_dtype)
57+
58+
59+
class BrevitasQuantLSTMCellFn(Function):
60+
@staticmethod
61+
def symbolic(
62+
g, # args and kwargs passed from _QuantLSTMLayer
63+
quant_input,
64+
quant_hidden_state,
65+
quant_cell_state,
66+
quant_weight_ii,
67+
quant_weight_if,
68+
quant_weight_ic,
69+
quant_weight_io,
70+
quant_weight_hi,
71+
quant_weight_hf,
72+
quant_weight_hc,
73+
quant_weight_ho,
74+
quant_bias_input,
75+
quant_bias_forget,
76+
quant_bias_cell,
77+
quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler
78+
batch_first,
79+
reverse_input,
80+
cifg, # Output quant
81+
output_scale,
82+
output_zero_point,
83+
output_bit_width,
84+
output_narrow_range,
85+
output_signed,
86+
output_rounding_mode, # Cell state quant
87+
cell_state_scale,
88+
cell_state_zero_point,
89+
cell_state_bit_width,
90+
cell_state_narrow_range,
91+
cell_state_signed,
92+
cell_state_rounding_mode, # Input gate accumulator quant
93+
input_acc_scale,
94+
input_acc_zero_point,
95+
input_acc_bit_width,
96+
input_acc_narrow_range,
97+
input_acc_signed,
98+
input_acc_rounding_mode, # Forget gate accumulator quant
99+
forget_acc_scale,
100+
forget_acc_zero_point,
101+
forget_acc_bit_width,
102+
forget_acc_narrow_range,
103+
forget_acc_signed,
104+
forget_acc_rounding_mode, # Cell gate accumulator quant
105+
cell_acc_scale,
106+
cell_acc_zero_point,
107+
cell_acc_bit_width,
108+
cell_acc_narrow_range,
109+
cell_acc_signed,
110+
cell_acc_rounding_mode, # Output gate accumulator quant
111+
output_acc_scale,
112+
output_acc_zero_point,
113+
output_acc_bit_width,
114+
output_acc_narrow_range,
115+
output_acc_signed,
116+
output_acc_rounding_mode, # Input gate sigmoid quant
117+
input_sigmoid_scale,
118+
input_sigmoid_zero_point,
119+
input_sigmoid_bit_width,
120+
input_sigmoid_narrow_range,
121+
input_sigmoid_signed,
122+
input_sigmoid_rounding_mode, # Forget gate sigmoid quant
123+
forget_sigmoid_scale,
124+
forget_sigmoid_zero_point,
125+
forget_sigmoid_bit_width,
126+
forget_sigmoid_narrow_range,
127+
forget_sigmoid_signed,
128+
forget_sigmoid_rounding_mode, # Cell gate tanh quant
129+
cell_tanh_scale,
130+
cell_tanh_zero_point,
131+
cell_tanh_bit_width,
132+
cell_tanh_narrow_range,
133+
cell_tanh_signed,
134+
cell_tanh_rounding_mode, # Output gate sigmoid quant
135+
output_sigmoid_scale,
136+
output_sigmoid_zero_point,
137+
output_sigmoid_bit_width,
138+
output_sigmoid_narrow_range,
139+
output_sigmoid_signed,
140+
output_sigmoid_rounding_mode, # Hidden state tanh quant
141+
hidden_state_tanh_scale,
142+
hidden_state_tanh_zero_point,
143+
hidden_state_tanh_bit_width,
144+
hidden_state_tanh_narrow_range,
145+
hidden_state_tanh_signed,
146+
hidden_state_tanh_rounding_mode,
147+
):
148+
return g.op(
149+
f"{DOMAIN_STRING}::QuantLSTMCell", # Tensors
150+
# Input values
151+
quant_input,
152+
quant_hidden_state,
153+
quant_cell_state,
154+
quant_weight_ii,
155+
quant_weight_if,
156+
quant_weight_ic,
157+
quant_weight_io,
158+
quant_weight_hi,
159+
quant_weight_hf,
160+
quant_weight_hc,
161+
quant_weight_ho,
162+
quant_bias_input,
163+
quant_bias_forget,
164+
quant_bias_cell,
165+
quant_bias_output, # Output quant
166+
output_scale,
167+
output_zero_point,
168+
output_bit_width, # Cell state quant
169+
cell_state_scale,
170+
cell_state_zero_point,
171+
cell_state_bit_width, # Input gate accumulator quant
172+
input_acc_scale,
173+
input_acc_zero_point,
174+
input_acc_bit_width, # Forget gate accumulator quant
175+
forget_acc_scale,
176+
forget_acc_zero_point,
177+
forget_acc_bit_width, # Cell gate accumulator quant
178+
cell_acc_scale,
179+
cell_acc_zero_point,
180+
cell_acc_bit_width, # Output gate accumulator quant
181+
output_acc_scale,
182+
output_acc_zero_point,
183+
output_acc_bit_width, # Input gate sigmoid quant
184+
input_sigmoid_scale,
185+
input_sigmoid_zero_point,
186+
input_sigmoid_bit_width, # Forget gate sigmoid quant
187+
forget_sigmoid_scale,
188+
forget_sigmoid_zero_point,
189+
forget_sigmoid_bit_width, # Cell gate tanh quant
190+
cell_tanh_scale,
191+
cell_tanh_zero_point,
192+
cell_tanh_bit_width, # Output gate sigmoid quant
193+
output_sigmoid_scale,
194+
output_sigmoid_zero_point,
195+
output_sigmoid_bit_width, # Hidden state tanh quant
196+
hidden_state_tanh_scale,
197+
hidden_state_tanh_zero_point,
198+
hidden_state_tanh_bit_width,
199+
# Attributes
200+
batch_first_i=batch_first,
201+
reverse_input_i=reverse_input,
202+
cifg_i=cifg,
203+
output_narrow_i=output_narrow_range,
204+
output_signed_i=output_signed,
205+
output_rounding_mode_s=output_rounding_mode,
206+
cell_state_narrow_i=cell_state_narrow_range,
207+
cell_state_signed_i=cell_state_signed,
208+
cell_state_rounding_mode_s=cell_state_rounding_mode,
209+
input_acc_narrow_i=input_acc_narrow_range,
210+
input_acc_signed_i=input_acc_signed,
211+
input_acc_rounding_mode_s=input_acc_rounding_mode,
212+
forget_acc_narrow_i=forget_acc_narrow_range,
213+
forget_acc_signed_i=forget_acc_signed,
214+
forget_acc_rounding_mode_s=forget_acc_rounding_mode,
215+
cell_acc_narrow_i=cell_acc_narrow_range,
216+
cell_acc_signed_i=cell_acc_signed,
217+
cell_acc_rounding_mode_s=cell_acc_rounding_mode,
218+
output_acc_narrow_i=output_acc_narrow_range,
219+
output_acc_signed_i=output_acc_signed,
220+
output_acc_rounding_mode_s=output_acc_rounding_mode,
221+
input_sigmoid_narrow_i=input_sigmoid_narrow_range,
222+
input_sigmoid_signed_i=input_sigmoid_signed,
223+
input_sigmoid_rounding_mode_s=input_sigmoid_rounding_mode,
224+
forget_sigmoid_narrow_i=forget_sigmoid_narrow_range,
225+
forget_sigmoid_signed_i=forget_sigmoid_signed,
226+
forget_sigmoid_rounding_mode_s=forget_sigmoid_rounding_mode,
227+
cell_tanh_narrow_i=cell_tanh_narrow_range,
228+
cell_tanh_signed_i=cell_tanh_signed,
229+
cell_tanh_rounding_mode_s=cell_tanh_rounding_mode,
230+
output_sigmoid_narrow_range_i=output_sigmoid_narrow_range,
231+
output_sigmoid_signed_i=output_sigmoid_signed,
232+
output_sigmoid_rounding_mode_s=output_sigmoid_rounding_mode,
233+
hidden_state_tanh_narrow_i=hidden_state_tanh_narrow_range,
234+
hidden_state_tanh_signed_i=hidden_state_tanh_signed,
235+
hidden_state_tanh_rounding_mode_s=hidden_state_tanh_rounding_mode,
236+
# PyTorch requires to specify the number of outputs manually
237+
outputs=3,
238+
)
239+
240+
@staticmethod
241+
def forward(
242+
ctx, # args and kwargs passed from _QuantLSTMLayer
243+
quant_input,
244+
quant_hidden_state,
245+
quant_cell_state,
246+
quant_weight_ii,
247+
quant_weight_if,
248+
quant_weight_ic,
249+
quant_weight_io,
250+
quant_weight_hi,
251+
quant_weight_hf,
252+
quant_weight_hc,
253+
quant_weight_ho,
254+
quant_bias_input,
255+
quant_bias_forget,
256+
quant_bias_cell,
257+
quant_bias_output, # Symbolic kwargs passed from BrevitasQuantLSTMLayerHandler
258+
batch_first,
259+
reverse_input,
260+
cifg, # Output quant
261+
output_scale,
262+
output_zero_point,
263+
output_bit_width,
264+
output_narrow_range,
265+
output_signed,
266+
output_rounding_mode, # Cell state quant
267+
cell_state_scale,
268+
cell_state_zero_point,
269+
cell_state_bit_width,
270+
cell_state_narrow_range,
271+
cell_state_signed,
272+
cell_state_rounding_mode, # Input gate accumulator quant
273+
input_acc_scale,
274+
input_acc_zero_point,
275+
input_acc_bit_width,
276+
input_acc_narrow_range,
277+
input_acc_signed,
278+
input_acc_rounding_mode, # Forget gate accumulator quant
279+
forget_acc_scale,
280+
forget_acc_zero_point,
281+
forget_acc_bit_width,
282+
forget_acc_narrow_range,
283+
forget_acc_signed,
284+
forget_acc_rounding_mode, # Cell gate accumulator quant
285+
cell_acc_scale,
286+
cell_acc_zero_point,
287+
cell_acc_bit_width,
288+
cell_acc_narrow_range,
289+
cell_acc_signed,
290+
cell_acc_rounding_mode, # Output gate accumulator quant
291+
output_acc_scale,
292+
output_acc_zero_point,
293+
output_acc_bit_width,
294+
output_acc_narrow_range,
295+
output_acc_signed,
296+
output_acc_rounding_mode, # Input gate sigmoid quant
297+
input_sigmoid_scale,
298+
input_sigmoid_zero_point,
299+
input_sigmoid_bit_width,
300+
input_sigmoid_narrow_range,
301+
input_sigmoid_signed,
302+
input_sigmoid_rounding_mode, # Forget gate sigmoid quant
303+
forget_sigmoid_scale,
304+
forget_sigmoid_zero_point,
305+
forget_sigmoid_bit_width,
306+
forget_sigmoid_narrow_range,
307+
forget_sigmoid_signed,
308+
forget_sigmoid_rounding_mode, # Cell gate tanh quant
309+
cell_tanh_scale,
310+
cell_tanh_zero_point,
311+
cell_tanh_bit_width,
312+
cell_tanh_narrow_range,
313+
cell_tanh_signed,
314+
cell_tanh_rounding_mode, # Output gate sigmoid quant
315+
output_sigmoid_scale,
316+
output_sigmoid_zero_point,
317+
output_sigmoid_bit_width,
318+
output_sigmoid_narrow_range,
319+
output_sigmoid_signed,
320+
output_sigmoid_rounding_mode, # Hidden state tanh quant
321+
hidden_state_tanh_scale,
322+
hidden_state_tanh_zero_point,
323+
hidden_state_tanh_bit_width,
324+
hidden_state_tanh_narrow_range,
325+
hidden_state_tanh_signed,
326+
hidden_state_tanh_rounding_mode,
327+
):
328+
# Tp simplify things, here we are returning the outputs
329+
# as if they were already concatenated. Scale/zp/bw are avoided too.
330+
# This preserves output shapes but not values.
331+
# See _QuantLSTMCell for the actual implementation.
332+
quant_outputs = torch.zeros(
333+
quant_input.size(0), quant_input.size(1), quant_hidden_state.size(1), device=quant_hidden_state.device
334+
)
335+
return quant_outputs, quant_hidden_state, quant_cell_state

0 commit comments

Comments
 (0)