Skip to content

Commit 0eff170

Browse files
committed
fix
Signed-off-by: Icey <1790571317@qq.com>
1 parent d42b607 commit 0eff170

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

tests/e2e/singlecard/models/report_template.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ lm_eval --model {{ model_type }} --model_args $MODEL_ARGS --tasks {{ datasets }}
1616
--apply_chat_template --fewshot_as_multiturn {% if num_fewshot is defined and num_fewshot != "N/A" %} --num_fewshot {{ num_fewshot }} {% endif %} \
1717
--limit {{ limit }} --batch_size {{ batch_size}}
1818
```
19+
1920
| Task | Metric | Value | Stderr |
2021
|-----------------------|-------------|----------:|-------:|
2122
{% for row in rows -%}

tests/e2e/singlecard/models/test_lm_eval_correctness.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import os
22
from dataclasses import dataclass
3-
from numpy import np
4-
from typing import Any, Dict, List
53

64
import lm_eval
75
import pytest
86
import yaml
97
from jinja2 import Environment, FileSystemLoader
8+
from numpy import np
109

1110
RTOL = 0.02
1211

@@ -68,9 +67,8 @@ def build_eval_args(eval_config, tp_size):
6867
return eval_params
6968

7069

71-
def generate_report(tp_size, eval_config,
72-
report_data,
73-
report_template, output_path, env_config):
70+
def generate_report(tp_size, eval_config, report_data, report_template,
71+
output_path, env_config):
7472
env = Environment(loader=FileSystemLoader('.'))
7573
template = env.get_template(str(report_template))
7674
model_args = build_model_args(eval_config, tp_size)
@@ -97,20 +95,21 @@ def generate_report(tp_size, eval_config,
9795

9896

9997
def test_lm_eval_correctness_param(config_filename, tp_size, report_template,
100-
output_path, env_config):
98+
output_path, env_config):
10199
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
102100
eval_params = build_eval_args(eval_config, tp_size)
103101
results = lm_eval.simple_evaluate(**eval_params)
104102
success = True
105103
report_data: dict[str, list[dict]] = {"rows": []}
106-
104+
107105
for task in eval_config["tasks"]:
108106
for metric in task["metrics"]:
109-
ground_truth = metric["value"]
110-
measured_value = results["results"][task["name"]][metric["name"]]
111-
print(f"{task['name']} | {metric['name']}: "
112-
f"ground_truth={ground_truth} | measured={measured_value}")
113-
success = success and bool(np.isclose(ground_truth, measured_value, rtol=RTOL))
107+
ground_truth = metric["value"]
108+
measured_value = results["results"][task["name"]][metric["name"]]
109+
print(f"{task['name']} | {metric['name']}: "
110+
f"ground_truth={ground_truth} | measured={measured_value}")
111+
success = success and bool(
112+
np.isclose(ground_truth, measured_value, rtol=RTOL))
114113

115114
report_data["rows"].append({
116115
"task":
@@ -126,4 +125,3 @@ def test_lm_eval_correctness_param(config_filename, tp_size, report_template,
126125
generate_report(tp_size, eval_config, report_data, report_template,
127126
output_path, env_config)
128127
assert success
129-

0 commit comments

Comments
 (0)