Skip to content

Commit dc68aac

Browse files
committed
Merge branch 'develop' into feat/add-swanlab-logger
2 parents 98c3242 + ddcb722 commit dc68aac

File tree

8 files changed

+718
-17
lines changed

8 files changed

+718
-17
lines changed

llm/devices/intel_hpu/tests/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#setup custom tpc library
2+
unset FLAGS_selected_intel_hpus
3+
export GC_KERNEL_PATH=/workspace/pdpd_automation/repo/PaddleCustomDevice/backends/intel_hpu/build/libcustom_tpc_perf_lib.so:/usr/lib/habanalabs/libtpc_kernels.so
4+
5+
#test cmdline example
6+
#PR test cases
7+
python e2e-test-run.py --context pr --data /data/ckpt/ --filter stable --device intel_hpu --junit test_result.xml --platform gaudi2d
8+
#BAT test cases
9+
python e2e-test-run.py --context bat --data /data/ckpt/ --filter stable --device intel_hpu --junit test_result.xml --platform gaudi2d
10+
#smoke test cases
11+
python e2e-test-run.py --context sanity --data /data/ckpt/ --filter stable --device intel_hpu --junit test_result.xml --platform gaudi2d
12+
python e2e-test-run.py --context sanity --data /data/ckpt/ --filter stable --device intel_hpu:2 --junit test_result.xml --platform gaudi2d
13+
14+
#static graph mode Inference
15+
export PYTHONPATH=$PYTHONPATH:/workspace/pdpd_automation/repo/PaddleNLP/
16+
export FLAGS_intel_hpu_execution_queue_size=10
17+
#export relative static mode file
18+
python export_model.py --model_name_or_path /data/ckpt/meta-llama/Llama-2-7b-chat/ --inference_model --output_path ./inference --dtype bfloat16 --device intel_hpu
19+
#run static mode inference
20+
python e2e-test-run.py --context sanity --data /data/ckpt/ --filter stable --device intel_hpu:2 --mode static --junit test_result.xml --platform gaudi2d
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
18+
def create_variants(
19+
scr_length: list, max_length: list, total_max_length: list, batch_size: list, decode_strategy: list
20+
):
21+
return [
22+
dict(zip(["src_length", "max_length", "total_max_length", "batch_size", "decode_strategy"], v))
23+
for v in itertools.product(scr_length, max_length, total_max_length, batch_size, decode_strategy)
24+
]
25+
26+
27+
def add_testcase(
28+
dict_lst: dict,
29+
model_full_name,
30+
scr_length: list,
31+
max_length: list,
32+
total_max_length: list,
33+
batch_size: list,
34+
decode_strategy: list,
35+
):
36+
case_dict = dict_lst.setdefault(model_full_name, {})
37+
case_dict["variants"] = {
38+
"inference": create_variants(scr_length, max_length, total_max_length, batch_size, decode_strategy)
39+
}
40+
case_dict["output_file"] = f"{model_full_name.lower().split('/')[-1]}.json"
41+
42+
43+
test_case_lst = {}
44+
skip_case_lst = {}
45+
46+
for i in ["bat", "pr", "sanity", "full"]:
47+
test_case_lst.setdefault(i, {})
48+
skip_case_lst.setdefault(i, {})
49+
add_testcase(
50+
test_case_lst[i],
51+
"meta-llama/Llama-2-7b-chat",
52+
["128"],
53+
["128"],
54+
["256"],
55+
["1", "16"],
56+
["sampling", "greedy_search", "beam_search"],
57+
)
58+
59+
# testcase list + model name + scr_length(list) + max_length(list) + total_max_length (list) + batch_size(list) + decode_strategy (list)
60+
add_testcase(
61+
test_case_lst["sanity"],
62+
"meta-llama/Llama-2-13b-chat",
63+
["128"],
64+
["128"],
65+
["256"],
66+
["1", "16"],
67+
["greedy_search", "greedy_search", "beam_search"],
68+
)
69+
add_testcase(
70+
test_case_lst["sanity"], "meta-llama/Llama-2-70b-chat", ["128"], ["128"], ["256"], ["1", "16"], ["greedy_search"]
71+
)
72+
73+
# when filter passwdown 'stable' will load this list
74+
# this list for the unstable test case to skip
75+
skip_case_lst["sanity"]["stable"] = []
76+
77+
# when filter passwdown 'unstable' will load this list
78+
# this list for the stable test case to skip
79+
skip_case_lst["sanity"]["unstable"] = []
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#!/usr/bin/python3
2+
3+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import argparse
18+
import os
19+
import subprocess
20+
21+
from util import get_model_path
22+
23+
realwd = os.path.dirname(os.path.realpath(__file__))
24+
paddlenlp_path = os.path.realpath(f"{realwd}/../../")
25+
26+
data_path = os.environ.get("DATA_DIR", "/data")
27+
28+
os.environ.setdefault("GAUDI2_CI", "1")
29+
30+
cmd_args = {}
31+
parser = argparse.ArgumentParser(
32+
description="help scriopt for pdpd e2e run test on intel hpu",
33+
add_help=True,
34+
formatter_class=argparse.RawTextHelpFormatter,
35+
)
36+
parser.add_argument(
37+
"--context",
38+
choices=["pr", "bat", "sanity", "full"],
39+
help="which test suites to be used: pr(PR testing), bat (BAT testing), sanity (Smoke testing), full (full testing); default: bat",
40+
default="bat",
41+
)
42+
parser.add_argument("--data", type=str, help="data folder which should include huggingface folder", default=data_path)
43+
parser.add_argument(
44+
"--filter",
45+
choices=["stable", "unstable", "all"],
46+
help="filter test case list: stable/unstable/all",
47+
default="all",
48+
)
49+
parser.add_argument("--device", type=str, help="device name", default="intel_hpu")
50+
parser.add_argument("--mode", type=str, help="it should be one of [dynamic, static]", default="dynamic")
51+
parser.add_argument("--junit", type=str, help="junit result file")
52+
parser.add_argument("--platform", type=str, help="platform name")
53+
54+
cmd_args.update(vars(parser.parse_args()))
55+
56+
if cmd_args["junit"]:
57+
libpath = os.path.dirname(os.path.dirname(realwd))
58+
if os.path.exists(f"{libpath}/junitxml.py"):
59+
import sys
60+
61+
sys.path.append(libpath)
62+
from junitxml import jTestCase, jTestSuite
63+
64+
cmd_args["platform"] = "" if cmd_args["platform"] is None else cmd_args["platform"].lower()
65+
66+
script_path = os.path.dirname(os.path.realpath(__file__))
67+
68+
if os.path.exists(cmd_args["data"]) is False:
69+
print("data path not exist, please check for parameter: data / environment DATA_DIR")
70+
exit(2)
71+
72+
data_path = f"{cmd_args['data']}"
73+
os.environ.setdefault("DATA_HOME", data_path)
74+
if os.path.exists(data_path) is False:
75+
print(f"Couldn't find mode data path, please confirm this folder under the {cmd_args['data']} folder")
76+
exit(2)
77+
78+
79+
def case_command(command, test_case=None, test_suite=None):
80+
output = ""
81+
output += f"Command: {command}\n"
82+
proc = subprocess.Popen(command, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
83+
try:
84+
line = proc.stdout.readline().decode()
85+
while len(line) > 0:
86+
print(f"{line[:-1]}") # line include '\n' at last
87+
output += line
88+
line = proc.stdout.readline().decode()
89+
except Exception as e:
90+
if test_case:
91+
test_case.setFail("Command abnormally")
92+
print(e)
93+
finally:
94+
proc.communicate()
95+
if proc.returncode != 0 and test_case:
96+
test_case.setFail(f"Command return code:{proc.returncode}")
97+
if test_case:
98+
test_case.AddOutput(output)
99+
if test_suite:
100+
ts.addCase(test_case)
101+
102+
return proc.returncode, output
103+
104+
105+
def run_e2e_test_case(
106+
model,
107+
model_case_name,
108+
src_length,
109+
max_length,
110+
total_max_length,
111+
batch_size,
112+
decode_strategy,
113+
output_file,
114+
skip_lst=None,
115+
test_suite=None,
116+
cmd_env="",
117+
):
118+
# mode option: dynamic or static
119+
mode_opt = ""
120+
mode_param = ""
121+
if cmd_args["mode"] and cmd_args["mode"] in ["dynamic", "static"]:
122+
mode_opt = f"--mode {cmd_args['mode']}"
123+
mode_param = cmd_args["mode"]
124+
else:
125+
mode_param = "dynamic"
126+
127+
# device option
128+
device_opt = "intel_hpu"
129+
if cmd_args["device"]:
130+
device_opt = cmd_args["device"]
131+
132+
# gaudi2d platform only support bfloat16
133+
float_opt = "--dtype bfloat16" if cmd_args["platform"] == "gaudi2d" else ""
134+
135+
testcase = None
136+
testcase_name = (
137+
f"{model_case_name}-{src_length}-{max_length}-{total_max_length}-{batch_size}-{decode_strategy}-{mode_param}"
138+
)
139+
if test_suite:
140+
testcase = jTestCase(testcase_name)
141+
else:
142+
testcase = None
143+
144+
cmd_line = f"python {paddlenlp_path}/predict/predictor.py --model_name_or_path {model} --inference_model --device {device_opt} {mode_opt} {float_opt} "
145+
cmd_line += f"--src_length {src_length} --max_length {max_length} --total_max_length {total_max_length} --batch_size {batch_size} --decode_strategy {decode_strategy} --output_file result/{testcase_name}.json"
146+
_env = os.environ.copy()
147+
for opt in cmd_env.split():
148+
_env.setdefault(opt.split("=")[0], opt.split("=")[1])
149+
print(f"RUN shell CMD: {cmd_env} {cmd_line}")
150+
ret, output = case_command(cmd_line, testcase, test_suite)
151+
152+
return ret, output
153+
154+
155+
case_dict_lst = {}
156+
157+
skip_lst = []
158+
159+
from config.llm import skip_case_lst, test_case_lst
160+
161+
case_dict_lst = test_case_lst[cmd_args["context"]]
162+
skip_lst = skip_case_lst.get(cmd_args["context"], {}).get(cmd_args["filter"], [])
163+
164+
ts = None
165+
if cmd_args["junit"]:
166+
ts = jTestSuite("E2E Test")
167+
ts.setPlatform(cmd_args["platform"])
168+
169+
total_case_num = 0
170+
pass_case_num = 0
171+
fail_case_num = 0
172+
173+
for model_name, case_dict in case_dict_lst.items():
174+
model_case_name = model_name.split("/")[-1]
175+
model_path_or_name = get_model_path(model_name)
176+
variants = case_dict.get("variants", dict()).get("inference", [])
177+
mode_param = "dynamic"
178+
if cmd_args["mode"] and cmd_args["mode"] in ["dynamic", "static"]:
179+
mode_param = cmd_args["mode"]
180+
181+
for variant_dict in variants:
182+
case_tag = f"{model_name} src:{variant_dict['src_length']}-max_length:{variant_dict['max_length']}-total_max_length:{variant_dict['total_max_length']}-bs:{variant_dict['batch_size']}-decode_strategy:{variant_dict['decode_strategy']}-mode:{mode_param}"
183+
ret_test, _ = run_e2e_test_case(
184+
model_path_or_name,
185+
model_case_name,
186+
variant_dict["src_length"],
187+
variant_dict["max_length"],
188+
variant_dict["total_max_length"],
189+
variant_dict["batch_size"],
190+
variant_dict["decode_strategy"],
191+
case_dict["output_file"],
192+
skip_lst,
193+
ts,
194+
)
195+
if "skip" == _:
196+
pass
197+
else:
198+
total_case_num = total_case_num + 1
199+
if ret_test == 0:
200+
pass_case_num = pass_case_num + 1
201+
print(f"\033[0;32mtest case {total_case_num} : {case_tag} pass \033[0m")
202+
else:
203+
fail_case_num = fail_case_num + 1
204+
print(f"\033[0;31mtest case {total_case_num} : {case_tag} fail \033[0m")
205+
206+
if cmd_args["junit"]:
207+
with open(cmd_args["junit"], "w+") as f:
208+
ts.toString()
209+
f.write(ts.toString())
210+
211+
print("...............................Summary.......................................")
212+
print(f"\033[0;37mE2E total {total_case_num} test case running \033[0m")
213+
print(f"\033[0;32mE2E total {pass_case_num} test case pass \033[0m")
214+
if fail_case_num != 0:
215+
print(f"\033[0;31mE2E total {fail_case_num} test case fail \033[0m")
216+
217+
exit(fail_case_num)

0 commit comments

Comments
 (0)