From 4731e17a173b3a44a1316331f65d9daa62ed3966 Mon Sep 17 00:00:00 2001 From: Jonathan-Shoemaker Date: Mon, 20 Jun 2022 12:27:03 -0500 Subject: [PATCH 01/11] attempt to add support for conv1d transpose add new files for conv1dtranspose resource clean up so that conv code is reached. Still need to get the actual implementation matching keras implement conv1dtranspose super inefficiently (gets correct answer though) try to fix indices to make code work make the c code work for conv1dtranspose reduce weight dimensions to properly reflect transposed kernel size clean up so that transpose filter width is passes around from config fix code such that simple transpose layer gets synthesized move variables out of loops, optimize slightly and add in alternative method of computation to compute by kernel (that option is not optimized as of now) add in conv1d transpose linebuffer format code. seems to work, unsure of if it is optimized yet trying to fix stream behavior get transpose compilation working mostly as expected. weird jump in latency from reuse 1 to 2 still exists initial conv2dtranspose addition. Output is permuted as of now. output in correct order. using large array to buffer output though fix up conv1dtranspose a bit to pad correctly. fix up stream instructions for both 1d and 2d transposes fix allowed reuse factors for transpose layers update to new conv methods for io_parallel. Still some issues with multiple filters as well as some padding issues clean up error with multiple filters and larger kernels optimize conv transpose resource to get it working reasonably well. may still have slight optimization left fix output to conv1d transpose resource add conv2dtranspose io_parallel implementation. Can still be optimized small changeup to data storage in conv1d parallel fix zero padding pass addition for transpose stream layers move transposing of weight matrix to resource_strategy for transpose layers change how stream loads in weights to be like parallel for conv transposes. unroll all stride steps completely fix output of 1d transpose parallel to be faster change 1d transpose weight input to be 2-dimensional (passed from python code) change 2d transpose weight input to be 3-dimensional (passed from python code) small changes to transposes Revert "fix nondefault project name handling (#626)". The commit breaks the Vivado Accelerator workflow, and the fix is unclear to me right now. This reverts commit e8f048ad2a49c067eb5e49740a5d94c7c1e33b24. steps towards getting integer inputs to work --- hls4ml/backends/fpga/fpga_backend.py | 164 ++++++++++++++ hls4ml/backends/fpga/fpga_types.py | 12 +- hls4ml/backends/fpga/passes/codegen.py | 41 +++- .../backends/vivado/passes/conv_same_pad.py | 110 ++++++++- hls4ml/backends/vivado/passes/conv_stream.py | 26 ++- .../vivado/passes/convolution_templates.py | 194 +++++++++++++++- hls4ml/backends/vivado/vivado_backend.py | 73 ++++++ hls4ml/converters/keras/convolution.py | 80 ++++++- hls4ml/converters/utils.py | 43 ++++ hls4ml/model/layers.py | 138 +++++++++++- hls4ml/model/types.py | 6 +- .../vivado/nnet_utils/nnet_conv1dtranspose.h | 45 ++++ .../nnet_conv1dtranspose_resource.h | 125 +++++++++++ .../nnet_utils/nnet_conv1dtranspose_stream.h | 133 +++++++++++ .../vivado/nnet_utils/nnet_conv2dtranspose.h | 56 +++++ .../nnet_conv2dtranspose_resource.h | 144 ++++++++++++ .../nnet_utils/nnet_conv2dtranspose_stream.h | 210 ++++++++++++++++++ .../vivado/nnet_utils/nnet_helpers.h | 76 +++++++ hls4ml/writer/vivado_writer.py | 39 +++- 19 files changed, 1686 insertions(+), 29 deletions(-) create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 8cfaec8b3f..47d6d056a4 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -158,6 +158,25 @@ def get_layer_mult_size(self, layer): n_out = layer.get_attr('n_out') return n_in, n_out + if 'Conv1DTranspose' in layer.class_name: + trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) // layer.get_attr( + 'stride_width' + ) + n_in = layer.get_attr('n_chan') * trfilt_width + n_out = layer.get_attr('n_filt') + return n_in, n_out + + if 'Conv2DTranspose' in layer.class_name: + trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) // layer.get_attr( + 'stride_width' + ) + trfilt_height = (layer.get_attr('filt_height') + layer.get_attr('stride_height') - 1) // layer.get_attr( + 'stride_height' + ) + n_in = layer.get_attr('n_chan') * trfilt_height * trfilt_width + n_out = layer.get_attr('n_filt') + return n_in, n_out + if 'Conv1D' in layer.class_name: n_in = layer.get_attr('n_chan') * layer.get_attr('filt_width') n_out = layer.get_attr('n_filt') @@ -711,7 +730,65 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke " ) {{\n" ).format(index=layer_idx) indent = ' ' + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + f'if (partition == {partition_idx:>3}) {{\n' + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = f'data[{int(v - 1)}]' + buffer_stmts.append(f'buffer[{pixel_idx}][{j}] = {val:>10};') + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + return generated_code + + def _compute_conv1d_tr_im2col(self, input_shape, out_w, kernel=3, stride=1): + W, C = input_shape + + tr_kernel = (kernel + stride - 1) // stride + + input_img = np.arange(1, W * C + 1) + im_matrix = np.zeros((tr_kernel * C * out_w,)) + + index = 0 + for i_ow in range(out_w): + for i_kw in range(tr_kernel): + for i_c in range(C): + # input column is just the output column shifted + input_col = i_ow - (tr_kernel - 1) + i_kw + if input_col >= 0 and input_col < W: + im_matrix[index] = input_img[input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + im_matrix = im_matrix.reshape(out_w, -1) + return im_matrix + + def generate_conv1d_tr_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, out_W, kernel=3, stride=1): + im2col_matrix = self._compute_conv1d_tr_im2col( + (in_W, in_C), + out_W, + kernel, + stride, + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv1DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): generated_code += indent * 2 + f'if (partition == {partition_idx:>3}) {{\n' for pixel_idx, arr in enumerate(partition): @@ -860,6 +937,93 @@ def generate_conv2d_line_buffer_fn( return generated_code + def _compute_conv2d_tr_im2col(self, input_shape, out_shape, kernel=(3, 3), stride=(1, 1)): + H, W, C = input_shape + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + out_h, out_w = out_shape + + tr_kernel_h = (kernel_h + stride_h - 1) // stride_h + tr_kernel_w = (kernel_w + stride_w - 1) // stride_w + + input_img = np.arange(1, H * W * C + 1) + im_matrix = np.zeros((tr_kernel_h * tr_kernel_w * C * out_h * out_w,)) + + index = 0 + for i_oh in range(out_h): + for i_ow in range(out_w): + for i_kh in range(tr_kernel_h): + input_row = i_oh - (tr_kernel_h - 1) + i_kh + for i_kw in range(tr_kernel_w): + for i_c in range(C): + if input_row < 0 or input_row >= H: + im_matrix[index] = 0 + else: + input_col = i_ow - (tr_kernel_w - 1) + i_kw + if input_col >= 0 and input_col < W: + im_matrix[index] = input_img[input_row * W * C + input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + + im_matrix = im_matrix.reshape(out_h * out_w, -1) + return im_matrix + + def generate_conv2d_tr_line_buffer_fn( + self, layer_idx, n_partitions, in_H, in_W, in_C, out_H, out_W, kernel=(3, 3), stride=(1, 1) + ): + if isinstance(kernel, Iterable): + kernel_height = kernel[0] + kernel_width = kernel[1] + else: + kernel_height = kernel + kernel_width = kernel + + if isinstance(stride, Iterable): + stride_height = stride[0] + stride_width = stride[1] + else: + stride_height = stride + stride_width = stride + + im2col_matrix = self._compute_conv2d_tr_im2col( + (in_H, in_W, in_C), + (out_W, out_W), + (kernel_height, kernel_width), + (stride_height, stride_width), + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv2DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T " + "buffer[CONFIG_T::n_pixels][CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' + + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + f'if (partition == {partition_idx:>3}) {{\n' + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = f'data[{int(v - 1)}]' + buffer_stmts.append(f'buffer[{pixel_idx}][{j}] = {val:>10};') + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + @model_optimizer() def write_hls(self, model): self.writer.write_hls(model) diff --git a/hls4ml/backends/fpga/fpga_types.py b/hls4ml/backends/fpga/fpga_types.py index ceac0b5e4d..160739d8de 100644 --- a/hls4ml/backends/fpga/fpga_types.py +++ b/hls4ml/backends/fpga/fpga_types.py @@ -428,7 +428,17 @@ def __init__(self, type_converter): class StaticWeightVariableDefinition(VariableDefinition): def definition_cpp(self, name_suffix='', as_reference=False): - return f'{self.type.name} {self.name}[{self.data_length}]' + if self.keep_dims > 0: + size_str = '' + for dim in range(self.keep_dims): + size_str += f'[{self.shape[dim]}]' + final_dim = 1 + for dim in range(self.keep_dims, len(self.shape)): + final_dim *= self.shape[dim] + size_str += f'[{final_dim}]' + return f'{self.type.name} {self.name}{size_str}' + else: + return f'{self.type.name} {self.name}[{self.data_length}]' class StaticWeightVariableConverter: diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index f1f1080996..f3b03ab472 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D +from hls4ml.model.layers import Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose from hls4ml.model.optimizer import OptimizerPass from hls4ml.model.types import Source @@ -7,12 +7,19 @@ class GenerateConvIm2col(OptimizerPass): '''Generates tcode for im2col step of 1D/2d convolution''' def match(self, node): - return isinstance(node, (Conv1D, Conv2D)) and node.model.config.get_config_value('IOType') == 'io_parallel' + return ( + isinstance(node, (Conv1D, Conv2D, Conv1DTranspose, Conv2DTranspose)) + and node.model.config.get_config_value('IOType') == 'io_parallel' + ) def transform(self, model, node): node_class = node.__class__.__name__ - if '1D' in node_class: + if '1DTranspose' in node_class: + self._generate_im2col_1d_transpose(node) + elif '1D' in node_class: self._generate_im2col_1d(node) + elif '2DTranspose' in node_class: + self._generate_im2col_2d_transpose(node) elif '2D' in node_class: self._generate_im2col_2d(node) else: @@ -31,6 +38,19 @@ def _generate_im2col_1d(self, node): node.set_attr('line_buffer_codegen', Source(code_str)) + def _generate_im2col_1d_transpose(self, node): + code_str = node.model.config.backend.generate_conv1d_tr_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_attr('proc_width'), + kernel=node.get_attr('filt_width'), + stride=node.get_attr('stride_width'), + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) + def _generate_im2col_2d(self, node): code_str = node.model.config.backend.generate_conv2d_line_buffer_fn( node.get_attr('index'), @@ -49,3 +69,18 @@ def _generate_im2col_2d(self, node): ) node.set_attr('line_buffer_codegen', Source(code_str)) + + def _generate_im2col_2d_transpose(self, node): + code_str = node.model.config.backend.generate_conv2d_tr_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_input_variable().shape[2], + node.get_attr('proc_height'), + node.get_attr('proc_width'), + kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')), + stride=(node.get_attr('stride_height'), node.get_attr('stride_width')), + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) diff --git a/hls4ml/backends/vivado/passes/conv_same_pad.py b/hls4ml/backends/vivado/passes/conv_same_pad.py index bb8354a3d0..1bbdb327ca 100644 --- a/hls4ml/backends/vivado/passes/conv_same_pad.py +++ b/hls4ml/backends/vivado/passes/conv_same_pad.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose, SeparableConv1D, SeparableConv2D from hls4ml.model.optimizer import OptimizerPass @@ -50,6 +50,53 @@ def transform(self, model, node): return True +class InsertZeroPaddingBeforeConv1DTranspose(OptimizerPass): + name = 'insert_zero_padding_before_conv1dtranspose' + + def match(self, node): + is_match = ( + isinstance(node, (Conv1DTranspose)) and node.get_attr('padding') == 'same' and node.get_attr('filt_width') != 1 + ) + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv1D layer + pad_left = node.get_attr('pad_left') + # pad_right = node.get_attr('pad_right') + convtr_out_width = node.get_attr('out_width') + in_width = node.get_attr('in_width') + stride_width = node.get_attr('stride_width') + trfilt_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) // node.get_attr('stride_width') + + add_right = (convtr_out_width + pad_left) // stride_width - (in_width - 1) + + out_width = in_width + add_right + trfilt_width - 1 + + attrs = { + 'pad_left': trfilt_width - 1, + 'pad_right': add_right, + 'in_width': in_width, + 'out_width': out_width, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last'), + } + + # Switch Conv1DTranspose to be 'valid'. I think this is wrong + node.set_attr('padding', 'valid') + node.set_attr('in_width', out_width) + node.set_attr('pad_left', pad_left + (trfilt_width - 1) * stride_width) + + # Insert new ZeroPadding1D node above Conv1DTranspose + padding_layer = model.make_node('ZeroPadding1D', 'zp1d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer) + + return True + + class InsertZeroPaddingBeforeConv2D(OptimizerPass): name = 'insert_zero_padding_before_conv2d' @@ -107,3 +154,64 @@ def transform(self, model, node): model.insert_node(padding_layer, before=node) return True + + +class InsertZeroPaddingBeforeConv2DTranspose(OptimizerPass): + name = 'insert_zero_padding_before_conv2dtranspose' + + def match(self, node): + is_match = ( + isinstance(node, Conv2DTranspose) and node.get_attr('padding') == 'same' and node.get_attr('filt_width') != 1 + ) + return is_match + + def transform(self, model, node): + if model.config.get_config_value('IOType') != 'io_stream': + return False + + # Get the padding parameters from Conv2DTranspose layer + pad_left = node.get_attr('pad_left') + # pad_right = node.get_attr('pad_right') + pad_top = node.get_attr('pad_top') + # pad_bottom = node.get_attr('pad_bottom') + convtr_out_width = node.get_attr('out_width') + convtr_out_height = node.get_attr('out_height') + in_width = node.get_attr('in_width') + in_height = node.get_attr('in_height') + stride_width = node.get_attr('stride_width') + stride_height = node.get_attr('stride_height') + trfilt_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) // node.get_attr('stride_width') + trfilt_height = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) // node.get_attr('stride_height') + + add_right = (convtr_out_width + pad_left) // stride_width - (in_width - 1) + add_bottom = (convtr_out_height + pad_top) // stride_height - (in_height - 1) + + out_width = in_width + add_right + trfilt_width - 1 + out_height = in_height + add_bottom + trfilt_height - 1 + + attrs = { + 'pad_left': trfilt_width - 1, + 'pad_right': add_right, + 'pad_top': trfilt_height - 1, + 'pad_bottom': add_bottom, + 'in_width': in_width, + 'in_height': in_height, + 'out_width': out_width, + 'out_height': out_height, + 'n_chan': node.get_attr('n_chan'), + 'data_format': node.get_attr('data_format', 'channels_last'), + } + + # switch Conv2DTranspose to be 'valid'. This is technically not true though + node.set_attr('padding', 'valid') + node.set_attr('in_width', out_width) + node.set_attr('in_height', out_height) + node.set_attr('pad_left', pad_left + (trfilt_width - 1) * stride_width) + node.set_attr('pad_top', pad_top + (trfilt_height - 1) * stride_height) + + # insert new ZeroPadding2D ndoe above Conv2DTranspose + padding_layer = model.make_node('ZeroPadding2D', 'zp2d_' + node.name, attrs, node.inputs.copy()) + padding_layer.get_output_variable().type.precision = node.get_input_variable().type.precision + model.insert_node(padding_layer, before=node) + + return True diff --git a/hls4ml/backends/vivado/passes/conv_stream.py b/hls4ml/backends/vivado/passes/conv_stream.py index e0bb853d83..2bac452b31 100644 --- a/hls4ml/backends/vivado/passes/conv_stream.py +++ b/hls4ml/backends/vivado/passes/conv_stream.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose, SeparableConv1D, SeparableConv2D from hls4ml.model.optimizer import OptimizerPass @@ -6,7 +6,7 @@ class GenerateConvStreamingInstructions(OptimizerPass): '''Generates the instructions for streaming implementation of CNNs''' def match(self, node): - return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D)) + return isinstance(node, (Conv1D, Conv1DTranspose, SeparableConv1D, Conv2D, SeparableConv2D, Conv2DTranspose)) def transform(self, model, node): node_class = node.__class__.__name__ @@ -18,12 +18,18 @@ def transform(self, model, node): raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})') def _generate_1d_instructions(self, node): + kernel_width = node.get_attr('filt_width') + stride_width = node.get_attr('stride_width') + if isinstance(node, Conv1DTranspose): + # set kernel width to trfilt_width and set stride to 1 (effective kernel dimensions in transpose) + kernel_width = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) // node.get_attr('stride_width') + stride_width = 1 if node.model.config.get_config_value('IOType') == 'io_stream': min_w, instructions = node.model.config.backend.compute_conv1d_instructions( node.get_input_variable().shape[0], node.get_input_variable().shape[1], - node.get_attr('filt_width'), - node.get_attr('stride_width'), + kernel_width, + stride_width, ) instructions_str = ','.join(str(i) for i in instructions) node.set_attr('min_width', min_w) @@ -34,13 +40,21 @@ def _generate_1d_instructions(self, node): node.set_attr('instructions', '0') def _generate_2d_instructions(self, node): + kernel_height = node.get_attr('filt_height') + stride_height = node.get_attr('stride_height') + if isinstance(node, Conv2DTranspose): + # set actual kernel height to trfilt_height and set stride to 1 (effective kernel in transpose) + kernel_height = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) // node.get_attr( + 'stride_height' + ) + stride_height = 1 if node.model.config.get_config_value('IOType') == 'io_stream': min_h, min_w, instructions = node.model.config.backend.compute_conv2d_instructions( node.get_input_variable().shape[0], node.get_input_variable().shape[1], node.get_input_variable().shape[2], - node.get_attr('filt_height'), - node.get_attr('stride_height'), + kernel_height, + stride_height, ) instructions_str = ','.join(str(i) for i in instructions) node.set_attr('min_height', min_h) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 9a7b10a6f4..11083fa069 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -1,6 +1,15 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import Conv1D, Conv2D, Conv2DBatchnorm, DepthwiseConv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import ( + Conv1D, + Conv1DTranspose, + Conv2D, + Conv2DBatchnorm, + Conv2DTranspose, + DepthwiseConv2D, + SeparableConv1D, + SeparableConv2D, +) # Shared multiplication template @@ -106,6 +115,93 @@ def format(self, node): return self.template.format(**params) +# Conv1DTranspose Templates + +conv1dtranspose_config_template = """struct config{index} : nnet::conv1dtranspose_config {{ + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_width = {stride_width}; + static const unsigned dilation = {dilation}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned trfilt_width = {trfilt_width}; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_width = {min_width}; + static const ap_uint pixels[min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned proc_width = {proc_width}; + static const unsigned n_pixels = proc_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; +}}; +const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" + +conv1dtranspose_function_template = ( + 'nnet::conv_1d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) + +conv1dtranspose_include_list = ['nnet_utils/nnet_conv1dtranspose.h', 'nnet_utils/nnet_conv1dtranspose_stream.h'] + + +class Conv1DTransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Conv1DTranspose) + self.template = conv1dtranspose_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + + params['config_t'] = f'config{node.index}_mult' + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = f'fill_buffer_{node.index}' + else: + params['fill_fn'] = 'FillConv1DBuffer' + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = ( + node.get_attr('n_chan') + * (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) + // node.get_attr('stride_width') + ) + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision + ) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + + +class Conv1DTransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv1DTranspose, include_header=conv1dtranspose_include_list) + self.template = conv1dtranspose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) + + # Conv2D Templates conv2d_config_template = """struct config{index} : nnet::conv2d_config {{ @@ -219,6 +315,102 @@ def __init__(self): self.template = depthconv2d_function_template +# Conv2DTranspose Templates +conv2dtranspose_config_template = """struct config{index} : nnet::conv2dtranspose_config {{ + static const unsigned pad_top = {pad_top}; + static const unsigned pad_bottom = {pad_bottom}; + static const unsigned pad_left = {pad_left}; + static const unsigned pad_right = {pad_right}; + static const unsigned in_height = {in_height}; + static const unsigned in_width = {in_width}; + static const unsigned n_chan = {n_chan}; + static const unsigned filt_height = {filt_height}; + static const unsigned filt_width = {filt_width}; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = {n_filt}; + static const unsigned stride_height = {stride_height}; + static const unsigned stride_width = {stride_width}; + static const unsigned out_height = {out_height}; + static const unsigned out_width = {out_width}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned trfilt_width = {trfilt_width}; + static const unsigned trfilt_height = {trfilt_height}; + static const bool store_weights_in_bram = false; + static const unsigned strategy = nnet::{strategy}; + static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; + static const unsigned min_height = {min_height}; + static const unsigned min_width = {min_width}; + static const ap_uint pixels[min_height * min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned proc_height = {proc_height}; + static const unsigned proc_width = {proc_width}; + static const unsigned n_pixels = proc_height * proc_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; + typedef {accum_t.name} accum_t; + typedef {bias_t.name} bias_t; + typedef {weight_t.name} weight_t; + typedef {config_t} mult_config; +}}; +const ap_uint config{index}::pixels[] = {{{instructions}}};\n""" + +conv2dtranspose_function_template = ( + 'nnet::conv_2d_transpose_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' +) + +conv2dtranspose_include_list = ['nnet_utils/nnet_conv2dtranspose.h', 'nnet_utils/nnet_conv2dtranspose_stream.h'] + + +class Conv2DTransposeConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Conv2DTranspose) + self.template = conv2dtranspose_config_template + self.mult_template = conv_mult_config_template + + def format(self, node): + params = self._default_config_params(node) + params['dilation'] = node.get_attr('dilation', 1) + params['nzeros'] = node.get_weights('weight').nzeros + params['trfilt_width'] = (node.get_attr('filt_width') + node.get_attr('stride_width') - 1) // node.get_attr( + 'stride_width' + ) + params['trfilt_height'] = (node.get_attr('filt_height') + node.get_attr('stride_height') - 1) // node.get_attr( + 'stride_height' + ) + + params['config_t'] = f'config{node.index}_mult' + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = f'fill_buffer_{node.index}' + else: + params['fill_fn'] = 'FillConv2DBuffer' + conv_config = self.template.format(**params) + + mult_params = self._default_config_params(node) + mult_params['n_in'] = node.get_attr('n_chan') * params['trfilt_width'] * params['trfilt_height'] + mult_params['n_out'] = node.get_attr('n_filt') + mult_params['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('weight').type.precision + ) + mult_config = self.mult_template.format(**mult_params) + + return mult_config + '\n' + conv_config + + +class Conv2DTransposeFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Conv2DTranspose, include_header=conv2dtranspose_include_list) + self.template = conv2dtranspose_function_template + + def format(self, node): + params = self._default_function_params(node) + params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl' + params['w'] = node.get_weights('weight').name + params['b'] = node.get_weights('bias').name + + return self.template.format(**params) + + # SeparableConv1D/2D Templates sepconv_config_template = """struct config{index} {{ diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 1d4c96d982..6c68a13c7e 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -11,7 +11,9 @@ GRU, LSTM, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, Dense, DepthwiseConv2D, Embedding, @@ -84,6 +86,8 @@ def _register_flows(self): 'vivado:clone_output', 'vivado:insert_zero_padding_before_conv1d', 'vivado:insert_zero_padding_before_conv2d', + 'vivado:insert_zero_padding_before_conv1dtranspose', + 'vivado:insert_zero_padding_before_conv2dtranspose', 'vivado:broadcast_stream', ] streaming_flow = register_flow('streaming', streaming_passes, requires=[init_flow], backend=self.name) @@ -273,6 +277,37 @@ def init_conv1d(self, layer): self._validate_conv_strategy(layer) + @layer_optimizer(Conv1DTranspose) + def init_conv1dtranspose(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_target_reuse_factor(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + proc_width = ( + layer.get_output_variable().shape[0] + layer.get_attr('pad_left') + layer.get_attr('stride_width') - 1 + ) // layer.get_attr('stride_width') + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(1, proc_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + valid_pf_str = ','.join(map(str, valid_pf)) + print( + f'WARNING: Invalid ParallelizationFactor={chosen_pf} in layer "{layer.name}".' + f'Using ParallelizationFactor={closest_pf} instead. Valid ParallelizationFactor(s): {valid_pf_str}.' + ) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', proc_width // closest_pf) + layer.set_attr('proc_width', proc_width) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv1D) def init_sepconv1d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -319,6 +354,44 @@ def init_conv2d(self, layer): self._validate_conv_strategy(layer) + @layer_optimizer(Conv2DTranspose) + def init_conv2dtranspose(self, layer): + if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D + layer.weights['weight'].data = np.expand_dims(layer.weights['weight'].data, axis=(0, 1)) + + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + self.set_target_reuse_factor(layer) + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + proc_height = ( + layer.get_output_variable().shape[0] + layer.get_attr('pad_top') + layer.get_attr('stride_height') - 1 + ) // layer.get_attr('stride_height') + proc_width = ( + layer.get_output_variable().shape[1] + layer.get_attr('pad_left') + layer.get_attr('stride_width') - 1 + ) // layer.get_attr('stride_width') + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(proc_height, proc_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + valid_pf_str = ','.join(map(str, valid_pf)) + print( + f'WARNING: Invalid ParallelizationFactor={chosen_pf} in layer "{layer.name}".' + f'Using ParallelizationFactor={closest_pf} instead. Valid ParallelizationFactor(s): {valid_pf_str}.' + ) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', proc_height * proc_width // closest_pf) + layer.set_attr('proc_height', proc_height) + layer.set_attr('proc_width', proc_width) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv2D) def init_sepconv2d(self, layer): if layer.model.config.is_resource_strategy(layer): diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 5ebd2abee1..b2f3ea7611 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -1,5 +1,11 @@ from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer -from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d, parse_data_format +from hls4ml.converters.utils import ( + compute_padding_1d, + compute_padding_1d_transpose, + compute_padding_2d, + compute_padding_2d_transpose, + parse_data_format, +) @keras_handler('Conv1D', 'SeparableConv1D') @@ -36,6 +42,33 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): return layer, output_shape +@keras_handler('Conv1DTranspose') +def parse_conv1dtranspose_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'Conv1DTranspose' in keras_layer['class_name'] + layer = parse_default_keras_layer(keras_layer, input_names) + + (layer['in_width'], layer['n_chan']) = parse_data_format(input_shapes[0], layer['data_format']) + + layer['n_filt'] = keras_layer['config']['filters'] + layer['filt_width'] = keras_layer['config']['kernel_size'][0] + layer['stride_width'] = keras_layer['config']['strides'][0] + layer['padding'] = keras_layer['config']['padding'] + layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1) // layer['stride_width'] + + ( + layer['out_width'], + layer['pad_left'], + layer['pad_right'], + ) = compute_padding_1d_transpose(layer['padding'], layer['in_width'], layer['stride_width'], layer['filt_width']) + + if layer['data_format'] == 'channels_last': + output_shape = [input_shapes[0][0], layer['out_width'], layer['n_filt']] + elif layer['data_format'] == 'channels_first': + output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']] + + return layer, output_shape + + @keras_handler('Conv2D', 'SeparableConv2D', 'DepthwiseConv2D') def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): assert 'Conv2D' in keras_layer['class_name'] @@ -88,3 +121,48 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): output_shape = [input_shapes[0][0], layer['out_height'], layer['out_width'], layer['n_filt']] return layer, output_shape + + +@keras_handler('Conv2DTranspose') +def parse_conv2dtranspose_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'Conv2DTranspose' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + (layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(input_shapes[0], layer['data_format']) + + if 'filters' in keras_layer['config']: + layer['n_filt'] = keras_layer['config']['filters'] + else: + layer['n_filt'] = layer['n_chan'] + layer['filt_height'] = keras_layer['config']['kernel_size'][0] + layer['filt_width'] = keras_layer['config']['kernel_size'][1] + layer['stride_height'] = keras_layer['config']['strides'][0] + layer['stride_width'] = keras_layer['config']['strides'][1] + layer['padding'] = keras_layer['config']['padding'] + layer['trfilt_height'] = (layer['filt_height'] + layer['stride_height'] - 1) // layer['stride_height'] + layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1) // layer['stride_width'] + + ( + layer['out_height'], + layer['out_width'], + layer['pad_top'], + layer['pad_bottom'], + layer['pad_left'], + layer['pad_right'], + ) = compute_padding_2d_transpose( + layer['padding'], + layer['in_height'], + layer['in_width'], + layer['stride_height'], + layer['stride_width'], + layer['filt_height'], + layer['filt_width'], + ) + + if layer['data_format'] == 'channels_first': + output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']] + else: + output_shape = [input_shapes[0][0], layer['out_height'], layer['out_width'], layer['n_filt']] + + return layer, output_shape diff --git a/hls4ml/converters/utils.py b/hls4ml/converters/utils.py index d1c9e050d5..9ca2938996 100644 --- a/hls4ml/converters/utils.py +++ b/hls4ml/converters/utils.py @@ -82,6 +82,22 @@ def compute_padding_1d(pad_type, in_size, stride, filt_size): return (n_out, pad_left, pad_right) +def compute_padding_1d_transpose(pad_type, in_size, stride, filt_size): + if pad_type.lower() == 'same': + n_out = stride * in_size + pad_along_size = max(filt_size - stride, 0) + pad_left = pad_along_size // 2 + pad_right = pad_along_size - pad_left + elif pad_type.lower() == 'valid': + n_out = stride * (in_size - 1) + filt_size + pad_left = 0 + pad_right = 0 + else: + raise Exception(f'Unknown padding type: {pad_type}') + + return (n_out, pad_left, pad_right) + + def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width): """Computes the amount of padding required on each side of the 2D input tensor. @@ -134,6 +150,33 @@ def compute_padding_2d(pad_type, in_height, in_width, stride_height, stride_widt return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) +def compute_padding_2d_transpose(pad_type, in_height, in_width, stride_height, stride_width, filt_height, filt_width): + if pad_type.lower() == 'same': + # Height + out_height = stride_height * in_height + pad_along_height = max(filt_height - stride_height, 0) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + # Width + out_width = stride_width * in_width + pad_along_width = max(filt_width - stride_width, 0) + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + elif pad_type.lower() == 'valid': + # something + out_height = stride_height * in_height + out_width = stride_width * in_width + + pad_top = 0 + pad_bottom = 0 + pad_left = 0 + pad_right = 0 + else: + raise Exception(f'Unknown padding type: {pad_type}') + + return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) + + def compute_padding_1d_pytorch(pad_type, in_size, stride, filt_size, dilation): if isinstance(pad_type, str): if pad_type.lower() == 'same': diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index d9da2cc741..10ced7ef01 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -249,9 +249,14 @@ def add_output_variable( self.set_attr(out_name, out) - def add_weights(self, quantizer=None, compression=False): + def add_weights(self, quantizer=None, compression=False, keep_dims=0): self.add_weights_variable( - name='weight', var_name='w{index}', data='weight', quantizer=quantizer, compression=compression + name='weight', + var_name='w{index}', + data='weight', + quantizer=quantizer, + compression=compression, + keep_dims=keep_dims, ) def add_bias(self, quantizer=None): @@ -269,7 +274,7 @@ def add_bias(self, quantizer=None): ) def add_weights_variable( - self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False + self, name, var_name=None, type_name=None, precision=None, data=None, quantizer=None, compression=False, keep_dims=0 ): if var_name is None: var_name = name + '{index}' @@ -315,7 +320,13 @@ def add_weights_variable( ) else: var = WeightVariable( - var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index + var_name, + type_name=type_name, + precision=precision, + quantizer=quantizer, + data=data, + index=self.index, + keep_dims=keep_dims, ) var.data_unquantized = data_unquantized @@ -435,6 +446,55 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv1DTranspose(Layer): + _expected_attributes = [ + Attribute('in_width'), + Attribute('out_width'), + Attribute('n_chan'), + Attribute('n_filt'), + Attribute('filt_width'), + Attribute('stride_width'), + Attribute('pad_left'), + Attribute('pad_right'), + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_width'], self.attributes['n_filt']] + dims = [f'N_OUTPUTS_{self.index}', f'N_FILT_{self.index}'] + else: + shape = [self.attributes['n_filt'], self.attributes['out_width']] + dims = [f'N_FILT_{self.index}', f'N_OUTPUTS_{self.index}'] + + data = self.model.get_weights_data(self.name, 'kernel') + # now we transform the entire kernel + + # (W,F,C) => (F,W,C) + data = np.transpose(data, axes=[1, 0, 2]) + # now split the kernel into stride width kernels (F, W, C) -> (S, F, W/S, C) + n_filts, kern_width, n_chan = data.shape + new_weights = np.zeros((self.attributes['stride_width'], n_filts, self.attributes['trfilt_width'], n_chan)) + for i_sw in range(self.attributes['stride_width']): + for i_fw in range(self.attributes['trfilt_width']): + filt_ind = i_sw + (self.attributes['trfilt_width'] - i_fw - 1) * self.attributes['stride_width'] + for i_nf in range(n_filts): + for i_nc in range(n_chan): + if filt_ind < kern_width: + new_weights[i_sw][i_nf][i_fw][i_nc] = data[i_nf][filt_ind][i_nc] + data = new_weights + + self.add_output_variable(shape, dims) + # self.add_weights(quantizer = self.get_attr('weight_quantizer'), keep_dims=1) + self.add_weights_variable( + name='weight', var_name='w{index}', data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=1 + ) + self.add_bias(quantizer=self.get_attr('bias_quantizer')) + + class SeparableConv1D(Layer): _expected_attributes = [ Attribute('in_width'), @@ -506,6 +566,74 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class Conv2DTranspose(Layer): + _expected_attributes = [ + Attribute('in_height'), + Attribute('in_width'), + Attribute('out_height'), + Attribute('out_width'), + Attribute('n_chan'), + Attribute('n_filt'), + Attribute('filt_height'), + Attribute('filt_width'), + Attribute('stride_height'), + Attribute('stride_width'), + Attribute('pad_top'), + Attribute('pad_bottom'), + Attribute('pad_left'), + Attribute('pad_right'), + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + + def initialize(self): + if self.get_attr('data_format') == 'channels_last': + shape = [self.attributes['out_height'], self.attributes['out_width'], self.attributes['n_filt']] + dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_FILT_{self.index}'] + else: + shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']] + dims = [f'N_FILT_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] + + data = self.model.get_weights_data(self.name, 'kernel') + # now we transform the entire kernel + + # (H,W,F,C) => (F,H,W,C) + data = np.transpose(data, axes=[2, 0, 1, 3]) + # now split the kernel into stride width kernels (F, W, C) -> (Sh, Sw, F, H/Sh, W/Sw, C) + n_filts, kern_height, kern_width, n_chan = data.shape + new_weights = np.zeros( + ( + self.attributes['stride_height'], + self.attributes['stride_width'], + n_filts, + self.attributes['trfilt_height'], + self.attributes['trfilt_width'], + n_chan, + ) + ) + for i_sh in range(self.attributes['stride_height']): + for i_sw in range(self.attributes['stride_width']): + for i_fh in range(self.attributes['trfilt_height']): + for i_fw in range(self.attributes['trfilt_width']): + filt_h_ind = i_sh + (self.attributes['trfilt_height'] - i_fh - 1) * self.attributes['stride_height'] + filt_w_ind = i_sw + (self.attributes['trfilt_width'] - i_fw - 1) * self.attributes['stride_width'] + for i_nf in range(n_filts): + for i_nc in range(n_chan): + if filt_h_ind < kern_height and filt_w_ind < kern_width: + new_weights[i_sh][i_sw][i_nf][i_fh][i_fw][i_nc] = data[i_nf][filt_h_ind][filt_w_ind][ + i_nc + ] + data = new_weights + + self.add_output_variable(shape, dims) + self.add_weights_variable( + name='weight', var_name='w{index}', data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=2 + ) + self.add_bias(quantizer=self.get_attr('bias_quantizer')) + + class Conv2DBatchnorm(Conv2D): def _get_folded_weights(self): """ @@ -1312,6 +1440,8 @@ def initialize(self): 'Conv2D': Conv2D, 'BinaryConv2D': Conv2D, 'QConv2D': Conv2D, + 'Conv1DTranspose': Conv1DTranspose, + 'Conv2DTranspose': Conv2DTranspose, 'QConv2DBatchnorm': Conv2DBatchnorm, 'SeparableConv1D': SeparableConv1D, 'SeparableConv2D': SeparableConv2D, diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index b6f2e42a01..eccd653126 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -301,6 +301,8 @@ def __str__(self): return typestring def __eq__(self, other): + if not isinstance(other, FixedPrecisionType): + return False eq = self.width == other.width eq = eq and self.integer == other.integer eq = eq and self.fractional == other.fractional @@ -512,9 +514,10 @@ class WeightVariable(Variable): precision (PrecisionType, optional): Precision data type. data (ndarray): The data array. quantizer (_type_, optional): Quantizer to apply to the data array. Defaults to ``None``. + keep_dims (int, optional): ADD A DESCRIPTION HERE. """ - def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwargs): + def __init__(self, var_name, type_name, precision, data, quantizer=None, keep_dims=0, **kwargs): super().__init__(var_name, NamedType(type_name, precision, **kwargs), **kwargs) self.data = data self.nzeros = -1 @@ -527,6 +530,7 @@ def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwarg self._iterator = None self.update_precision(precision) self.quantizer = quantizer + self.keep_dims = keep_dims def __iter__(self): self._iterator = np.nditer(self.data, order='C') diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h new file mode 100644 index 0000000000..fd3836f6fe --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose.h @@ -0,0 +1,45 @@ +#ifndef NNET_CONV1DTRANSPOSE_H_ +#define NNET_CONV1DTRANSPOSE_H_ + +#include "nnet_common.h" +#include "nnet_conv1dtranspose_resource.h" +#include + +namespace nnet { + +struct conv1dtranspose_config { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Convolutional parameters + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_width = 10; + static const unsigned n_chan = 0; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_width; + static const unsigned stride_width = 1; + static const unsigned dilation = 1; + static const unsigned out_width = 10; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; +}; + +template +void conv_1d_transpose_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width] + [CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region + // for now, we are only adding resource strategy + conv_1d_transpose_resource_cl(data, res, weights, biases); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h new file mode 100644 index 0000000000..c43e6380a1 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_resource.h @@ -0,0 +1,125 @@ +#ifndef NNET_CONV1DTRANSPOSE_RESOURCE_H_ +#define NNET_CONV1DTRANSPOSE_RESOURCE_H_ + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet { + +template +void conv_1d_transpose_resource_cl( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width] + [CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + constexpr unsigned mult_n_in = CONFIG_T::trfilt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = block_factor / mult_n_out; + + assert((block_factor % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && + "The current Reuse Factor is not allowed"); + assert((CONFIG_T::reuse_factor <= CONFIG_T::trfilt_width * CONFIG_T::n_chan) && + "This function is correct only for RF <= TRFILT_WIDTH * N_CHAN"); + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_width]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=2 + +PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + + InitStrideLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc][i_sw] = (typename CONFIG_T::accum_t)biases[i_acc]; + } + } + } + + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideMultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out][i_sw] += static_cast( + CONFIG_T::mult_config::template product< + data_T, typename CONFIG_T::mult_config::weight_t>::product(data_buf[i_pxl][i_in], + weights[i_sw][i_w])); + } + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideResultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + unsigned output_index = + i_pxl * CONFIG_T::n_partitions * CONFIG_T::stride_width + i_part * CONFIG_T::stride_width + i_sw; + + if (output_index >= CONFIG_T::pad_left && output_index < CONFIG_T::out_width + CONFIG_T::pad_left) { + ResultLoop: + for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + + res[(output_index - CONFIG_T::pad_left) * mult_n_out + i_res] = + cast(acc[i_pxl][i_res][i_sw]); + } + } + } + } + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h new file mode 100644 index 0000000000..3a38f8e7d2 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h @@ -0,0 +1,133 @@ +#ifndef NNET_CONV1DTRANSPOSE_STREAM_H +#define NNET_CONV1DTRANSPOSE_STREAM_H + +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_conv_stream.h" + +namespace nnet { + +template +void kernel_shift_tr_1d(const data_T &in_elem, + typename data_T::value_type kernel_window[CONFIG_T::trfilt_width * CONFIG_T::n_chan]) { + #pragma HLS INLINE + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::trfilt_width - 1; +KernelShiftWidth: + for (int i_iw = 0; i_iw < filt_width; i_iw++) { + #pragma HLS PIPELINE II = 1 + KernelShiftChannel: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + // Shift every element in kernel_window to the left + kernel_window[i_iw * CONFIG_T::n_chan + i_ic] = kernel_window[(i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::trfilt_width - 1) * CONFIG_T::n_chan; +KernelPushChannel: + for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + kernel_window[lastheight + i_ic] = in_elem[i_ic]; + } +} + +// Conv 1D transpose compute output +template +void compute_output_buffer_tr_1d( + const data_T &in_elem, hls::stream &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width] + [CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE + + // Thresholds + const static int lShiftX = CONFIG_T::trfilt_width - 1; + + // Counters + static int pX = 0; // pixel counter + static int oX = 0; // output counter (deals with 'padding') + + static typename data_T::value_type kernel_data[CONFIG_T::trfilt_width * CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + res_T res_pack; + #pragma HLS DATA_PACK variable=res_pack + + // Add pixel to buffer + nnet::kernel_shift_tr_1d(in_elem, kernel_data); + +// always do stride number of multiplications +StrideLoop: + for (int idx = 0; idx < CONFIG_T::stride_width; idx++) { + #pragma HLS UNROLL + #pragma HLS INLINE region + // Dense multiply + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights[idx], biases); + } else { + dense_resource( + kernel_data, res_out, weights[idx], biases); + } + + // Pack output + if (oX >= CONFIG_T::pad_left && oX < CONFIG_T::pad_left + CONFIG_T::out_width) { + CastLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + res_pack[i_ic] = res_out[i_ic]; + } + res_stream.write(res_pack); + } + // Write output to stream when output ready + oX++; + } + + // static var housekeeping + if (pX + 1 == CONFIG_T::in_width) // done with all of the inputs + { + pX = 0; + oX = 0; + } else { + pX = pX + 1; + } +} + +template +void conv_1d_transpose_buffer_cl( + hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width] + [CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { +ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + #pragma HLS LOOP_FLATTEN + // if (CONFIG_T::strategy == nnet::latency) { + // #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + // } + compute_output_buffer_tr_1d(data.read(), res, weights, biases); + } +} + +template +void conv_1d_transpose_cl(hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_width] + [CONFIG_T::trfilt_width * CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + switch (CONFIG_T::implementation) { + #pragma HLS inline region + case conv_implementation::linebuffer: + conv_1d_transpose_buffer_cl(data, res, weights, biases); + break; + } +} + +} // namespace nnet +#endif +// NEED TO PAD INPUT OR CLEAR KERNEL diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h new file mode 100644 index 0000000000..2ef3435a9d --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose.h @@ -0,0 +1,56 @@ +#ifndef NNET_CONV2DTRANSPOSE_H +#define NNET_CONV2DTRANSPOSE_H + +#include "nnet_common.h" +#include "nnet_conv2dtranspose_resource.h" +#include + +namespace nnet { + +struct conv2dtranspose_config { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + + // Convolutional parameters + static const unsigned pad_top = 0; + static const unsigned pad_bottom = 0; + static const unsigned pad_left = 0; + static const unsigned pad_right = 0; + static const unsigned in_height = 10; + static const unsigned in_width = 10; + static const unsigned n_chan = 1; + static const unsigned filt_height = 1; + static const unsigned filt_width = 1; + static const unsigned kernel_size = filt_height * filt_width; + static const unsigned n_filt = 1; + static const unsigned stride_height = 1; + static const unsigned stride_width = 1; + static const unsigned out_height = 10; + static const unsigned out_width = 10; + static const unsigned dilation_height = 1; + static const unsigned dilation_width = 1; + static const unsigned trfilt_height = 1; + static const unsigned trfilt_width = 1; + + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + static const unsigned n_zeros = 0; // not used yet +}; + +template +void conv_2d_transpose_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width] + [CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * + CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region + // only have resource strategy as of now + conv_2d_transpose_resource_cl(data, res, weights, biases); +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h new file mode 100644 index 0000000000..e2954fd64b --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_resource.h @@ -0,0 +1,144 @@ +#ifndef NNET_CONV2DTRANSPOSE_RESOURCE_H +#define NNET_CONV2DTRANSPOSE_RESOURCE_H + +#include "nnet_common.h" +#include "nnet_dense.h" + +namespace nnet { + +template +void conv_2d_transpose_resource_cl(data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width] + [CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * + CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + constexpr unsigned mult_n_in = CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + + constexpr unsigned multiplier_limit = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = multiplier_limit / mult_n_out; + + assert((multiplier_limit % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && + "The current Reuse Factor is not allowed"); + assert((multiplier_limit == block_factor) && + "This function is correct only for RF <= TRFILT_HEIGHT * TRFILT_WIDTH * N_CHAN"); + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out][CONFIG_T::stride_height][CONFIG_T::stride_width]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor dim=3 + +PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + + InitStrideHeightLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + + InitStrideWidthLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc][i_sh][i_sw] = (typename CONFIG_T::accum_t)biases[i_acc]; + } + } + } + } + + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + StrideHeightMultLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + StrideWidthMultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out][i_sh][i_sw] += static_cast( + CONFIG_T::mult_config::template product< + data_T, typename CONFIG_T::mult_config::weight_t>::product(data_buf[i_pxl][i_in], + weights[i_sh][i_sw][i_w])); + } + } + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + StrideHeightResultLoop: + for (unsigned i_sh = 0; i_sh < CONFIG_T::stride_height; i_sh++) { + #pragma HLS UNROLL + StrideWidthResultLoop: + for (unsigned i_sw = 0; i_sw < CONFIG_T::stride_width; i_sw++) { + #pragma HLS UNROLL + + unsigned px_ind = i_pxl * CONFIG_T::n_partitions + i_part; + unsigned height_ind = (px_ind / CONFIG_T::proc_width) * CONFIG_T::stride_height + i_sh; + unsigned width_ind = (px_ind % CONFIG_T::proc_width) * CONFIG_T::stride_width + i_sw; + + if (height_ind >= CONFIG_T::pad_top && height_ind < CONFIG_T::out_height + CONFIG_T::pad_top && + width_ind >= CONFIG_T::pad_left && width_ind < CONFIG_T::out_width + CONFIG_T::pad_left) { + ResultLoop: + for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + + res[((height_ind - CONFIG_T::pad_top) * CONFIG_T::out_width + width_ind - CONFIG_T::pad_left) * + CONFIG_T::n_filt + + i_res] = cast(acc[i_pxl][i_res][i_sh][i_sw]); + } + } + } + } + } + } +} + +} // namespace nnet + +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h new file mode 100644 index 0000000000..555f5ae2be --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h @@ -0,0 +1,210 @@ +#ifndef NNET_CONV2DTRANSPOSE_STREAM_H +#define NNET_CONV2DTRANSPOSE_STREAM_H + +#include "ap_shift_reg.h" +#include "hls_stream.h" +#include "nnet_common.h" +#include "nnet_conv_stream.h" + +namespace nnet { + +template +void kernel_shift_tr_2d( + typename data_T::value_type shift_buffer[CONFIG_T::trfilt_height][CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::trfilt_width * CONFIG_T::trfilt_height * CONFIG_T::n_chan]) { + #pragma HLS inline + + // Shift kernel_window by one step to the left (manual shift operation) + static const int filt_width = CONFIG_T::trfilt_width - 1; +KernelShiftWidth: + for (int i_iw = 0; i_iw < filt_width; i_iw++) { + #pragma HLS PIPELINE II = 1 + KernelShiftHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) { + KernelShiftChannel: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift every element in kernel_window to the left + kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_iw * CONFIG_T::n_chan + i_ic] = + kernel_window[i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + (i_iw + 1) * CONFIG_T::n_chan + i_ic]; + } + } + } + + // Insert shift_buffer column into right-most column of kernel + static const int lastheight = (CONFIG_T::trfilt_width - 1) * CONFIG_T::n_chan; +KernelPushHeight: + for (int i_ih = 0; i_ih < CONFIG_T::trfilt_height; i_ih++) { + #pragma HLS UNROLL + KernelPushChannel: + for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + kernel_window[lastheight + i_ih * CONFIG_T::trfilt_width * CONFIG_T::n_chan + i_ic] = shift_buffer[i_ih][i_ic]; + } + } +} + +template +void shift_line_buffer_tr( + const data_T &in_elem, + ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height - 1, 1)] + [CONFIG_T::n_chan], + typename data_T::value_type kernel_window[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan]) { + + #pragma HLS PIPELINE + + // Temporary buffer for popped (shifted) elements + typename data_T::value_type shift_buffer[CONFIG_T::trfilt_height][CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable = shift_buffer complete dim = 0 + +UpdateBuffer: + for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + #pragma HLS UNROLL + + // Insert pixel(s) at end of shift buffer + shift_buffer[CONFIG_T::trfilt_height - 1][i_ic] = in_elem[i_ic]; + } + +LineBufferDataIn: + for (int i_ic = 0; i_ic < CONFIG_T::n_chan; i_ic++) { + // Shift the shift buffer into the line buffer + LineBufferShift: + for (unsigned i_ih = 1; i_ih < CONFIG_T::trfilt_height; i_ih++) { + #pragma HLS UNROLL + typename data_T::value_type pop_elem = line_buffer[i_ih - 1][i_ic].shift( + shift_buffer[CONFIG_T::trfilt_height - i_ih][i_ic]); // Shift the line buffer, return the popped pixel + shift_buffer[CONFIG_T::trfilt_height - i_ih - 1][i_ic] = + pop_elem; // Popped element placed back into shift_buffer, one row up. + } + } + kernel_shift_tr_2d(shift_buffer, kernel_window); +} + +template +void compute_output_buffer_tr_2d(const data_T &in_elem, + ap_shift_reg + line_buffer[MAX(CONFIG_T::trfilt_height - 1, 1)][CONFIG_T::n_chan], + hls::stream &res_stream, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width] + [CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * + CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE + + // Counters + static int pX = 0; // pixel counters + static int pY = 0; + + static typename data_T::value_type kernel_data[CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable=kernel_data complete + + typename res_T::value_type res_out[CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 + + static typename res_T::value_type + output_buffer[CONFIG_T::in_width * CONFIG_T::stride_width * CONFIG_T::stride_height * CONFIG_T::n_filt]; + #pragma HLS ARRAY_PARTITION variable=output_buffer complete dim = 0 + + res_T res_pack; + #pragma HLS DATA_PACK variable = res_pack + + // Add pixel to the buffer + nnet::shift_line_buffer_tr(in_elem, line_buffer, kernel_data); + +HeightStrideLoop: + for (int w_idx = 0; w_idx < CONFIG_T::stride_width; w_idx++) { + // #pragma HLS PIPELINE + #pragma HLS UNROLL + WidthStrideLoop: + for (int h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) { + #pragma HLS UNROLL + + #pragma HLS INLINE region + + if (CONFIG_T::strategy == nnet::latency) { + dense_latency( + kernel_data, res_out, weights[h_idx][w_idx], biases); + } else { + dense_resource( + kernel_data, res_out, weights[h_idx][w_idx], biases); + } + + BufferOutputLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + output_buffer[(pX * CONFIG_T::stride_width + w_idx) * CONFIG_T::stride_height * CONFIG_T::n_filt + + h_idx * CONFIG_T::n_filt + i_ic] = res_out[i_ic]; + } + } + } + + // Counter Housekeeping and printing buffered output + if (pX + 1 == CONFIG_T::in_width) { + pX = 0; + // write all of the buffered output for outputs we want + HeightOutputLoop: + for (unsigned h_idx = 0; h_idx < CONFIG_T::stride_height; h_idx++) { + // #pragma HLS PIPELINE + if (pY * CONFIG_T::stride_height + h_idx >= CONFIG_T::pad_top && + pY * CONFIG_T::stride_height + h_idx < CONFIG_T::pad_top + CONFIG_T::out_height) { + WidthOutputLoop: + for (unsigned oX = CONFIG_T::pad_left; oX < CONFIG_T::pad_left + CONFIG_T::out_width; oX++) { + #pragma HLS PIPELINE + CastLoop: + for (unsigned i_ic = 0; i_ic < CONFIG_T::n_filt; i_ic++) { + #pragma HLS UNROLL + res_pack[i_ic] = + output_buffer[oX * CONFIG_T::stride_height * CONFIG_T::n_filt + h_idx * CONFIG_T::n_filt + i_ic]; + } + res_stream.write(res_pack); + } + } + } + + if (pY + 1 == CONFIG_T::in_height) { + pY = 0; + } else { + pY = pY + 1; + } + } else { + pX = pX + 1; + } +} + +template +void conv_2d_transpose_buffer_cl(hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width] + [CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * + CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + static ap_shift_reg line_buffer[MAX(CONFIG_T::trfilt_height - 1, 1)] + [CONFIG_T::n_chan]; + #pragma HLS ARRAY_PARTITION variable = line_buffer complete dim = 2 + +ReadInputHeight: + for (unsigned i_ih = 0; i_ih < CONFIG_T::in_height; i_ih++) { + ReadInputWidth: + for (unsigned i_iw = 0; i_iw < CONFIG_T::in_width; i_iw++) { + #pragma HLS LOOP_FLATTEN + if (CONFIG_T::strategy == nnet::latency) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + } + compute_output_buffer_tr_2d(data.read(), line_buffer, res, weights, biases); + } + } +} + +template +void conv_2d_transpose_cl(hls::stream &data, hls::stream &res, + typename CONFIG_T::weight_t weights[CONFIG_T::stride_height][CONFIG_T::stride_width] + [CONFIG_T::trfilt_height * CONFIG_T::trfilt_width * + CONFIG_T::n_filt * CONFIG_T::n_chan], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region + switch (CONFIG_T::implementation) { + case conv_implementation::linebuffer: + conv_2d_transpose_buffer_cl(data, res, weights, biases); + break; + } +} + +} // namespace nnet +#endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h b/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h index b8c2a48d19..c4e44b904a 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_helpers.h @@ -47,6 +47,82 @@ template void load_weights_from_txt(T *w, const char *fna } } +template void load_weights_from_txt(T w[DIM_1][DIM_2], const char *fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + size_t j = 0; + size_t tot = 0; + while (std::getline(iss, token, ',')) { + std::istringstream(token) >> w[i][j]; + j++; + if (j == DIM_2) { + j = 0; + i++; + } + tot++; + } + + if (DIM_1 * DIM_2 != tot) { + std::cerr << "ERROR: Expected " << DIM_1 * DIM_2 << " values"; + std::cerr << " but read only " << tot << " values" << std::endl; + } + } +} + +template +void load_weights_from_txt(T w[DIM_1][DIM_2][DIM_3], const char *fname) { + + std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); + std::ifstream infile(full_path.c_str(), std::ios::binary); + + if (infile.fail()) { + std::cerr << "ERROR: file " << std::string(fname) << " does not exist" << std::endl; + exit(1); + } + + std::string line; + if (std::getline(infile, line)) { + std::istringstream iss(line); + std::string token; + + size_t i = 0; + size_t j = 0; + size_t k = 0; + size_t tot = 0; + while (std::getline(iss, token, ',')) { + std::istringstream(token) >> w[i][j][k]; + k++; + if (k == DIM_3) { + k = 0; + j++; + if (j == DIM_2) { + j = 0; + i++; + } + } + tot++; + } + + if (DIM_1 * DIM_2 * DIM_3 != tot) { + std::cerr << "ERROR: Expected " << DIM_1 * DIM_2 * DIM_3 << " values"; + std::cerr << " but read only " << tot << " values" << std::endl; + } + } +} + template void load_compressed_weights_from_txt(T *w, const char *fname) { std::string full_path = std::string(WEIGHTS_DIR) + "/" + std::string(fname); diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 2fbe3d9438..b2c2b8506e 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -42,17 +42,27 @@ def print_array_to_cpp(self, var, odir, write_txt_file=True): h_file.write(var.definition_cpp() + ";\n") h_file.write("#else\n") - h_file.write(var.definition_cpp() + " = {") - - # fill c++ array. - # not including internal brackets for multidimensional case - sep = '' - for x in var: - h_file.write(sep + x) + h_file.write(var.definition_cpp() + " = ") + + factors = np.ones(len(var.shape) + 1) + for idx in range(len(var.shape) - 1, -1, -1): + factors[idx] = var.shape[idx] * factors[idx + 1] + # fill c++ array, keeping the first keep_dims dimensions in-tact. + for idx, x in enumerate(var): + for dim in range(var.keep_dims + 1): + if idx % factors[dim] == 0: + h_file.write("{") + h_file.write(x) if write_txt_file: - txt_file.write(sep + x) - sep = ", " - h_file.write("};\n") + txt_file.write(x) + for dim in range(var.keep_dims + 1): + if idx % factors[dim] == factors[dim] - 1: + h_file.write("}") + if idx < factors[0] - 1: # only don't put comma at the end + h_file.write(", ") + if write_txt_file: + txt_file.write(", ") + h_file.write(";\n") if write_txt_file: h_file.write("#endif\n") txt_file.close() @@ -149,8 +159,15 @@ def write_project_cpp(self, model): w.type.name, w.data_length, w.name, w.name ) else: + dim_info = w.data_length + if w.keep_dims == 1: + dim_info = f'{w.shape[0]}, {w.data_length // w.shape[0]}' + if w.keep_dims == 2: + dim_info = '{}, {}, {}'.format( + w.shape[0], w.shape[1], w.data_length // (w.shape[0] * w.shape[1]) + ) newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( - w.type.name, w.data_length, w.name, w.name + w.type.name, dim_info, w.name, w.name ) # Add input/output type From 72df8e7dbf31f8f978aa97caf96a920f68a376e5 Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Sat, 18 Mar 2023 08:24:21 -0700 Subject: [PATCH 02/11] apply resource --- .../backends/vivado/passes/resource_strategy.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index 63e6e0b4db..b752b1817d 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -1,6 +1,16 @@ import numpy as np -from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import ( + GRU, + LSTM, + Conv1D, + Conv1DTranspose, + Conv2D, + Conv2DTranspose, + Dense, + SeparableConv1D, + SeparableConv2D, +) from hls4ml.model.optimizer import OptimizerPass @@ -8,7 +18,9 @@ class ApplyResourceStrategy(OptimizerPass): '''Transposes the weights to use the dense_resource matrix multiply routine''' def match(self, node): - node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) + node_matches = isinstance( + node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Conv1DTranspose, Conv2DTranspose) + ) is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' already_transformed = node.get_attr('_weights_transposed', False) is True From c67f7ff958902269003a50c5da45faf0a9124537 Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Sat, 18 Mar 2023 10:00:36 -0700 Subject: [PATCH 03/11] fix accum_t; no transpose for resource? --- hls4ml/backends/fpga/fpga_backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 47d6d056a4..41df148b66 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -14,7 +14,9 @@ Activation, BatchNormalization, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, Dense, Dot, Embedding, @@ -52,7 +54,9 @@ def __init__(self, name): accum_layers = [ Dense, Conv1D, + Conv1DTranspose, Conv2D, + Conv2DTranspose, SeparableConv1D, SeparableConv2D, Pooling1D, From 1b02451cc825c38d4c13d78c83e4c28d790435e2 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Wed, 9 Aug 2023 23:20:48 -0500 Subject: [PATCH 04/11] expand testing, start fixing errors --- hls4ml/converters/keras/convolution.py | 6 ++ hls4ml/model/layers.py | 4 +- test/pytest/test_conv_transpose.py | 92 ++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 test/pytest/test_conv_transpose.py diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index b2f3ea7611..1d2fa72f0b 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -55,6 +55,9 @@ def parse_conv1dtranspose_layer(keras_layer, input_names, input_shapes, data_rea layer['padding'] = keras_layer['config']['padding'] layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1) // layer['stride_width'] + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'kernel') + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + ( layer['out_width'], layer['pad_left'], @@ -143,6 +146,9 @@ def parse_conv2dtranspose_layer(keras_layer, input_names, input_shapes, data_rea layer['trfilt_height'] = (layer['filt_height'] + layer['stride_height'] - 1) // layer['stride_height'] layer['trfilt_width'] = (layer['filt_width'] + layer['stride_width'] - 1) // layer['stride_width'] + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'kernel') + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + ( layer['out_height'], layer['out_width'], diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 10ced7ef01..b62be95c0d 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -470,7 +470,7 @@ def initialize(self): shape = [self.attributes['n_filt'], self.attributes['out_width']] dims = [f'N_FILT_{self.index}', f'N_OUTPUTS_{self.index}'] - data = self.model.get_weights_data(self.name, 'kernel') + data = self.get_attr("weight_data") # now we transform the entire kernel # (W,F,C) => (F,W,C) @@ -596,7 +596,7 @@ def initialize(self): shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']] dims = [f'N_FILT_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] - data = self.model.get_weights_data(self.name, 'kernel') + data = self.get_attr("weight_data") # now we transform the entire kernel # (H,W,F,C) => (F,H,W,C) diff --git a/test/pytest/test_conv_transpose.py b/test/pytest/test_conv_transpose.py new file mode 100644 index 0000000000..b8ab0a3223 --- /dev/null +++ b/test/pytest/test_conv_transpose.py @@ -0,0 +1,92 @@ +from pathlib import Path + +import numpy as np +import pytest +from tensorflow.keras.layers import Conv1DTranspose, Conv2DTranspose +from tensorflow.keras.models import Sequential + +import hls4ml + +test_root_path = Path(__file__).parent + + +@pytest.fixture(scope='module') +def data2D(): + X = np.random.rand(10, 5, 5, 3) + return X + + +@pytest.fixture(scope='module') +def data1D(): + X = np.random.rand(10, 5, 5, 3) + return X + + +@pytest.fixture(scope='module') +def model2D(): + model = Sequential() + model.add(Conv2DTranspose(4, (3, 3), input_shape=(5, 5, 3))) + model.compile() + return model + + +@pytest.fixture(scope='module') +def model1D(): + model = Sequential() + model.add(Conv1DTranspose(4, 3, input_shape=(5, 3))) + model.compile() + return model + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('strategy', ['Latency', 'Resource']) +@pytest.mark.filterwarnings("error") +def test_conv1dtranspose(data1D, model1D, io_type, backend, strategy): + ''' + Check that the implementation does not have leftover data. + ''' + + X = data1D + model = model1D + + output_dir = str(test_root_path / f'hls4mlprj_conv1Dtranspose_{backend}_{io_type}_{strategy}') + + config = hls4ml.utils.config_from_keras_model(model) + config['Model']['Strategy'] = strategy + + hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, io_type=io_type, output_dir=output_dir) + hls_model.compile() + + # model under test predictions and accuracy + y_keras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_keras, y_hls4ml, atol=0.05) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('strategy', ['Latency', 'Resource']) +@pytest.mark.filterwarnings("error") +def test_conv2dtranspose(data2D, model2D, io_type, backend, strategy): + ''' + Check that the implementation does not have leftover data. + ''' + + X = data2D + model = model2D + + output_dir = str(test_root_path / f'hls4mlprj_conv2Dtranspose_{backend}_{io_type}_{strategy}') + + config = hls4ml.utils.config_from_keras_model(model) + config['Model']['Strategy'] = strategy + + hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, io_type=io_type, output_dir=output_dir) + hls_model.compile() + + # model under test predictions and accuracy + y_keras = model.predict(X) + y_hls4ml = hls_model.predict(X) + + np.testing.assert_allclose(y_keras, y_hls4ml, atol=0.05) From 77e2a70ab8e4bdab6956422a8f09b3f439eac773 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Fri, 11 Aug 2023 10:45:43 -0400 Subject: [PATCH 05/11] add nzeros parameter to mult_params --- hls4ml/backends/vivado/passes/convolution_templates.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 11083fa069..dba3629f12 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -180,6 +180,7 @@ def format(self, node): // node.get_attr('stride_width') ) mult_params['n_out'] = node.get_attr('n_filt') + mult_params['nzeros'] = node.get_weights('weight').nzeros mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) @@ -389,6 +390,7 @@ def format(self, node): mult_params = self._default_config_params(node) mult_params['n_in'] = node.get_attr('n_chan') * params['trfilt_width'] * params['trfilt_height'] mult_params['n_out'] = node.get_attr('n_filt') + mult_params['nzeros'] = node.get_weights('weight').nzeros mult_params['product_type'] = get_backend('vivado').product_type( node.get_input_variable().type.precision, node.get_weights('weight').type.precision ) From 9c53e16f0af8e4d431c38340b80a07da3ae05775 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sun, 13 Aug 2023 18:29:17 -0400 Subject: [PATCH 06/11] clean up conv2dtranspose init, fix output dimensions --- hls4ml/converters/utils.py | 9 +++++---- hls4ml/model/layers.py | 6 ++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/hls4ml/converters/utils.py b/hls4ml/converters/utils.py index 9ca2938996..25137c5012 100644 --- a/hls4ml/converters/utils.py +++ b/hls4ml/converters/utils.py @@ -89,7 +89,8 @@ def compute_padding_1d_transpose(pad_type, in_size, stride, filt_size): pad_left = pad_along_size // 2 pad_right = pad_along_size - pad_left elif pad_type.lower() == 'valid': - n_out = stride * (in_size - 1) + filt_size + # n_out = stride * (in_size - 1) + filt_size + n_out = in_size * stride + max(filt_size - stride, 0) # from Keras source code pad_left = 0 pad_right = 0 else: @@ -164,15 +165,15 @@ def compute_padding_2d_transpose(pad_type, in_height, in_width, stride_height, s pad_right = pad_along_width - pad_left elif pad_type.lower() == 'valid': # something - out_height = stride_height * in_height - out_width = stride_width * in_width + out_height = in_height * stride_height + max(filt_height - stride_height, 0) + out_width = in_width * stride_width + max(filt_width - stride_width, 0) pad_top = 0 pad_bottom = 0 pad_left = 0 pad_right = 0 else: - raise Exception(f'Unknown padding type: {pad_type}') + raise ValueError(f'Unknown padding type: {pad_type}') return (out_height, out_width, pad_top, pad_bottom, pad_left, pad_right) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index b62be95c0d..07b0b41dbe 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -625,12 +625,10 @@ def initialize(self): new_weights[i_sh][i_sw][i_nf][i_fh][i_fw][i_nc] = data[i_nf][filt_h_ind][filt_w_ind][ i_nc ] - data = new_weights + self.set_attr("weight_data", new_weights) self.add_output_variable(shape, dims) - self.add_weights_variable( - name='weight', var_name='w{index}', data=data, quantizer=self.get_attr('weight_quantizer'), keep_dims=2 - ) + self.add_weights(quantizer=self.get_attr('weight_quantizer'), keep_dims=2) self.add_bias(quantizer=self.get_attr('bias_quantizer')) From 554be071dda7bf5888be9c59db7a179d9fdb95f6 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Thu, 24 Aug 2023 13:51:19 -0500 Subject: [PATCH 07/11] update resource strategy to support conv*dtranspose --- hls4ml/backends/vivado/passes/resource_strategy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index b752b1817d..ce97d9b5c3 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -52,6 +52,8 @@ def transform(self, model, node): elif isinstance(node, (LSTM, GRU)): node.weights['weight'].data = np.transpose(node.weights['weight'].data) node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) + elif isinstance(node, (Conv2DTranspose, Conv1DTranspose)): + pass else: raise Exception(f'Unexpected layer {node.class_name} with resource strategy') From 6daa5cc40fd0c3875613c165599bb63b85cf3141 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Thu, 24 Aug 2023 14:02:19 -0500 Subject: [PATCH 08/11] update test --- test/pytest/test_conv_transpose.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/pytest/test_conv_transpose.py b/test/pytest/test_conv_transpose.py index b8ab0a3223..07e6c2b960 100644 --- a/test/pytest/test_conv_transpose.py +++ b/test/pytest/test_conv_transpose.py @@ -38,9 +38,9 @@ def model1D(): return model -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -@pytest.mark.parametrize('strategy', ['Latency', 'Resource']) +@pytest.mark.parametrize('strategy', ['Resource']) @pytest.mark.filterwarnings("error") def test_conv1dtranspose(data1D, model1D, io_type, backend, strategy): ''' @@ -62,12 +62,12 @@ def test_conv1dtranspose(data1D, model1D, io_type, backend, strategy): y_keras = model.predict(X) y_hls4ml = hls_model.predict(X) - np.testing.assert_allclose(y_keras, y_hls4ml, atol=0.05) + np.testing.assert_allclose(y_keras.ravel(), y_hls4ml.ravel(), atol=0.05) -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -@pytest.mark.parametrize('strategy', ['Latency', 'Resource']) +@pytest.mark.parametrize('strategy', ['Resource']) @pytest.mark.filterwarnings("error") def test_conv2dtranspose(data2D, model2D, io_type, backend, strategy): ''' @@ -89,4 +89,4 @@ def test_conv2dtranspose(data2D, model2D, io_type, backend, strategy): y_keras = model.predict(X) y_hls4ml = hls_model.predict(X) - np.testing.assert_allclose(y_keras, y_hls4ml, atol=0.05) + np.testing.assert_allclose(y_keras.ravel(), y_hls4ml.ravel(), atol=0.05) From 056ec30972d70fb19b6143a44b99084d7095f115 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Thu, 24 Aug 2023 14:12:14 -0500 Subject: [PATCH 09/11] fix conv1dtranspose test --- test/pytest/test_conv_transpose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pytest/test_conv_transpose.py b/test/pytest/test_conv_transpose.py index 07e6c2b960..1db8107184 100644 --- a/test/pytest/test_conv_transpose.py +++ b/test/pytest/test_conv_transpose.py @@ -18,7 +18,7 @@ def data2D(): @pytest.fixture(scope='module') def data1D(): - X = np.random.rand(10, 5, 5, 3) + X = np.random.rand(10, 5, 3) return X From 3c92d0e11bcb093a22887b237ce5aea6b7708b1a Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Thu, 24 Aug 2023 14:29:45 -0500 Subject: [PATCH 10/11] reverse resource strategy changes --- .../vivado/passes/resource_strategy.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index ce97d9b5c3..63e6e0b4db 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -1,16 +1,6 @@ import numpy as np -from hls4ml.model.layers import ( - GRU, - LSTM, - Conv1D, - Conv1DTranspose, - Conv2D, - Conv2DTranspose, - Dense, - SeparableConv1D, - SeparableConv2D, -) +from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D from hls4ml.model.optimizer import OptimizerPass @@ -18,9 +8,7 @@ class ApplyResourceStrategy(OptimizerPass): '''Transposes the weights to use the dense_resource matrix multiply routine''' def match(self, node): - node_matches = isinstance( - node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Conv1DTranspose, Conv2DTranspose) - ) + node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) is_resource_strategy = node.get_attr('strategy', '').lower() == 'resource' already_transformed = node.get_attr('_weights_transposed', False) is True @@ -52,8 +40,6 @@ def transform(self, model, node): elif isinstance(node, (LSTM, GRU)): node.weights['weight'].data = np.transpose(node.weights['weight'].data) node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) - elif isinstance(node, (Conv2DTranspose, Conv1DTranspose)): - pass else: raise Exception(f'Unexpected layer {node.class_name} with resource strategy') From 1554435bd5ec25ac6efb561c24188b1626a80b0a Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Sun, 17 Sep 2023 16:09:39 -0500 Subject: [PATCH 11/11] replace data pack variables compatible with Vitis HLS --- .../templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h | 2 +- .../templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h index 3a38f8e7d2..afe35f4369 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1dtranspose_stream.h @@ -57,7 +57,7 @@ void compute_output_buffer_tr_1d( #pragma HLS ARRAY_PARTITION variable=res_out complete dim = 0 res_T res_pack; - #pragma HLS DATA_PACK variable=res_pack + PRAGMA_DATA_PACK(res_pack) // Add pixel to buffer nnet::kernel_shift_tr_1d(in_elem, kernel_data); diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h index 555f5ae2be..cf304fcb90 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2dtranspose_stream.h @@ -104,7 +104,7 @@ void compute_output_buffer_tr_2d(const data_T &in_elem, #pragma HLS ARRAY_PARTITION variable=output_buffer complete dim = 0 res_T res_pack; - #pragma HLS DATA_PACK variable = res_pack + PRAGMA_DATA_PACK(res_pack) // Add pixel to the buffer nnet::shift_line_buffer_tr(in_elem, line_buffer, kernel_data);