From 98336d89951391dd906865ff728b9e7fe1727bef Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Tue, 22 Apr 2025 02:09:54 -0400 Subject: [PATCH] Implement getField CEL function I'm proposing this as an eventual replacement (before 1.0) of our hack around the fact that the `in` identifier is reserved in CEL. This is especially urgent for protovalidate-cc which is currently carrying patches to the CEL implementation in order to enable it, since cel-cpp doesn't allow this sort of functionality to be added in at runtime. --- protovalidate/internal/constraints.py | 8 ++++---- protovalidate/internal/extra_func.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/protovalidate/internal/constraints.py b/protovalidate/internal/constraints.py index 2519108..48b1ca4 100644 --- a/protovalidate/internal/constraints.py +++ b/protovalidate/internal/constraints.py @@ -40,7 +40,7 @@ def make_timestamp(msg: message.Message) -> celtypes.TimestampType: def unwrap(msg: message.Message) -> celtypes.Value: - return _field_to_cel(msg, msg.DESCRIPTOR.fields_by_name["value"]) + return field_to_cel(msg, msg.DESCRIPTOR.fields_by_name["value"]) _MSG_TYPE_URL_TO_CTOR: dict[str, typing.Callable[..., celtypes.Value]] = { @@ -70,7 +70,7 @@ def __init__(self, msg: message.Message): for field in self.desc.fields: if field.containing_oneof is not None and not self.msg.HasField(field.name): continue - self[field.name] = _field_to_cel(self.msg, field) + self[field.name] = field_to_cel(self.msg, field) def __getitem__(self, name): field = self.desc.fields_by_name[name] @@ -175,7 +175,7 @@ def _map_field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) - return _map_field_value_to_cel(_proto_message_get_field(msg, field), field) -def _field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value: +def field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value: if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: return _repeated_field_to_cel(msg, field) elif field.message_type is not None and not _proto_message_has_field(msg, field): @@ -374,7 +374,7 @@ def add_rule( rule_cel = None if rule_field is not None and self._rules is not None: rule_value = _proto_message_get_field(self._rules, rule_field) - rule_cel = _field_to_cel(self._rules, rule_field) + rule_cel = field_to_cel(self._rules, rule_field) self._cel.append( CelRunner( runner=prog, diff --git a/protovalidate/internal/extra_func.py b/protovalidate/internal/extra_func.py index cf211c7..96fe5e3 100644 --- a/protovalidate/internal/extra_func.py +++ b/protovalidate/internal/extra_func.py @@ -22,6 +22,7 @@ from celpy import celtypes from protovalidate.internal import string_format +from protovalidate.internal.constraints import MessageType, field_to_cel def _validate_hostname(host): @@ -112,6 +113,19 @@ def validate_ip(val: typing.Union[str, bytes], version: typing.Optional[int] = N return False +def get_field(message: celtypes.Value, field_name: celtypes.Value) -> celpy.Result: + if not isinstance(message, MessageType): + msg = "invalid argument, expected message" + raise celpy.CELEvalError(msg) + if not isinstance(field_name, celtypes.StringType): + msg = "invalid argument, expected string" + raise celpy.CELEvalError(msg) + if field_name not in message.desc.fields_by_name: + msg = f"no such field: {field_name}" + raise celpy.CELEvalError(msg) + return field_to_cel(message.msg, message.desc.fields_by_name[field_name]) + + def is_ip(val: celtypes.Value, version: typing.Optional[celtypes.Value] = None) -> celpy.Result: if not isinstance(val, (celtypes.BytesType, celtypes.StringType)): msg = "invalid argument, expected string or bytes" @@ -245,6 +259,7 @@ def make_extra_funcs(locale: str) -> dict[str, celpy.CELFunction]: # Missing standard functions "format": string_fmt.format, # protovalidate specific functions + "getField": get_field, "isNan": is_nan, "isInf": is_inf, "isIp": is_ip,