Skip to content

Commit 2f6443a

Browse files
committed
Fix SR writer
1 parent b7c767b commit 2f6443a

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

hls4ml/writer/symbolic_writer.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import glob
22
import os
3+
import stat
4+
from pathlib import Path
35
from shutil import copyfile, copytree, rmtree
46

57
from hls4ml.backends import get_backend
@@ -56,49 +58,48 @@ def write_build_script(self, model):
5658
model (ModelGraph): the hls4ml model.
5759
"""
5860

59-
filedir = os.path.dirname(os.path.abspath(__file__))
60-
61-
# build_prj.tcl
62-
f = open(f'{model.config.get_output_dir()}/project.tcl', 'w')
63-
f.write('variable project_name\n')
64-
f.write(f'set project_name "{model.config.get_project_name()}"\n')
65-
f.write('variable backend\n')
66-
f.write('set backend "vivado"\n')
67-
f.write('variable part\n')
68-
f.write('set part "{}"\n'.format(model.config.get_config_value('Part')))
69-
f.write('variable clock_period\n')
70-
f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod')))
71-
f.write('variable clock_uncertainty\n')
72-
f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%')))
73-
f.write('variable version\n')
74-
f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0')))
75-
f.close()
61+
filedir = Path(__file__).parent
62+
63+
# project.tcl
64+
prj_tcl_dst = Path(f'{model.config.get_output_dir()}/project.tcl')
65+
with open(prj_tcl_dst, 'w') as f:
66+
f.write('variable project_name\n')
67+
f.write(f'set project_name "{model.config.get_project_name()}"\n')
68+
f.write('variable backend\n')
69+
f.write('set backend "vivado"\n')
70+
f.write('variable part\n')
71+
f.write('set part "{}"\n'.format(model.config.get_config_value('Part')))
72+
f.write('variable clock_period\n')
73+
f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod')))
74+
f.write('variable clock_uncertainty\n')
75+
f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%')))
76+
f.write('variable version\n')
77+
f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0')))
7678

7779
# build_prj.tcl
78-
srcpath = os.path.join(filedir, '../templates/vivado/build_prj.tcl')
80+
srcpath = (filedir / '../templates/vivado/build_prj.tcl').resolve()
7981
dstpath = f'{model.config.get_output_dir()}/build_prj.tcl'
8082
copyfile(srcpath, dstpath)
8183

8284
# vivado_synth.tcl
83-
srcpath = os.path.join(filedir, '../templates/vivado/vivado_synth.tcl')
85+
srcpath = (filedir / '../templates/vivado/vivado_synth.tcl').resolve()
8486
dstpath = f'{model.config.get_output_dir()}/vivado_synth.tcl'
8587
copyfile(srcpath, dstpath)
8688

8789
# build_lib.sh
88-
f = open(os.path.join(filedir, '../templates/symbolic/build_lib.sh'))
89-
fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w')
90-
91-
for line in f.readlines():
92-
line = line.replace('myproject', model.config.get_project_name())
93-
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
94-
line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath'))
95-
96-
if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')):
97-
line = 'LDFLAGS=\n'
98-
99-
fout.write(line)
100-
f.close()
101-
fout.close()
90+
build_lib_src = (filedir / '../templates/symbolic/build_lib.sh').resolve()
91+
build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve()
92+
with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst:
93+
for line in src.readlines():
94+
line = line.replace('myproject', model.config.get_project_name())
95+
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
96+
line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath'))
97+
98+
if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')):
99+
line = 'LDFLAGS=\n'
100+
101+
dst.write(line)
102+
build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC)
102103

103104
def write_hls(self, model):
104105
print('Writing HLS project')

0 commit comments

Comments
 (0)