1
1
import os
2
2
from dataclasses import dataclass
3
- from numpy import np
4
- from typing import Any , Dict , List
5
3
6
4
import lm_eval
7
5
import pytest
8
6
import yaml
9
7
from jinja2 import Environment , FileSystemLoader
8
+ from numpy import np
10
9
11
10
RTOL = 0.02
12
11
@@ -68,9 +67,8 @@ def build_eval_args(eval_config, tp_size):
68
67
return eval_params
69
68
70
69
71
- def generate_report (tp_size , eval_config ,
72
- report_data ,
73
- report_template , output_path , env_config ):
70
+ def generate_report (tp_size , eval_config , report_data , report_template ,
71
+ output_path , env_config ):
74
72
env = Environment (loader = FileSystemLoader ('.' ))
75
73
template = env .get_template (str (report_template ))
76
74
model_args = build_model_args (eval_config , tp_size )
@@ -97,20 +95,21 @@ def generate_report(tp_size, eval_config,
97
95
98
96
99
97
def test_lm_eval_correctness_param (config_filename , tp_size , report_template ,
100
- output_path , env_config ):
98
+ output_path , env_config ):
101
99
eval_config = yaml .safe_load (config_filename .read_text (encoding = "utf-8" ))
102
100
eval_params = build_eval_args (eval_config , tp_size )
103
101
results = lm_eval .simple_evaluate (** eval_params )
104
102
success = True
105
103
report_data : dict [str , list [dict ]] = {"rows" : []}
106
-
104
+
107
105
for task in eval_config ["tasks" ]:
108
106
for metric in task ["metrics" ]:
109
- ground_truth = metric ["value" ]
110
- measured_value = results ["results" ][task ["name" ]][metric ["name" ]]
111
- print (f"{ task ['name' ]} | { metric ['name' ]} : "
112
- f"ground_truth={ ground_truth } | measured={ measured_value } " )
113
- success = success and bool (np .isclose (ground_truth , measured_value , rtol = RTOL ))
107
+ ground_truth = metric ["value" ]
108
+ measured_value = results ["results" ][task ["name" ]][metric ["name" ]]
109
+ print (f"{ task ['name' ]} | { metric ['name' ]} : "
110
+ f"ground_truth={ ground_truth } | measured={ measured_value } " )
111
+ success = success and bool (
112
+ np .isclose (ground_truth , measured_value , rtol = RTOL ))
114
113
115
114
report_data ["rows" ].append ({
116
115
"task" :
@@ -126,4 +125,3 @@ def test_lm_eval_correctness_param(config_filename, tp_size, report_template,
126
125
generate_report (tp_size , eval_config , report_data , report_template ,
127
126
output_path , env_config )
128
127
assert success
129
-
0 commit comments