|
1 | 1 | import glob
|
2 | 2 | import os
|
| 3 | +import stat |
| 4 | +from pathlib import Path |
3 | 5 | from shutil import copyfile, copytree, rmtree
|
4 | 6 |
|
5 | 7 | from hls4ml.backends import get_backend
|
@@ -56,49 +58,48 @@ def write_build_script(self, model):
|
56 | 58 | model (ModelGraph): the hls4ml model.
|
57 | 59 | """
|
58 | 60 |
|
59 |
| - filedir = os.path.dirname(os.path.abspath(__file__)) |
60 |
| - |
61 |
| - # build_prj.tcl |
62 |
| - f = open(f'{model.config.get_output_dir()}/project.tcl', 'w') |
63 |
| - f.write('variable project_name\n') |
64 |
| - f.write(f'set project_name "{model.config.get_project_name()}"\n') |
65 |
| - f.write('variable backend\n') |
66 |
| - f.write('set backend "vivado"\n') |
67 |
| - f.write('variable part\n') |
68 |
| - f.write('set part "{}"\n'.format(model.config.get_config_value('Part'))) |
69 |
| - f.write('variable clock_period\n') |
70 |
| - f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))) |
71 |
| - f.write('variable clock_uncertainty\n') |
72 |
| - f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%'))) |
73 |
| - f.write('variable version\n') |
74 |
| - f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0'))) |
75 |
| - f.close() |
| 61 | + filedir = Path(__file__).parent |
| 62 | + |
| 63 | + # project.tcl |
| 64 | + prj_tcl_dst = Path(f'{model.config.get_output_dir()}/project.tcl') |
| 65 | + with open(prj_tcl_dst, 'w') as f: |
| 66 | + f.write('variable project_name\n') |
| 67 | + f.write(f'set project_name "{model.config.get_project_name()}"\n') |
| 68 | + f.write('variable backend\n') |
| 69 | + f.write('set backend "vivado"\n') |
| 70 | + f.write('variable part\n') |
| 71 | + f.write('set part "{}"\n'.format(model.config.get_config_value('Part'))) |
| 72 | + f.write('variable clock_period\n') |
| 73 | + f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))) |
| 74 | + f.write('variable clock_uncertainty\n') |
| 75 | + f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%'))) |
| 76 | + f.write('variable version\n') |
| 77 | + f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0'))) |
76 | 78 |
|
77 | 79 | # build_prj.tcl
|
78 |
| - srcpath = os.path.join(filedir, '../templates/vivado/build_prj.tcl') |
| 80 | + srcpath = (filedir / '../templates/vivado/build_prj.tcl').resolve() |
79 | 81 | dstpath = f'{model.config.get_output_dir()}/build_prj.tcl'
|
80 | 82 | copyfile(srcpath, dstpath)
|
81 | 83 |
|
82 | 84 | # vivado_synth.tcl
|
83 |
| - srcpath = os.path.join(filedir, '../templates/vivado/vivado_synth.tcl') |
| 85 | + srcpath = (filedir / '../templates/vivado/vivado_synth.tcl').resolve() |
84 | 86 | dstpath = f'{model.config.get_output_dir()}/vivado_synth.tcl'
|
85 | 87 | copyfile(srcpath, dstpath)
|
86 | 88 |
|
87 | 89 | # build_lib.sh
|
88 |
| - f = open(os.path.join(filedir, '../templates/symbolic/build_lib.sh')) |
89 |
| - fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w') |
90 |
| - |
91 |
| - for line in f.readlines(): |
92 |
| - line = line.replace('myproject', model.config.get_project_name()) |
93 |
| - line = line.replace('mystamp', model.config.get_config_value('Stamp')) |
94 |
| - line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath')) |
95 |
| - |
96 |
| - if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')): |
97 |
| - line = 'LDFLAGS=\n' |
98 |
| - |
99 |
| - fout.write(line) |
100 |
| - f.close() |
101 |
| - fout.close() |
| 90 | + build_lib_src = (filedir / '../templates/symbolic/build_lib.sh').resolve() |
| 91 | + build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve() |
| 92 | + with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst: |
| 93 | + for line in src.readlines(): |
| 94 | + line = line.replace('myproject', model.config.get_project_name()) |
| 95 | + line = line.replace('mystamp', model.config.get_config_value('Stamp')) |
| 96 | + line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath')) |
| 97 | + |
| 98 | + if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')): |
| 99 | + line = 'LDFLAGS=\n' |
| 100 | + |
| 101 | + dst.write(line) |
| 102 | + build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC) |
102 | 103 |
|
103 | 104 | def write_hls(self, model):
|
104 | 105 | print('Writing HLS project')
|
|
0 commit comments