Skip to content

Commit 46b7a88

Browse files
calad0iJanFSchultepre-commit-ci[bot]HamzaEzzRa
authored
Distributed Arithmetic strategy for Dense, Conv1/2D, and EinsumDense (#1191)
* skip oneapi test if icpx doesnt exist * bit-exact-possible multidim softmax * softmax fix softmax fix fix softmax parsing issue ckpt softmax fix fix table size after overriding inv_inp_t * move softmax attr to fpga backend, post rebase fix * add keras v3 object parser * add keras v3 layer handlers * einsumdense and einsum * add einsum templates * bit-exact-possible multidim softmax * symbolic bitwidth infer util * add qinterval test * keras v2-v3 reshape fn compability patch * hgq2 layer handlers * add bit-exact enforcement pass * fix softmax accum fractional bits derivation * add qeinsum test * env update * remove hgq v1 rewire behavier (superseded) * fix del method in config class * distributed arithmetic impl * distributed arithmetic impl w/ conv * distributed arithmetic templates * prevent pointwise override for DA strategy * add test for DA * update proj conf * disable delay_constraint in da4ml * add hgq2 mha test * update ci template * require da4ml version * pre-commit fix * proper skip qeinsum test when condition not met * softmax and activation fix * hgq2 api change, prevent zero bw activation crash syn * qinterval and bn bit-exactness fix * fix einsum ax expansion and 0d output handling * fix merge templates * converter and bit-exact pass for ops layers * use pointwise 2d for conv2d due to unknown flow changing * fix einsum dense da impl corner case * qinterval type fix * fix corner case in qkeras converted proxy * support mha def in (q,v) format * update da4ml binding syntax * update da4ml binding syntax x2 * use fixedvararr obj for da codegen * more general build_lib script * bring back hgq proxy embedded properties excl. pecision * fix streaming conv1/2d da regression * streaming template support for DA fix * allow non-po-2 avg pooling * ignore batch dim in parse_data_format * keras v3 native pooling layer parser * globalpooling handler fix * unary lut bw derivation update * keras 3.10 api change * namespace fix for pointwise conv * use constexpr for dim def * conv pf handling * keras v3 api change * quality-of-life changes * kv3 parser update * shut up! * post-rebase import conflicts * remaining post-rebase fix * bit-exactness corner case * quantizer shrink corner case fix (sign bit) * allow 0 bit activation... * template and test fix * intel/ac_types/ac_int.hpp:156:30: error: unsigned _BitInt must have a bit size of at least 1 * static_cast<typename ExtractPipeType<res_pipe>::value_type::value_type> * doc leftover * comment * style * [pre-commit.ci] auto fixes from pre-commit hooks * model opt pass fix and avg pool fix * squashed cosmetic and minor changes * multi graph dimname fix * Add Cropping layers support (#1309) * added Cropping1D and Cropping2D keras layers support * removed .bak templates files * added cropping layers tests for vivado and vitis * [pre-commit.ci] auto fixes from pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * bump da4ml version * bit-exact algorithm minor change --------- Co-authored-by: Jan-Frederik Schulte <jschulte@cern.ch> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hamza Ezzaoui Rahali <hamzaezzaouirahali@gmail.com>
1 parent 71bf4ae commit 46b7a88

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+3254
-352
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ repos:
4747
exclude: docs/conf.py
4848
additional_dependencies: [flake8-bugbear, flake8-print]
4949
args: ['--max-line-length=125', # github viewer width
50-
'--extend-ignore=E203,T201'] # E203 is not PEP8 compliant
50+
'--extend-ignore=E203,T201', # E203 is not PEP8 compliant
51+
'--per-file-ignores=hls4ml/model/optimizer/passes/bit_exact.py:E741',
52+
# i for #int w/o sign, I for #int w/ sign when massively processing bw conversions
53+
]
5154

5255
- repo: https://github.com/mgedmin/check-manifest
5356
rev: "0.50"

Jenkinsfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pipeline {
1616
sh '''#!/bin/bash --login
1717
conda activate hls4ml-py310
1818
conda install -y jupyterhub pydot graphviz pytest pytest-cov
19-
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/jmitrevs/qkeras.git@qrecurrent_unstack pyparsing
19+
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/jmitrevs/qkeras.git@qrecurrent_unstack pyparsing quantizers da4ml
2020
pip install -U ../ --user
2121
./convert-keras-models.sh -x -f keras-models.txt
2222
pip uninstall hls4ml -y'''

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ config = hls4ml.utils.fetch_example_model('KERAS_3layer.json')
4545
print(config)
4646

4747
# Convert it to a hls project
48-
hls_model = hls4ml.converters.keras_to_hls(config)
48+
hls_model = hls4ml.converters.keras_v2_to_hls(config)
4949

5050
# Print full list of example models if you want to explore more
5151
hls4ml.utils.fetch_example_list()

docs/advanced/profiling.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ You will need to initialise these objects by using a trained model, loading a mo
1414
.. code-block:: python
1515
1616
from hls4ml.model.profiling import numerical
17-
from hls4ml.converters import keras_to_hls
17+
from hls4ml.converters import keras_v2_to_hls
1818
import matplotlib.pyplot as plt
1919
import yaml
2020
@@ -27,7 +27,7 @@ You will need to initialise these objects by using a trained model, loading a mo
2727
with open("keras-config.yml", 'r') as ymlfile:
2828
config = yaml.load(ymlfile)
2929
30-
hls_model = keras_to_hls(config)
30+
hls_model = keras_v2_to_hls(config)
3131
3232
# produce 4 plots
3333
plots = numerical(model=model, hls_model = hls_model, X=X)

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from hls4ml.backends.backend import Backend
10-
from hls4ml.model.attributes import ChoiceAttribute, ConfigurableAttribute, TypeAttribute
10+
from hls4ml.model.attributes import Attribute, ChoiceAttribute, ConfigurableAttribute, TypeAttribute
1111
from hls4ml.model.layers import (
1212
GRU,
1313
LSTM,
@@ -109,32 +109,37 @@ def __init__(self, name):
109109
act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type))
110110
self.attribute_map[Activation] = act_attrs
111111

112-
softmax_attrs = self.attribute_map.get(Softmax, [])
113-
softmax_attrs.append(
112+
softmax_attrs = [
113+
Attribute('n_in'),
114+
Attribute('activation', value_type=str),
115+
Attribute('n_outer', value_type=int, default=1),
116+
Attribute('n_inner', value_type=int, default=1),
114117
ChoiceAttribute(
115118
'implementation',
116119
['latency', 'stable', 'argmax', 'legacy'],
117120
default='stable',
118121
description=descriptions.softmax_implementation,
119-
)
120-
)
121-
softmax_attrs.append(
122-
ConfigurableAttribute('skip', value_type=bool, default=False, description=descriptions.softmax_skip)
123-
)
124-
softmax_attrs.append(
122+
),
123+
ConfigurableAttribute('skip', value_type=bool, default=False, description=descriptions.softmax_skip),
125124
TypeAttribute(
126125
'exp_table',
127126
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
128127
description=descriptions.table_type,
129-
)
130-
)
131-
softmax_attrs.append(
128+
),
132129
TypeAttribute(
133130
'inv_table',
134131
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
135132
description=descriptions.table_type,
136-
)
137-
)
133+
),
134+
TypeAttribute(
135+
'inv_inp',
136+
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
137+
),
138+
TypeAttribute(
139+
'accum',
140+
default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT),
141+
),
142+
]
138143
self.attribute_map[Softmax] = softmax_attrs
139144

140145
def create_layer_class(self, layer_class):

hls4ml/backends/fpga/passes/fix_softmax_table_size.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
class FixSoftmaxTableSize(OptimizerPass):
88
def match(self, node):
9-
return isinstance(node, Softmax)
9+
if not isinstance(node, Softmax):
10+
return False
11+
if 'inv_table_size' in node.attributes:
12+
return False # handler generating inv_table_size sets it properly
13+
return True
1014

1115
def transform(self, model, node: Layer):
1216
inp_layer = node.get_input_node() # type: ignore

hls4ml/backends/fpga/passes/hgq_proxy_model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ def match(self, node: Layer):
5252
return isinstance(node, FixedPointQuantizer)
5353

5454
def transform(self, model, node: FixedPointQuantizer):
55-
if node.fusible:
56-
model.remove_node(node)
57-
return True
58-
5955
if model.config.config['IOType'] != 'io_parallel':
6056
raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel')
6157

@@ -96,7 +92,6 @@ def __init__(self):
9692

9793
def format(self, node):
9894
params = self._default_function_params(node)
99-
node.attributes['result_t'].precision = node.attributes['table_t'].precision
10095
params['config'] = f'unary_lut_config{node.index}'
10196
params['table'] = node.get_weights('table').name
10297

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,21 @@ def format(self, node):
154154
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin'
155155
elif node.get_attr('strategy').lower() == 'resource_unrolled':
156156
mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
157+
elif node.get_attr('strategy').lower() == 'distributed_arithmetic':
158+
mult_params['dense_function'] = f'{namespace}::dense_da_wrapper_{node.index}'
157159

158160
mult_config = self.mult_template.format(**mult_params)
159161

160162
return mult_config + '\n' + conv_config
161163

164+
def match(self, node):
165+
if node.get_attr('strategy') == 'distributed_arithmetic':
166+
io_type = node.model.config.get_config_value("IOType")
167+
if io_type == 'io_parallel':
168+
# DA impl use alternate entry point for io_parallel conv
169+
return False
170+
return super().match(node)
171+
162172

163173
class Conv1DFunctionTemplate(FunctionCallTemplate):
164174
def __init__(self):
@@ -173,6 +183,14 @@ def format(self, node):
173183

174184
return self.template.format(**params)
175185

186+
def match(self, node):
187+
if node.get_attr('strategy') == 'distributed_arithmetic':
188+
io_type = node.model.config.get_config_value("IOType")
189+
if io_type == 'io_parallel':
190+
# DA impl use alternate entry point for io_parallel conv
191+
return False
192+
return super().match(node)
193+
176194

177195
class DepthwiseConv1DFunctionTemplate(Conv1DFunctionTemplate):
178196
def __init__(self):
@@ -299,11 +317,21 @@ def format(self, node):
299317
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin'
300318
elif node.get_attr('strategy').lower() == 'resource_unrolled':
301319
mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
320+
elif node.get_attr('strategy').lower() == 'distributed_arithmetic':
321+
mult_params['dense_function'] = f'{namespace}::dense_da_wrapper_{node.index}'
302322

303323
mult_config = self.mult_template.format(**mult_params)
304324

305325
return mult_config + '\n' + conv_config
306326

327+
def match(self, node):
328+
if node.get_attr('strategy') == 'distributed_arithmetic':
329+
io_type = node.model.config.get_config_value("IOType")
330+
if io_type == 'io_parallel':
331+
# DA impl use alternate entry point for io_parallel conv
332+
return False
333+
return super().match(node)
334+
307335

308336
class Conv2DFunctionTemplate(FunctionCallTemplate):
309337
def __init__(self):
@@ -318,6 +346,14 @@ def format(self, node):
318346

319347
return self.template.format(**params)
320348

349+
def match(self, node):
350+
if node.get_attr('strategy') == 'distributed_arithmetic':
351+
io_type = node.model.config.get_config_value("IOType")
352+
if io_type == 'io_parallel':
353+
# DA impl use alternate entry point for io_parallel conv
354+
return False
355+
return super().match(node)
356+
321357

322358
class DepthwiseConv2DFunctionTemplate(Conv2DFunctionTemplate):
323359
def __init__(self):

hls4ml/backends/vivado/passes/core_templates.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from math import ceil, log2
2+
13
from hls4ml.backends.backend import get_backend
24
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
35
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
@@ -55,9 +57,17 @@ def format(self, node):
5557
# The 3rd case is never used
5658
elif node.get_attr('strategy').lower() == 'resource_unrolled':
5759
params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
60+
elif node.get_attr('strategy').lower() == 'distributed_arithmetic':
61+
# Only triggered in io_streaming mode
62+
params['dense_function'] = f'{namespace}::dense_da_wrapper_{node.index}'
5863

5964
return self.template.format(**params)
6065

66+
def match(self, node):
67+
if node.get_attr('strategy') == 'distributed_arithmetic':
68+
return False # DA does not use common dense template
69+
return super().match(node)
70+
6171

6272
class DenseFunctionTemplate(FunctionCallTemplate):
6373
def __init__(self):
@@ -71,6 +81,11 @@ def format(self, node):
7181

7282
return self.template.format(**params)
7383

84+
def match(self, node):
85+
if node.get_attr('strategy') == 'distributed_arithmetic':
86+
return False # DA does not use common dense template
87+
return super().match(node)
88+
7489

7590
# BatchNormalization templates
7691

@@ -152,13 +167,22 @@ def format(self, node):
152167

153168
softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{
154169
static const unsigned n_in = {n_in};
155-
static const unsigned table_size = {table_size};
170+
static const unsigned n_slice = {n_slice};
171+
static const unsigned n_outer = {n_outer};
172+
static const unsigned n_inner = {n_inner};
173+
static const unsigned parallelization_factor = {parallelization_factor};
174+
static const unsigned exp_table_size = {exp_table_size};
175+
static const unsigned inv_table_size = {inv_table_size};
156176
static const unsigned io_type = nnet::{iotype};
157177
static const unsigned reuse_factor = {reuse};
158178
static const unsigned axis = {axis};
159179
static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation};
180+
static constexpr float exp_scale = {exp_scale};
160181
typedef {exp_table_t.name} exp_table_t;
161182
typedef {inv_table_t.name} inv_table_t;
183+
typedef {accum_t.name} accum_t;
184+
typedef {inv_inp_t.name} inv_inp_t;
185+
typedef {inp_norm_t_str} inp_norm_t;
162186
}};\n"""
163187

164188
activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
@@ -210,10 +234,66 @@ def __init__(self):
210234
super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__
211235
self.template = softmax_config_template
212236

237+
def format(self, node):
238+
params = self._default_config_params(node)
239+
params['type'] = node.get_attr('activation')
240+
params.setdefault('exp_table_size', params['table_size'])
241+
params.setdefault('inv_table_size', params['table_size'])
242+
params.setdefault('n_inner', 1)
243+
params.setdefault('n_outer', 1)
244+
params.setdefault('exp_scale', 1.0)
245+
params.setdefault('parallelization_factor', -1)
246+
247+
n_slice = params['n_in'] // params['n_inner'] // params['n_outer'] # type: ignore
248+
params['n_slice'] = n_slice
249+
250+
if params['accum_t'].name == 'model_default_t': # type: ignore
251+
scale = ceil(log2(n_slice))
252+
exp_table_t = node.attributes['exp_table_t'].precision
253+
signed, width, integers = exp_table_t.signed, exp_table_t.width, exp_table_t.integer
254+
params['accum_t_str'] = f'ap_{"" if signed else "u"}fixed<{width + scale}, {integers + scale}>'
255+
else:
256+
params['accum_t_str'] = params['accum_t'].name # type: ignore
257+
if params['inv_inp_t'].name == 'model_default_t': # type: ignore
258+
params['inv_inp_t'] = params['exp_table_t']
259+
260+
if params['implementation'] == 'stable':
261+
if 'inp_norm_t' not in params:
262+
# Only used in stable (max-normalized) implementation
263+
input_t = node.get_input_variable().type.precision
264+
width, iwidth, signed = input_t.width, input_t.integer, input_t.signed # noqa: F841
265+
width, iwidth = width - signed, iwidth - signed
266+
if signed:
267+
# Fix table size if too large
268+
exp_table_size = params['inv_table_size']
269+
params['exp_table_size'] = str(min(int(exp_table_size), 2**width))
270+
params['inp_norm_t_str'] = f'ap_ufixed<{width}, {iwidth}>'
271+
else:
272+
params['inp_norm_t_str'] = params['inp_norm_t'].name # type: ignore
273+
else:
274+
params['inp_norm_t_str'] = 'ap_fixed<1,0>'
275+
276+
return self.template.format(**params)
277+
278+
279+
class SoftmaxFunctionTemplate(FunctionCallTemplate):
280+
def __init__(self):
281+
super().__init__(Softmax, include_header=activ_include_list)
282+
self.template = activ_function_template
283+
284+
def format(self, node):
285+
params = self._default_function_params(node)
286+
use_multidim = node.get_attr('n_inner', 1) > 1 or node.get_attr('n_outer', 1) > 1
287+
use_multidim = use_multidim and node.model.config.get_config_value('IOType') == 'io_parallel'
288+
params['activation'] = 'softmax' if not use_multidim else 'softmax_multidim'
289+
params['config'] = f'softmax_config{node.index}'
290+
291+
return self.template.format(**params)
292+
213293

214294
class ActivationFunctionTemplate(FunctionCallTemplate):
215295
def __init__(self):
216-
super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list)
296+
super().__init__((Activation, HardActivation), include_header=activ_include_list)
217297
self.template = activ_function_template
218298

219299
def format(self, node):

0 commit comments

Comments
 (0)