Skip to content

Commit adf2768

Browse files
authored
add merge_all_horizontal_groups cinn flag for llama (#10629)
1 parent 64afc81 commit adf2768

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

tests/test_tipc/static/auto_parallel/llama2/benchmark_common/run_benchmark.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ export FLAGS_use_cinn=1
265265
export FLAGS_dist_prim_all=1
266266
export FLAGS_prim_forward_blacklist="pd_op.stack;pd_op.squeeze;pd_op.swiglu;pd_op.squared_l2_norm"
267267
export FLAGS_prim_backward_blacklist="swiglu_grad"
268+
export FLAGS_merge_all_horizontal_groups=1
268269

269270
source ${BENCHMARK_ROOT}/scripts/run_model.sh # 在该脚本中会对符合benchmark规范的log使用analysis.py 脚本进行性能数据解析;如果不联调只想要产出训练log可以注掉本行,提交时需打开
270271
_set_params $@

0 commit comments

Comments
 (0)