12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import dataclasses
15
16
import datetime
16
17
import typing
17
18
@@ -81,7 +82,7 @@ def __getitem__(self, name):
81
82
return super ().__getitem__ (name )
82
83
83
84
84
- def _msg_to_cel (msg : message .Message ) -> dict [ str , celtypes .Value ] :
85
+ def _msg_to_cel (msg : message .Message ) -> celtypes .Value :
85
86
ctor = _MSG_TYPE_URL_TO_CTOR .get (msg .DESCRIPTOR .full_name )
86
87
if ctor is not None :
87
88
return ctor (msg )
@@ -230,43 +231,56 @@ def _set_path_element_map_key(
230
231
raise CompilationError (msg )
231
232
232
233
234
+ class Violation :
235
+ """A singular constraint violation."""
236
+
237
+ proto : validate_pb2 .Violation
238
+ field_value : typing .Any
239
+ rule_value : typing .Any
240
+
241
+ def __init__ (self , * , field_value : typing .Any = None , rule_value : typing .Any = None , ** kwargs ):
242
+ self .proto = validate_pb2 .Violation (** kwargs )
243
+ self .field_value = field_value
244
+ self .rule_value = rule_value
245
+
246
+
233
247
class ConstraintContext :
234
248
"""The state associated with a single constraint evaluation."""
235
249
236
- def __init__ (self , fail_fast : bool = False , violations : validate_pb2 . Violations = None ): # noqa: FBT001, FBT002
250
+ def __init__ (self , fail_fast : bool = False , violations : typing . Optional [ list [ Violation ]] = None ): # noqa: FBT001, FBT002
237
251
self ._fail_fast = fail_fast
238
252
if violations is None :
239
- violations = validate_pb2 . Violations ()
253
+ violations = []
240
254
self ._violations = violations
241
255
242
256
@property
243
257
def fail_fast (self ) -> bool :
244
258
return self ._fail_fast
245
259
246
260
@property
247
- def violations (self ) -> validate_pb2 . Violations :
261
+ def violations (self ) -> list [ Violation ] :
248
262
return self ._violations
249
263
250
- def add (self , violation : validate_pb2 . Violation ):
251
- self ._violations .violations . append (violation )
264
+ def add (self , violation : Violation ):
265
+ self ._violations .append (violation )
252
266
253
267
def add_errors (self , other_ctx ):
254
- self ._violations .violations . extend (other_ctx . violations .violations )
268
+ self ._violations .extend (other_ctx .violations )
255
269
256
270
def add_field_path_element (self , element : validate_pb2 .FieldPathElement ):
257
- for violation in self ._violations . violations :
258
- violation .field .elements .append (element )
271
+ for violation in self ._violations :
272
+ violation .proto . field .elements .append (element )
259
273
260
274
def add_rule_path_elements (self , elements : typing .Iterable [validate_pb2 .FieldPathElement ]):
261
- for violation in self ._violations . violations :
262
- violation .rule .elements .extend (elements )
275
+ for violation in self ._violations :
276
+ violation .proto . rule .elements .extend (elements )
263
277
264
278
@property
265
279
def done (self ) -> bool :
266
280
return self ._fail_fast and self .has_errors ()
267
281
268
282
def has_errors (self ) -> bool :
269
- return len (self ._violations . violations ) > 0
283
+ return len (self ._violations ) > 0
270
284
271
285
def sub_context (self ):
272
286
return ConstraintContext (self ._fail_fast )
@@ -277,55 +291,67 @@ class ConstraintRules:
277
291
278
292
def validate (self , ctx : ConstraintContext , message : message .Message ): # noqa: ARG002
279
293
"""Validate the message against the rules in this constraint."""
280
- ctx .add (validate_pb2 .Violation (constraint_id = "unimplemented" , message = "Unimplemented" ))
294
+ ctx .add (Violation (constraint_id = "unimplemented" , message = "Unimplemented" ))
295
+
296
+
297
+ @dataclasses .dataclass
298
+ class CelRunner :
299
+ runner : celpy .Runner
300
+ constraint : validate_pb2 .Constraint
301
+ rule_value : typing .Optional [typing .Any ] = None
302
+ rule_cel : typing .Optional [celtypes .Value ] = None
303
+ rule_path : typing .Optional [validate_pb2 .FieldPath ] = None
281
304
282
305
283
306
class CelConstraintRules (ConstraintRules ):
284
307
"""A constraint that has rules written in CEL."""
285
308
286
- _runners : list [
287
- tuple [
288
- celpy .Runner ,
289
- validate_pb2 .Constraint ,
290
- typing .Optional [celtypes .Value ],
291
- typing .Optional [validate_pb2 .FieldPath ],
292
- ]
293
- ]
294
- _rules_cel : celtypes .Value = None
309
+ _cel : list [CelRunner ]
310
+ _rules : typing .Optional [message .Message ] = None
311
+ _rules_cel : typing .Optional [celtypes .Value ] = None
295
312
296
313
def __init__ (self , rules : typing .Optional [message .Message ]):
297
- self ._runners = []
314
+ self ._cel = []
298
315
if rules is not None :
316
+ self ._rules = rules
299
317
self ._rules_cel = _msg_to_cel (rules )
300
318
301
319
def _validate_cel (
302
320
self ,
303
321
ctx : ConstraintContext ,
304
- activation : dict [str , typing .Any ],
305
322
* ,
323
+ this_value : typing .Optional [typing .Any ] = None ,
324
+ this_cel : typing .Optional [celtypes .Value ] = None ,
306
325
for_key : bool = False ,
307
326
):
327
+ activation : dict [str , celtypes .Value ] = {}
328
+ if this_cel is not None :
329
+ activation ["this" ] = this_cel
308
330
activation ["rules" ] = self ._rules_cel
309
331
activation ["now" ] = celtypes .TimestampType (datetime .datetime .now (tz = datetime .timezone .utc ))
310
- for runner , constraint , rule , rule_path in self ._runners :
311
- activation ["rule" ] = rule
312
- result = runner .evaluate (activation )
332
+ for cel in self ._cel :
333
+ activation ["rule" ] = cel . rule_cel
334
+ result = cel . runner .evaluate (activation )
313
335
if isinstance (result , celtypes .BoolType ):
314
336
if not result :
315
337
ctx .add (
316
- validate_pb2 .Violation (
317
- rule = rule_path ,
318
- constraint_id = constraint .id ,
319
- message = constraint .message ,
338
+ Violation (
339
+ field_value = this_value ,
340
+ rule = cel .rule_path ,
341
+ rule_value = cel .rule_value ,
342
+ constraint_id = cel .constraint .id ,
343
+ message = cel .constraint .message ,
320
344
for_key = for_key ,
321
345
),
322
346
)
323
347
elif isinstance (result , celtypes .StringType ):
324
348
if result :
325
349
ctx .add (
326
- validate_pb2 .Violation (
327
- rule = rule_path ,
328
- constraint_id = constraint .id ,
350
+ Violation (
351
+ field_value = this_value ,
352
+ rule = cel .rule_path ,
353
+ rule_value = cel .rule_value ,
354
+ constraint_id = cel .constraint .id ,
329
355
message = result ,
330
356
for_key = for_key ,
331
357
),
@@ -339,19 +365,32 @@ def add_rule(
339
365
funcs : dict [str , celpy .CELFunction ],
340
366
rules : validate_pb2 .Constraint ,
341
367
* ,
342
- rule : typing .Optional [celtypes . Value ] = None ,
368
+ rule_field : typing .Optional [descriptor . FieldDescriptor ] = None ,
343
369
rule_path : typing .Optional [validate_pb2 .FieldPath ] = None ,
344
370
):
345
371
ast = env .compile (rules .expression )
346
372
prog = env .program (ast , functions = funcs )
347
- self ._runners .append ((prog , rules , rule , rule_path ))
373
+ rule_value = None
374
+ rule_cel = None
375
+ if rule_field is not None and self ._rules is not None :
376
+ rule_value = _proto_message_get_field (self ._rules , rule_field )
377
+ rule_cel = _field_to_cel (self ._rules , rule_field )
378
+ self ._cel .append (
379
+ CelRunner (
380
+ runner = prog ,
381
+ constraint = rules ,
382
+ rule_value = rule_value ,
383
+ rule_cel = rule_cel ,
384
+ rule_path = rule_path ,
385
+ )
386
+ )
348
387
349
388
350
389
class MessageConstraintRules (CelConstraintRules ):
351
390
"""Message-level rules."""
352
391
353
392
def validate (self , ctx : ConstraintContext , message : message .Message ):
354
- self ._validate_cel (ctx , { "this" : _msg_to_cel (message )} )
393
+ self ._validate_cel (ctx , this_cel = _msg_to_cel (message ))
355
394
356
395
357
396
def check_field_type (field : descriptor .FieldDescriptor , expected : int , wrapper_name : typing .Optional [str ] = None ):
@@ -445,7 +484,7 @@ def __init__(
445
484
env ,
446
485
funcs ,
447
486
cel ,
448
- rule = _field_to_cel ( rules , list_field ) ,
487
+ rule_field = list_field ,
449
488
rule_path = validate_pb2 .FieldPath (
450
489
elements = [
451
490
_field_to_element (list_field ),
@@ -465,13 +504,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
465
504
if _is_empty_field (message , self ._field ):
466
505
if self ._required :
467
506
ctx .add (
468
- validate_pb2 . Violation (
507
+ Violation (
469
508
field = validate_pb2 .FieldPath (
470
509
elements = [
471
510
_field_to_element (self ._field ),
472
511
],
473
512
),
474
513
rule = FieldConstraintRules ._required_rule_path ,
514
+ rule_value = self ._required ,
475
515
constraint_id = "required" ,
476
516
message = "value is required" ,
477
517
),
@@ -485,15 +525,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
485
525
return
486
526
sub_ctx = ctx .sub_context ()
487
527
self ._validate_value (sub_ctx , val )
488
- self ._validate_cel (sub_ctx , { "this" : cel_val } )
528
+ self ._validate_cel (sub_ctx , this_value = _proto_message_get_field ( message , self . _field ), this_cel = cel_val )
489
529
if sub_ctx .has_errors ():
490
530
element = _field_to_element (self ._field )
491
531
sub_ctx .add_field_path_element (element )
492
532
ctx .add_errors (sub_ctx )
493
533
494
534
def validate_item (self , ctx : ConstraintContext , val : typing .Any , * , for_key : bool = False ):
495
535
self ._validate_value (ctx , val , for_key = for_key )
496
- self ._validate_cel (ctx , { "this" : _scalar_field_value_to_cel (val , self ._field )} , for_key = for_key )
536
+ self ._validate_cel (ctx , this_value = val , this_cel = _scalar_field_value_to_cel (val , self ._field ), for_key = for_key )
497
537
498
538
def _validate_value (self , ctx : ConstraintContext , val : typing .Any , * , for_key : bool = False ):
499
539
pass
@@ -546,17 +586,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
546
586
if len (self ._in ) > 0 :
547
587
if value .type_url not in self ._in :
548
588
ctx .add (
549
- validate_pb2 . Violation (
589
+ Violation (
550
590
rule = AnyConstraintRules ._in_rule_path ,
591
+ rule_value = self ._in ,
551
592
constraint_id = "any.in" ,
552
593
message = "type URL must be in the allow list" ,
553
594
for_key = for_key ,
554
595
)
555
596
)
556
597
if value .type_url in self ._not_in :
557
598
ctx .add (
558
- validate_pb2 . Violation (
599
+ Violation (
559
600
rule = AnyConstraintRules ._not_in_rule_path ,
601
+ rule_value = self ._not_in ,
560
602
constraint_id = "any.not_in" ,
561
603
message = "type URL must not be in the block list" ,
562
604
for_key = for_key ,
@@ -603,13 +645,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
603
645
value = getattr (message , self ._field .name )
604
646
if value not in self ._field .enum_type .values_by_number :
605
647
ctx .add (
606
- validate_pb2 . Violation (
648
+ Violation (
607
649
field = validate_pb2 .FieldPath (
608
650
elements = [
609
651
_field_to_element (self ._field ),
610
652
],
611
653
),
612
654
rule = EnumConstraintRules ._defined_only_rule_path ,
655
+ rule_value = self ._defined_only ,
613
656
constraint_id = "enum.defined_only" ,
614
657
message = "value must be one of the defined enum values" ,
615
658
),
@@ -742,7 +785,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
742
785
if not message .WhichOneof (self ._oneof .name ):
743
786
if self .required :
744
787
ctx .add (
745
- validate_pb2 . Violation (
788
+ Violation (
746
789
field = validate_pb2 .FieldPath (
747
790
elements = [_oneof_to_element (self ._oneof )],
748
791
),
0 commit comments