diff --git a/protovalidate/internal/string_format.py b/protovalidate/internal/string_format.py index 84ef70b..922aff2 100644 --- a/protovalidate/internal/string_format.py +++ b/protovalidate/internal/string_format.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import re from decimal import Decimal from typing import Optional, Union @@ -186,7 +187,9 @@ def __format_string(self, arg: celtypes.Value) -> str: # True -> true return str(arg).lower() if isinstance(arg, celtypes.BytesType): - return str(arg, "utf-8") + decoded = arg.decode("utf-8", errors="replace") + # Collapse any contiguous placeholders into one + return re.sub("\\ufffd+", "\ufffd", decoded) if isinstance(arg, celtypes.DoubleType): result = self.__validate_number(arg) if result is not None: diff --git a/tests/format_test.py b/tests/format_test.py index 55e2cd8..353a847 100644 --- a/tests/format_test.py +++ b/tests/format_test.py @@ -13,6 +13,8 @@ # limitations under the License. import unittest +from collections.abc import MutableMapping +from itertools import chain from typing import Any, Optional import celpy @@ -45,15 +47,15 @@ ] -def read_textproto() -> simple_pb2.SimpleTestFile: +def load_test_data(file_name: str) -> simple_pb2.SimpleTestFile: msg = simple_pb2.SimpleTestFile() - with open(f"tests/testdata/string_ext_{CEL_SPEC_VERSION}.textproto") as file: + with open(file_name) as file: text_data = file.read() text_format.Parse(text_data, msg) return msg -def build_binding(bindings: dict[str, eval_pb2.ExprValue]) -> dict[Any, Any]: +def build_variables(bindings: MutableMapping[str, eval_pb2.ExprValue]) -> dict[Any, Any]: binder = {} for key, value in bindings.items(): if value.HasField("value"): @@ -82,25 +84,33 @@ def get_eval_error_message(test: simple_pb2.SimpleTest) -> Optional[str]: class TestFormat(unittest.TestCase): @classmethod def setUpClass(cls): - test_data = read_textproto() - cls._format_test_section = next((x for x in test_data.section if x.name == "format"), None) - cls._format_error_test_section = next((x for x in test_data.section if x.name == "format_errors"), None) + # The test data from the cel-spec conformance tests + cel_test_data = load_test_data(f"tests/testdata/string_ext_{CEL_SPEC_VERSION}.textproto") + # Our supplemental tests of functionality not in the cel conformance file, but defined in the spec. + supplemental_test_data = load_test_data("tests/testdata/string_ext_supplemental.textproto") + + # Combine the test data from both files into one + sections = cel_test_data.section + sections.extend(supplemental_test_data.section) + + # Find the format tests which test successful formatting + cls._format_tests = chain.from_iterable(x.test for x in sections if x.name == "format") + # Find the format error tests which test errors during formatting + cls._format_error_tests = chain.from_iterable(x.test for x in sections if x.name == "format_errors") + cls._env = celpy.Environment(runner_class=InterpretedRunner) def test_format_successes(self): """ Tests success scenarios for string.format """ - section = self._format_test_section - if section is None: - return - for test in section.test: + for test in self._format_tests: if test.name in skipped_tests: continue ast = self._env.compile(test.expr) prog = self._env.program(ast, functions=extra_func.EXTRA_FUNCS) - bindings = build_binding(test.bindings) + bindings = build_variables(test.bindings) # Ideally we should use pytest parametrize instead of subtests, but # that would require refactoring other tests also. with self.subTest(test.name): @@ -118,16 +128,13 @@ def test_format_errors(self): """ Tests error scenarios for string.format """ - section = self._format_error_test_section - if section is None: - return - for test in section.test: + for test in self._format_error_tests: if test.name in skipped_error_tests: continue ast = self._env.compile(test.expr) prog = self._env.program(ast, functions=extra_func.EXTRA_FUNCS) - bindings = build_binding(test.bindings) + bindings = build_variables(test.bindings) # Ideally we should use pytest parametrize instead of subtests, but # that would require refactoring other tests also. with self.subTest(test.name): diff --git a/tests/testdata/string_ext_supplemental.textproto b/tests/testdata/string_ext_supplemental.textproto new file mode 100644 index 0000000..dde7b83 --- /dev/null +++ b/tests/testdata/string_ext_supplemental.textproto @@ -0,0 +1,26 @@ +# proto-file: ../../../proto/cel/expr/conformance/test/simple.proto +# proto-message: cel.expr.conformance.test.SimpleTestFile + +# Ideally these tests should be in the cel-spec conformance test suite. +# Until they are added, we can use this to test for additional functionality +# listed in the spec. + +name: "string_ext_supplemental" +description: "Supplemental tests for the strings extension library." +section: { + name: "format" + test: { + name: "bytes support for string with invalid utf-8 encoding" + expr: '"%s".format([b"\\xF0abc\\x8C\\xF0xyz"])' + value: { + string_value: '\ufffdabc\ufffdxyz', + } + } + test: { + name: "bytes support for string with only invalid utf-8 sequences" + expr: '"%s".format([b"\\xF0\\x8C\\xF0"])' + value: { + string_value: '\ufffd', + } + } +}