Skip to content

Commit 5e6ac7b

Browse files
Add rule and field value to violations (#224)
Adds the ability to access the captured rule and field value from a `Violation`. **This is a breaking change.** The API changes in the following ways: - `ValidationError` has changed: - Old accesses to `ValidationError.violations` should call `ValidationError.to_proto()` instead, to get a `buf.validate.Violations` message. - `ValidationError.errors` was removed. Switch to using `ValidationError.violations` instead. - `ValidationError.violations` provides a list of the new `Violation` wrapper type instead of a list of `buf.validate.Violation`. - The new `Violation` wrapper type contains the `buf.validate.Violation` message under the `proto` field, as well as `field_value` and `rule_value` properties that capture the field and rule values, respectively. - `Validator.collect_violations` now operates on and returns `list[Violation]` instead of the protobuf `buf.validate.Violations` message. This API mirrors the changes being made in protovalidate-go in bufbuild/protovalidate-go#154.
1 parent 41a4661 commit 5e6ac7b

File tree

4 files changed

+135
-73
lines changed

4 files changed

+135
-73
lines changed

protovalidate/internal/constraints.py

Lines changed: 88 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import dataclasses
1516
import datetime
1617
import typing
1718

@@ -81,7 +82,7 @@ def __getitem__(self, name):
8182
return super().__getitem__(name)
8283

8384

84-
def _msg_to_cel(msg: message.Message) -> dict[str, celtypes.Value]:
85+
def _msg_to_cel(msg: message.Message) -> celtypes.Value:
8586
ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name)
8687
if ctor is not None:
8788
return ctor(msg)
@@ -230,43 +231,56 @@ def _set_path_element_map_key(
230231
raise CompilationError(msg)
231232

232233

234+
class Violation:
235+
"""A singular constraint violation."""
236+
237+
proto: validate_pb2.Violation
238+
field_value: typing.Any
239+
rule_value: typing.Any
240+
241+
def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = None, **kwargs):
242+
self.proto = validate_pb2.Violation(**kwargs)
243+
self.field_value = field_value
244+
self.rule_value = rule_value
245+
246+
233247
class ConstraintContext:
234248
"""The state associated with a single constraint evaluation."""
235249

236-
def __init__(self, fail_fast: bool = False, violations: validate_pb2.Violations = None): # noqa: FBT001, FBT002
250+
def __init__(self, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): # noqa: FBT001, FBT002
237251
self._fail_fast = fail_fast
238252
if violations is None:
239-
violations = validate_pb2.Violations()
253+
violations = []
240254
self._violations = violations
241255

242256
@property
243257
def fail_fast(self) -> bool:
244258
return self._fail_fast
245259

246260
@property
247-
def violations(self) -> validate_pb2.Violations:
261+
def violations(self) -> list[Violation]:
248262
return self._violations
249263

250-
def add(self, violation: validate_pb2.Violation):
251-
self._violations.violations.append(violation)
264+
def add(self, violation: Violation):
265+
self._violations.append(violation)
252266

253267
def add_errors(self, other_ctx):
254-
self._violations.violations.extend(other_ctx.violations.violations)
268+
self._violations.extend(other_ctx.violations)
255269

256270
def add_field_path_element(self, element: validate_pb2.FieldPathElement):
257-
for violation in self._violations.violations:
258-
violation.field.elements.append(element)
271+
for violation in self._violations:
272+
violation.proto.field.elements.append(element)
259273

260274
def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPathElement]):
261-
for violation in self._violations.violations:
262-
violation.rule.elements.extend(elements)
275+
for violation in self._violations:
276+
violation.proto.rule.elements.extend(elements)
263277

264278
@property
265279
def done(self) -> bool:
266280
return self._fail_fast and self.has_errors()
267281

268282
def has_errors(self) -> bool:
269-
return len(self._violations.violations) > 0
283+
return len(self._violations) > 0
270284

271285
def sub_context(self):
272286
return ConstraintContext(self._fail_fast)
@@ -277,55 +291,67 @@ class ConstraintRules:
277291

278292
def validate(self, ctx: ConstraintContext, message: message.Message): # noqa: ARG002
279293
"""Validate the message against the rules in this constraint."""
280-
ctx.add(validate_pb2.Violation(constraint_id="unimplemented", message="Unimplemented"))
294+
ctx.add(Violation(constraint_id="unimplemented", message="Unimplemented"))
295+
296+
297+
@dataclasses.dataclass
298+
class CelRunner:
299+
runner: celpy.Runner
300+
constraint: validate_pb2.Constraint
301+
rule_value: typing.Optional[typing.Any] = None
302+
rule_cel: typing.Optional[celtypes.Value] = None
303+
rule_path: typing.Optional[validate_pb2.FieldPath] = None
281304

282305

283306
class CelConstraintRules(ConstraintRules):
284307
"""A constraint that has rules written in CEL."""
285308

286-
_runners: list[
287-
tuple[
288-
celpy.Runner,
289-
validate_pb2.Constraint,
290-
typing.Optional[celtypes.Value],
291-
typing.Optional[validate_pb2.FieldPath],
292-
]
293-
]
294-
_rules_cel: celtypes.Value = None
309+
_cel: list[CelRunner]
310+
_rules: typing.Optional[message.Message] = None
311+
_rules_cel: typing.Optional[celtypes.Value] = None
295312

296313
def __init__(self, rules: typing.Optional[message.Message]):
297-
self._runners = []
314+
self._cel = []
298315
if rules is not None:
316+
self._rules = rules
299317
self._rules_cel = _msg_to_cel(rules)
300318

301319
def _validate_cel(
302320
self,
303321
ctx: ConstraintContext,
304-
activation: dict[str, typing.Any],
305322
*,
323+
this_value: typing.Optional[typing.Any] = None,
324+
this_cel: typing.Optional[celtypes.Value] = None,
306325
for_key: bool = False,
307326
):
327+
activation: dict[str, celtypes.Value] = {}
328+
if this_cel is not None:
329+
activation["this"] = this_cel
308330
activation["rules"] = self._rules_cel
309331
activation["now"] = celtypes.TimestampType(datetime.datetime.now(tz=datetime.timezone.utc))
310-
for runner, constraint, rule, rule_path in self._runners:
311-
activation["rule"] = rule
312-
result = runner.evaluate(activation)
332+
for cel in self._cel:
333+
activation["rule"] = cel.rule_cel
334+
result = cel.runner.evaluate(activation)
313335
if isinstance(result, celtypes.BoolType):
314336
if not result:
315337
ctx.add(
316-
validate_pb2.Violation(
317-
rule=rule_path,
318-
constraint_id=constraint.id,
319-
message=constraint.message,
338+
Violation(
339+
field_value=this_value,
340+
rule=cel.rule_path,
341+
rule_value=cel.rule_value,
342+
constraint_id=cel.constraint.id,
343+
message=cel.constraint.message,
320344
for_key=for_key,
321345
),
322346
)
323347
elif isinstance(result, celtypes.StringType):
324348
if result:
325349
ctx.add(
326-
validate_pb2.Violation(
327-
rule=rule_path,
328-
constraint_id=constraint.id,
350+
Violation(
351+
field_value=this_value,
352+
rule=cel.rule_path,
353+
rule_value=cel.rule_value,
354+
constraint_id=cel.constraint.id,
329355
message=result,
330356
for_key=for_key,
331357
),
@@ -339,19 +365,32 @@ def add_rule(
339365
funcs: dict[str, celpy.CELFunction],
340366
rules: validate_pb2.Constraint,
341367
*,
342-
rule: typing.Optional[celtypes.Value] = None,
368+
rule_field: typing.Optional[descriptor.FieldDescriptor] = None,
343369
rule_path: typing.Optional[validate_pb2.FieldPath] = None,
344370
):
345371
ast = env.compile(rules.expression)
346372
prog = env.program(ast, functions=funcs)
347-
self._runners.append((prog, rules, rule, rule_path))
373+
rule_value = None
374+
rule_cel = None
375+
if rule_field is not None and self._rules is not None:
376+
rule_value = _proto_message_get_field(self._rules, rule_field)
377+
rule_cel = _field_to_cel(self._rules, rule_field)
378+
self._cel.append(
379+
CelRunner(
380+
runner=prog,
381+
constraint=rules,
382+
rule_value=rule_value,
383+
rule_cel=rule_cel,
384+
rule_path=rule_path,
385+
)
386+
)
348387

349388

350389
class MessageConstraintRules(CelConstraintRules):
351390
"""Message-level rules."""
352391

353392
def validate(self, ctx: ConstraintContext, message: message.Message):
354-
self._validate_cel(ctx, {"this": _msg_to_cel(message)})
393+
self._validate_cel(ctx, this_cel=_msg_to_cel(message))
355394

356395

357396
def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_name: typing.Optional[str] = None):
@@ -445,7 +484,7 @@ def __init__(
445484
env,
446485
funcs,
447486
cel,
448-
rule=_field_to_cel(rules, list_field),
487+
rule_field=list_field,
449488
rule_path=validate_pb2.FieldPath(
450489
elements=[
451490
_field_to_element(list_field),
@@ -465,13 +504,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
465504
if _is_empty_field(message, self._field):
466505
if self._required:
467506
ctx.add(
468-
validate_pb2.Violation(
507+
Violation(
469508
field=validate_pb2.FieldPath(
470509
elements=[
471510
_field_to_element(self._field),
472511
],
473512
),
474513
rule=FieldConstraintRules._required_rule_path,
514+
rule_value=self._required,
475515
constraint_id="required",
476516
message="value is required",
477517
),
@@ -485,15 +525,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
485525
return
486526
sub_ctx = ctx.sub_context()
487527
self._validate_value(sub_ctx, val)
488-
self._validate_cel(sub_ctx, {"this": cel_val})
528+
self._validate_cel(sub_ctx, this_value=_proto_message_get_field(message, self._field), this_cel=cel_val)
489529
if sub_ctx.has_errors():
490530
element = _field_to_element(self._field)
491531
sub_ctx.add_field_path_element(element)
492532
ctx.add_errors(sub_ctx)
493533

494534
def validate_item(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
495535
self._validate_value(ctx, val, for_key=for_key)
496-
self._validate_cel(ctx, {"this": _scalar_field_value_to_cel(val, self._field)}, for_key=for_key)
536+
self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key)
497537

498538
def _validate_value(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
499539
pass
@@ -546,17 +586,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
546586
if len(self._in) > 0:
547587
if value.type_url not in self._in:
548588
ctx.add(
549-
validate_pb2.Violation(
589+
Violation(
550590
rule=AnyConstraintRules._in_rule_path,
591+
rule_value=self._in,
551592
constraint_id="any.in",
552593
message="type URL must be in the allow list",
553594
for_key=for_key,
554595
)
555596
)
556597
if value.type_url in self._not_in:
557598
ctx.add(
558-
validate_pb2.Violation(
599+
Violation(
559600
rule=AnyConstraintRules._not_in_rule_path,
601+
rule_value=self._not_in,
560602
constraint_id="any.not_in",
561603
message="type URL must not be in the block list",
562604
for_key=for_key,
@@ -603,13 +645,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
603645
value = getattr(message, self._field.name)
604646
if value not in self._field.enum_type.values_by_number:
605647
ctx.add(
606-
validate_pb2.Violation(
648+
Violation(
607649
field=validate_pb2.FieldPath(
608650
elements=[
609651
_field_to_element(self._field),
610652
],
611653
),
612654
rule=EnumConstraintRules._defined_only_rule_path,
655+
rule_value=self._defined_only,
613656
constraint_id="enum.defined_only",
614657
message="value must be one of the defined enum values",
615658
),
@@ -742,7 +785,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
742785
if not message.WhichOneof(self._oneof.name):
743786
if self.required:
744787
ctx.add(
745-
validate_pb2.Violation(
788+
Violation(
746789
field=validate_pb2.FieldPath(
747790
elements=[_oneof_to_element(self._oneof)],
748791
),

protovalidate/validator.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
from google.protobuf import message
1618

1719
from buf.validate import validate_pb2 # type: ignore
@@ -20,6 +22,7 @@
2022

2123
CompilationError = _constraints.CompilationError
2224
Violations = validate_pb2.Violations
25+
Violation = _constraints.Violation
2326

2427

2528
class Validator:
@@ -54,7 +57,7 @@ def validate(
5457
ValidationError: If the message is invalid.
5558
"""
5659
violations = self.collect_violations(message, fail_fast=fail_fast)
57-
if violations.violations:
60+
if len(violations) > 0:
5861
msg = f"invalid {message.DESCRIPTOR.name}"
5962
raise ValidationError(msg, violations)
6063

@@ -63,8 +66,8 @@ def collect_violations(
6366
message: message.Message,
6467
*,
6568
fail_fast: bool = False,
66-
into: validate_pb2.Violations = None,
67-
) -> validate_pb2.Violations:
69+
into: typing.Optional[list[Violation]] = None,
70+
) -> list[Violation]:
6871
"""
6972
Validates the given message against the static constraints defined in
7073
the message's descriptor. Compared to validate, collect_violations is
@@ -84,12 +87,12 @@ def collect_violations(
8487
constraint.validate(ctx, message)
8588
if ctx.done:
8689
break
87-
for violation in ctx.violations.violations:
88-
if violation.HasField("field"):
89-
violation.field.elements.reverse()
90-
if violation.HasField("rule"):
91-
violation.rule.elements.reverse()
92-
violation.field_path = field_path.string(violation.field)
90+
for violation in ctx.violations:
91+
if violation.proto.HasField("field"):
92+
violation.proto.field.elements.reverse()
93+
if violation.proto.HasField("rule"):
94+
violation.proto.rule.elements.reverse()
95+
violation.proto.field_path = field_path.string(violation.proto.field)
9396
return ctx.violations
9497

9598

@@ -98,15 +101,25 @@ class ValidationError(ValueError):
98101
An error raised when a message fails to validate.
99102
"""
100103

101-
violations: validate_pb2.Violations
104+
_violations: list[_constraints.Violation]
102105

103-
def __init__(self, msg: str, violations: validate_pb2.Violations):
106+
def __init__(self, msg: str, violations: list[_constraints.Violation]):
104107
super().__init__(msg)
105-
self.violations = violations
108+
self._violations = violations
109+
110+
def to_proto(self) -> validate_pb2.Violations:
111+
"""
112+
Provides the Protobuf form of the validation errors.
113+
"""
114+
result = validate_pb2.Violations()
115+
for violation in self._violations:
116+
result.violations.append(violation.proto)
117+
return result
106118

107-
def errors(self) -> list[validate_pb2.Violation]:
119+
@property
120+
def violations(self) -> list[Violation]:
108121
"""
109-
Returns the validation errors as a simple Python list, rather than the
122+
Provides the validation errors as a simple Python list, rather than the
110123
Protobuf-specific collection type used by Violations.
111124
"""
112-
return list(self.violations.violations)
125+
return self._violations

tests/conformance/runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def run_test_case(tc: typing.Any, result: typing.Optional[harness_pb2.TestResult
6262
result = harness_pb2.TestResult()
6363
# Run the validator
6464
try:
65-
protovalidate.collect_violations(tc, into=result.validation_error)
65+
violations = protovalidate.collect_violations(tc)
66+
for violation in violations:
67+
result.validation_error.violations.append(violation.proto)
6668
if len(result.validation_error.violations) == 0:
6769
result.success = True
6870
except celpy.CELEvalError as e:

0 commit comments

Comments
 (0)