Skip to content

Commit cfbad0b

Browse files
authored
Merge branch 'main' into hls4ml-optimization-api-part-2
2 parents 97c5347 + b4111c6 commit cfbad0b

File tree

9 files changed

+150
-145
lines changed

9 files changed

+150
-145
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte
22

33
repos:
44
- repo: https://github.com/psf/black
5-
rev: 24.8.0
5+
rev: 24.10.0
66
hooks:
77
- id: black
88
language_version: python3
99
args: ['--line-length=125',
1010
'--skip-string-normalization']
1111

1212
- repo: https://github.com/pre-commit/pre-commit-hooks
13-
rev: v4.6.0
13+
rev: v5.0.0
1414
hooks:
1515
- id: check-added-large-files
1616
- id: check-case-conflict
@@ -30,13 +30,13 @@ repos:
3030
args: ["--profile", "black", --line-length=125]
3131

3232
- repo: https://github.com/asottile/pyupgrade
33-
rev: v3.17.0
33+
rev: v3.18.0
3434
hooks:
3535
- id: pyupgrade
3636
args: ["--py36-plus"]
3737

3838
- repo: https://github.com/asottile/setup-cfg-fmt
39-
rev: v2.5.0
39+
rev: v2.7.0
4040
hooks:
4141
- id: setup-cfg-fmt
4242

@@ -50,7 +50,7 @@ repos:
5050
'--extend-ignore=E203,T201'] # E203 is not PEP8 compliant
5151

5252
- repo: https://github.com/mgedmin/check-manifest
53-
rev: "0.49"
53+
rev: "0.50"
5454
hooks:
5555
- id: check-manifest
5656
stages: [manual]

Jenkinsfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pipeline {
1616
sh '''#!/bin/bash --login
1717
conda activate hls4ml-py310
1818
conda install -y jupyterhub pydot graphviz pytest pytest-cov
19-
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/google/qkeras.git pyparsing
19+
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/jmitrevs/qkeras.git@qrecurrent_unstack pyparsing
2020
pip install -U ../ --user
2121
./convert-keras-models.sh -x -f keras-models.txt
2222
pip uninstall hls4ml -y'''

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
2-
import os
32
import re
3+
import subprocess
44
from bisect import bisect_left
55
from collections.abc import Iterable
66

@@ -131,19 +131,22 @@ def compile(self, model):
131131
Returns:
132132
string: Returns the name of the compiled library.
133133
"""
134-
curr_dir = os.getcwd()
135-
os.chdir(model.config.get_output_dir())
136134

137135
lib_name = None
138-
try:
139-
ret_val = os.system('bash build_lib.sh')
140-
if ret_val != 0:
141-
raise Exception(f'Failed to compile project "{model.config.get_project_name()}"')
142-
lib_name = '{}/firmware/{}-{}.so'.format(
143-
model.config.get_output_dir(), model.config.get_project_name(), model.config.get_config_value('Stamp')
144-
)
145-
finally:
146-
os.chdir(curr_dir)
136+
ret_val = subprocess.run(
137+
['./build_lib.sh'],
138+
shell=True,
139+
text=True,
140+
stdout=subprocess.PIPE,
141+
stderr=subprocess.STDOUT,
142+
cwd=model.config.get_output_dir(),
143+
)
144+
if ret_val.returncode != 0:
145+
print(ret_val.stdout)
146+
raise Exception(f'Failed to compile project "{model.config.get_project_name()}"')
147+
lib_name = '{}/firmware/{}-{}.so'.format(
148+
model.config.get_output_dir(), model.config.get_project_name(), model.config.get_config_value('Stamp')
149+
)
147150

148151
return lib_name
149152

hls4ml/templates/quartus/build_lib.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/bin/bash
2+
set -e
23

34
CC=g++
45
if [[ "$OSTYPE" == "linux-gnu" ]]; then

hls4ml/templates/vivado/build_lib.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/bin/bash
2+
set -e
23

34
CC=g++
45
if [[ "$OSTYPE" == "linux-gnu" ]]; then
@@ -10,8 +11,9 @@ LDFLAGS=
1011
INCFLAGS="-Ifirmware/ap_types/"
1112
PROJECT=myproject
1213
LIB_STAMP=mystamp
14+
WEIGHTS_DIR="\"weights\""
1315

14-
${CC} ${CFLAGS} ${INCFLAGS} -c firmware/${PROJECT}.cpp -o ${PROJECT}.o
15-
${CC} ${CFLAGS} ${INCFLAGS} -c ${PROJECT}_bridge.cpp -o ${PROJECT}_bridge.o
16+
${CC} ${CFLAGS} ${INCFLAGS} -D WEIGHTS_DIR=${WEIGHTS_DIR} -c firmware/${PROJECT}.cpp -o ${PROJECT}.o
17+
${CC} ${CFLAGS} ${INCFLAGS} -D WEIGHTS_DIR=${WEIGHTS_DIR} -c ${PROJECT}_bridge.cpp -o ${PROJECT}_bridge.o
1618
${CC} ${CFLAGS} ${INCFLAGS} -shared ${PROJECT}.o ${PROJECT}_bridge.o -o firmware/${PROJECT}-${LIB_STAMP}.so
1719
rm -f *.o

hls4ml/writer/catapult_writer.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import glob
22
import os
3+
import stat
34
import tarfile
45
from collections import OrderedDict
6+
from pathlib import Path
57
from shutil import copyfile, copytree, rmtree
68

79
import numpy as np
@@ -749,55 +751,50 @@ def write_build_script(self, model):
749751
model (ModelGraph): the hls4ml model.
750752
"""
751753

752-
filedir = os.path.dirname(os.path.abspath(__file__))
754+
filedir = Path(__file__).parent
753755

754756
# build_prj.tcl
755-
srcpath = os.path.join(filedir, '../templates/catapult/build_prj.tcl')
756-
dstpath = f'{model.config.get_output_dir()}/build_prj.tcl'
757-
# copyfile(srcpath, dstpath)
758-
f = open(srcpath)
759-
fout = open(dstpath, 'w')
760-
for line in f.readlines():
761-
indent = line[: len(line) - len(line.lstrip())]
762-
line = line.replace('myproject', model.config.get_project_name())
763-
line = line.replace('CATAPULT_DIR', model.config.get_project_dir())
764-
if '#hls-fpga-machine-learning insert techlibs' in line:
765-
if model.config.get_config_value('Technology') is None:
766-
if model.config.get_config_value('Part') is not None:
767-
line = indent + 'setup_xilinx_part {{{}}}\n'.format(model.config.get_config_value('Part'))
768-
elif model.config.get_config_value('ASICLibs') is not None:
769-
line = indent + 'setup_asic_libs {{{}}}\n'.format(model.config.get_config_value('ASICLibs'))
770-
else:
771-
if model.config.get_config_value('Technology') == 'asic':
772-
line = indent + 'setup_asic_libs {{{}}}\n'.format(model.config.get_config_value('ASICLibs'))
757+
srcpath = (filedir / '../templates/catapult/build_prj.tcl').resolve()
758+
dstpath = Path(f'{model.config.get_output_dir()}/build_prj.tcl').resolve()
759+
with open(srcpath) as src, open(dstpath, 'w') as dst:
760+
for line in src.readlines():
761+
indent = line[: len(line) - len(line.lstrip())]
762+
line = line.replace('myproject', model.config.get_project_name())
763+
line = line.replace('CATAPULT_DIR', model.config.get_project_dir())
764+
if '#hls-fpga-machine-learning insert techlibs' in line:
765+
if model.config.get_config_value('Technology') is None:
766+
if model.config.get_config_value('Part') is not None:
767+
line = indent + 'setup_xilinx_part {{{}}}\n'.format(model.config.get_config_value('Part'))
768+
elif model.config.get_config_value('ASICLibs') is not None:
769+
line = indent + 'setup_asic_libs {{{}}}\n'.format(model.config.get_config_value('ASICLibs'))
773770
else:
774-
line = indent + 'setup_xilinx_part {{{}}}\n'.format(model.config.get_config_value('Part'))
775-
elif '#hls-fpga-machine-learning insert invoke_args' in line:
776-
tb_in_file = model.config.get_config_value('InputData')
777-
tb_out_file = model.config.get_config_value('OutputPredictions')
778-
invoke_args = '$sfd/firmware/weights'
779-
if tb_in_file is not None:
780-
invoke_args = invoke_args + f' $sfd/tb_data/{tb_in_file}'
781-
if tb_out_file is not None:
782-
invoke_args = invoke_args + f' $sfd/tb_data/{tb_out_file}'
783-
line = indent + f'flow package option set /SCVerify/INVOKE_ARGS "{invoke_args}"\n'
784-
elif 'set hls_clock_period 5' in line:
785-
line = indent + 'set hls_clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))
786-
fout.write(line)
787-
f.close()
788-
fout.close()
771+
if model.config.get_config_value('Technology') == 'asic':
772+
line = indent + 'setup_asic_libs {{{}}}\n'.format(model.config.get_config_value('ASICLibs'))
773+
else:
774+
line = indent + 'setup_xilinx_part {{{}}}\n'.format(model.config.get_config_value('Part'))
775+
elif '#hls-fpga-machine-learning insert invoke_args' in line:
776+
tb_in_file = model.config.get_config_value('InputData')
777+
tb_out_file = model.config.get_config_value('OutputPredictions')
778+
invoke_args = '$sfd/firmware/weights'
779+
if tb_in_file is not None:
780+
invoke_args = invoke_args + f' $sfd/tb_data/{tb_in_file}'
781+
if tb_out_file is not None:
782+
invoke_args = invoke_args + f' $sfd/tb_data/{tb_out_file}'
783+
line = indent + f'flow package option set /SCVerify/INVOKE_ARGS "{invoke_args}"\n'
784+
elif 'set hls_clock_period 5' in line:
785+
line = indent + 'set hls_clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))
786+
dst.write(line)
789787

790788
# build_lib.sh
791-
f = open(os.path.join(filedir, '../templates/catapult/build_lib.sh'))
792-
fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w')
793-
794-
for line in f.readlines():
795-
line = line.replace('myproject', model.config.get_project_name())
796-
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
797-
798-
fout.write(line)
799-
f.close()
800-
fout.close()
789+
build_lib_src = (filedir / '../templates/catapult/build_lib.sh').resolve()
790+
build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve()
791+
with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst:
792+
for line in src.readlines():
793+
line = line.replace('myproject', model.config.get_project_name())
794+
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
795+
796+
dst.write(line)
797+
build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC)
801798

802799
def write_nnet_utils(self, model):
803800
"""Copy the nnet_utils, AP types headers and any custom source to the project output directory

hls4ml/writer/quartus_writer.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import glob
22
import os
3+
import stat
34
import tarfile
45
from collections import OrderedDict
6+
from pathlib import Path
57
from shutil import copyfile, copytree, rmtree
68

79
import numpy as np
@@ -877,32 +879,30 @@ def write_build_script(self, model):
877879
model (ModelGraph): the hls4ml model.
878880
"""
879881

880-
# Makefile
881-
filedir = os.path.dirname(os.path.abspath(__file__))
882-
f = open(os.path.join(filedir, '../templates/quartus/Makefile'))
883-
fout = open(f'{model.config.get_output_dir()}/Makefile', 'w')
882+
filedir = Path(__file__).parent
884883

885-
for line in f.readlines():
886-
line = line.replace('myproject', model.config.get_project_name())
884+
# Makefile
885+
makefile_src = (filedir / '../templates/quartus/Makefile').resolve()
886+
makefile_dst = Path(f'{model.config.get_output_dir()}/Makefile').resolve()
887+
with open(makefile_src) as src, open(makefile_dst, 'w') as dst:
888+
for line in src.readlines():
889+
line = line.replace('myproject', model.config.get_project_name())
887890

888-
if 'DEVICE :=' in line:
889-
line = 'DEVICE := {}\n'.format(model.config.get_config_value('Part'))
891+
if 'DEVICE :=' in line:
892+
line = 'DEVICE := {}\n'.format(model.config.get_config_value('Part'))
890893

891-
fout.write(line)
892-
f.close()
893-
fout.close()
894+
dst.write(line)
894895

895896
# build_lib.sh
896-
f = open(os.path.join(filedir, '../templates/quartus/build_lib.sh'))
897-
fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w')
898-
899-
for line in f.readlines():
900-
line = line.replace('myproject', model.config.get_project_name())
901-
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
902-
903-
fout.write(line)
904-
f.close()
905-
fout.close()
897+
build_lib_src = (filedir / '../templates/quartus/build_lib.sh').resolve()
898+
build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve()
899+
with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst:
900+
for line in src.readlines():
901+
line = line.replace('myproject', model.config.get_project_name())
902+
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
903+
904+
dst.write(line)
905+
build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC)
906906

907907
def write_nnet_utils(self, model):
908908
"""Copy the nnet_utils, AP types headers and any custom source to the project output directory

hls4ml/writer/symbolic_writer.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import glob
22
import os
3+
import stat
4+
from pathlib import Path
35
from shutil import copyfile, copytree, rmtree
46

57
from hls4ml.backends import get_backend
@@ -56,49 +58,48 @@ def write_build_script(self, model):
5658
model (ModelGraph): the hls4ml model.
5759
"""
5860

59-
filedir = os.path.dirname(os.path.abspath(__file__))
60-
61-
# build_prj.tcl
62-
f = open(f'{model.config.get_output_dir()}/project.tcl', 'w')
63-
f.write('variable project_name\n')
64-
f.write(f'set project_name "{model.config.get_project_name()}"\n')
65-
f.write('variable backend\n')
66-
f.write('set backend "vivado"\n')
67-
f.write('variable part\n')
68-
f.write('set part "{}"\n'.format(model.config.get_config_value('Part')))
69-
f.write('variable clock_period\n')
70-
f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod')))
71-
f.write('variable clock_uncertainty\n')
72-
f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%')))
73-
f.write('variable version\n')
74-
f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0')))
75-
f.close()
61+
filedir = Path(__file__).parent
62+
63+
# project.tcl
64+
prj_tcl_dst = Path(f'{model.config.get_output_dir()}/project.tcl')
65+
with open(prj_tcl_dst, 'w') as f:
66+
f.write('variable project_name\n')
67+
f.write(f'set project_name "{model.config.get_project_name()}"\n')
68+
f.write('variable backend\n')
69+
f.write('set backend "vivado"\n')
70+
f.write('variable part\n')
71+
f.write('set part "{}"\n'.format(model.config.get_config_value('Part')))
72+
f.write('variable clock_period\n')
73+
f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod')))
74+
f.write('variable clock_uncertainty\n')
75+
f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%')))
76+
f.write('variable version\n')
77+
f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0')))
7678

7779
# build_prj.tcl
78-
srcpath = os.path.join(filedir, '../templates/vivado/build_prj.tcl')
80+
srcpath = (filedir / '../templates/vivado/build_prj.tcl').resolve()
7981
dstpath = f'{model.config.get_output_dir()}/build_prj.tcl'
8082
copyfile(srcpath, dstpath)
8183

8284
# vivado_synth.tcl
83-
srcpath = os.path.join(filedir, '../templates/vivado/vivado_synth.tcl')
85+
srcpath = (filedir / '../templates/vivado/vivado_synth.tcl').resolve()
8486
dstpath = f'{model.config.get_output_dir()}/vivado_synth.tcl'
8587
copyfile(srcpath, dstpath)
8688

8789
# build_lib.sh
88-
f = open(os.path.join(filedir, '../templates/symbolic/build_lib.sh'))
89-
fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w')
90-
91-
for line in f.readlines():
92-
line = line.replace('myproject', model.config.get_project_name())
93-
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
94-
line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath'))
95-
96-
if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')):
97-
line = 'LDFLAGS=\n'
98-
99-
fout.write(line)
100-
f.close()
101-
fout.close()
90+
build_lib_src = (filedir / '../templates/symbolic/build_lib.sh').resolve()
91+
build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve()
92+
with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst:
93+
for line in src.readlines():
94+
line = line.replace('myproject', model.config.get_project_name())
95+
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
96+
line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath'))
97+
98+
if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')):
99+
line = 'LDFLAGS=\n'
100+
101+
dst.write(line)
102+
build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC)
102103

103104
def write_hls(self, model):
104105
print('Writing HLS project')

0 commit comments

Comments
 (0)