13
13
# limitations under the License.
14
14
15
15
import unittest
16
+ from collections .abc import MutableMapping
17
+ from itertools import chain
16
18
from typing import Any , Optional
17
19
18
20
import celpy
45
47
]
46
48
47
49
48
- def read_textproto ( ) -> simple_pb2 .SimpleTestFile :
50
+ def load_test_data ( file_name : str ) -> simple_pb2 .SimpleTestFile :
49
51
msg = simple_pb2 .SimpleTestFile ()
50
- with open (f"tests/testdata/string_ext_ { CEL_SPEC_VERSION } .textproto" ) as file :
52
+ with open (file_name ) as file :
51
53
text_data = file .read ()
52
54
text_format .Parse (text_data , msg )
53
55
return msg
54
56
55
57
56
- def build_binding (bindings : dict [str , eval_pb2 .ExprValue ]) -> dict [Any , Any ]:
58
+ def build_variables (bindings : MutableMapping [str , eval_pb2 .ExprValue ]) -> dict [Any , Any ]:
57
59
binder = {}
58
60
for key , value in bindings .items ():
59
61
if value .HasField ("value" ):
@@ -82,25 +84,33 @@ def get_eval_error_message(test: simple_pb2.SimpleTest) -> Optional[str]:
82
84
class TestFormat (unittest .TestCase ):
83
85
@classmethod
84
86
def setUpClass (cls ):
85
- test_data = read_textproto ()
86
- cls ._format_test_section = next ((x for x in test_data .section if x .name == "format" ), None )
87
- cls ._format_error_test_section = next ((x for x in test_data .section if x .name == "format_errors" ), None )
87
+ # The test data from the cel-spec conformance tests
88
+ cel_test_data = load_test_data (f"tests/testdata/string_ext_{ CEL_SPEC_VERSION } .textproto" )
89
+ # Our supplemental tests of functionality not in the cel conformance file, but defined in the spec.
90
+ supplemental_test_data = load_test_data ("tests/testdata/string_ext_supplemental.textproto" )
91
+
92
+ # Combine the test data from both files into one
93
+ sections = cel_test_data .section
94
+ sections .extend (supplemental_test_data .section )
95
+
96
+ # Find the format tests which test successful formatting
97
+ cls ._format_tests = chain .from_iterable (x .test for x in sections if x .name == "format" )
98
+ # Find the format error tests which test errors during formatting
99
+ cls ._format_error_tests = chain .from_iterable (x .test for x in sections if x .name == "format_errors" )
100
+
88
101
cls ._env = celpy .Environment (runner_class = InterpretedRunner )
89
102
90
103
def test_format_successes (self ):
91
104
"""
92
105
Tests success scenarios for string.format
93
106
"""
94
- section = self ._format_test_section
95
- if section is None :
96
- return
97
- for test in section .test :
107
+ for test in self ._format_tests :
98
108
if test .name in skipped_tests :
99
109
continue
100
110
ast = self ._env .compile (test .expr )
101
111
prog = self ._env .program (ast , functions = extra_func .EXTRA_FUNCS )
102
112
103
- bindings = build_binding (test .bindings )
113
+ bindings = build_variables (test .bindings )
104
114
# Ideally we should use pytest parametrize instead of subtests, but
105
115
# that would require refactoring other tests also.
106
116
with self .subTest (test .name ):
@@ -118,16 +128,13 @@ def test_format_errors(self):
118
128
"""
119
129
Tests error scenarios for string.format
120
130
"""
121
- section = self ._format_error_test_section
122
- if section is None :
123
- return
124
- for test in section .test :
131
+ for test in self ._format_error_tests :
125
132
if test .name in skipped_error_tests :
126
133
continue
127
134
ast = self ._env .compile (test .expr )
128
135
prog = self ._env .program (ast , functions = extra_func .EXTRA_FUNCS )
129
136
130
- bindings = build_binding (test .bindings )
137
+ bindings = build_variables (test .bindings )
131
138
# Ideally we should use pytest parametrize instead of subtests, but
132
139
# that would require refactoring other tests also.
133
140
with self .subTest (test .name ):
0 commit comments