Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule and field value to violations #224

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 88 additions & 45 deletions protovalidate/internal/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import datetime
import typing

Expand Down Expand Up @@ -81,7 +82,7 @@ def __getitem__(self, name):
return super().__getitem__(name)


def _msg_to_cel(msg: message.Message) -> dict[str, celtypes.Value]:
def _msg_to_cel(msg: message.Message) -> celtypes.Value:
ctor = _MSG_TYPE_URL_TO_CTOR.get(msg.DESCRIPTOR.full_name)
if ctor is not None:
return ctor(msg)
Expand Down Expand Up @@ -230,43 +231,56 @@ def _set_path_element_map_key(
raise CompilationError(msg)


class Violation:
"""A singular constraint violation."""

proto: validate_pb2.Violation
field_value: typing.Any
rule_value: typing.Any

def __init__(self, *, field_value: typing.Any = None, rule_value: typing.Any = None, **kwargs):
self.proto = validate_pb2.Violation(**kwargs)
self.field_value = field_value
self.rule_value = rule_value


class ConstraintContext:
"""The state associated with a single constraint evaluation."""

def __init__(self, fail_fast: bool = False, violations: validate_pb2.Violations = None): # noqa: FBT001, FBT002
def __init__(self, fail_fast: bool = False, violations: typing.Optional[list[Violation]] = None): # noqa: FBT001, FBT002
self._fail_fast = fail_fast
if violations is None:
violations = validate_pb2.Violations()
violations = []
self._violations = violations

@property
def fail_fast(self) -> bool:
return self._fail_fast

@property
def violations(self) -> validate_pb2.Violations:
def violations(self) -> list[Violation]:
return self._violations

def add(self, violation: validate_pb2.Violation):
self._violations.violations.append(violation)
def add(self, violation: Violation):
self._violations.append(violation)

def add_errors(self, other_ctx):
self._violations.violations.extend(other_ctx.violations.violations)
self._violations.extend(other_ctx.violations)

def add_field_path_element(self, element: validate_pb2.FieldPathElement):
for violation in self._violations.violations:
violation.field.elements.append(element)
for violation in self._violations:
violation.proto.field.elements.append(element)

def add_rule_path_elements(self, elements: typing.Iterable[validate_pb2.FieldPathElement]):
for violation in self._violations.violations:
violation.rule.elements.extend(elements)
for violation in self._violations:
violation.proto.rule.elements.extend(elements)

@property
def done(self) -> bool:
return self._fail_fast and self.has_errors()

def has_errors(self) -> bool:
return len(self._violations.violations) > 0
return len(self._violations) > 0

def sub_context(self):
return ConstraintContext(self._fail_fast)
Expand All @@ -277,55 +291,67 @@ class ConstraintRules:

def validate(self, ctx: ConstraintContext, message: message.Message): # noqa: ARG002
"""Validate the message against the rules in this constraint."""
ctx.add(validate_pb2.Violation(constraint_id="unimplemented", message="Unimplemented"))
ctx.add(Violation(constraint_id="unimplemented", message="Unimplemented"))


@dataclasses.dataclass
class CelRunner:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be a @dataclasses.dataclass, I think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this could just be a dataclass.

runner: celpy.Runner
constraint: validate_pb2.Constraint
rule_value: typing.Optional[typing.Any] = None
rule_cel: typing.Optional[celtypes.Value] = None
rule_path: typing.Optional[validate_pb2.FieldPath] = None


class CelConstraintRules(ConstraintRules):
"""A constraint that has rules written in CEL."""

_runners: list[
tuple[
celpy.Runner,
validate_pb2.Constraint,
typing.Optional[celtypes.Value],
typing.Optional[validate_pb2.FieldPath],
]
]
_rules_cel: celtypes.Value = None
_cel: list[CelRunner]
_rules: typing.Optional[message.Message] = None
_rules_cel: typing.Optional[celtypes.Value] = None

def __init__(self, rules: typing.Optional[message.Message]):
self._runners = []
self._cel = []
if rules is not None:
self._rules = rules
self._rules_cel = _msg_to_cel(rules)

def _validate_cel(
self,
ctx: ConstraintContext,
activation: dict[str, typing.Any],
*,
this_value: typing.Optional[typing.Any] = None,
this_cel: typing.Optional[celtypes.Value] = None,
for_key: bool = False,
):
activation: dict[str, celtypes.Value] = {}
if this_cel is not None:
activation["this"] = this_cel
activation["rules"] = self._rules_cel
activation["now"] = celtypes.TimestampType(datetime.datetime.now(tz=datetime.timezone.utc))
for runner, constraint, rule, rule_path in self._runners:
activation["rule"] = rule
result = runner.evaluate(activation)
for cel in self._cel:
activation["rule"] = cel.rule_cel
result = cel.runner.evaluate(activation)
if isinstance(result, celtypes.BoolType):
if not result:
ctx.add(
validate_pb2.Violation(
rule=rule_path,
constraint_id=constraint.id,
message=constraint.message,
Violation(
field_value=this_value,
rule=cel.rule_path,
rule_value=cel.rule_value,
constraint_id=cel.constraint.id,
message=cel.constraint.message,
for_key=for_key,
),
)
elif isinstance(result, celtypes.StringType):
if result:
ctx.add(
validate_pb2.Violation(
rule=rule_path,
constraint_id=constraint.id,
Violation(
field_value=this_value,
rule=cel.rule_path,
rule_value=cel.rule_value,
constraint_id=cel.constraint.id,
message=result,
for_key=for_key,
),
Expand All @@ -339,19 +365,32 @@ def add_rule(
funcs: dict[str, celpy.CELFunction],
rules: validate_pb2.Constraint,
*,
rule: typing.Optional[celtypes.Value] = None,
rule_field: typing.Optional[descriptor.FieldDescriptor] = None,
rule_path: typing.Optional[validate_pb2.FieldPath] = None,
):
ast = env.compile(rules.expression)
prog = env.program(ast, functions=funcs)
self._runners.append((prog, rules, rule, rule_path))
rule_value = None
rule_cel = None
if rule_field is not None and self._rules is not None:
rule_value = _proto_message_get_field(self._rules, rule_field)
rule_cel = _field_to_cel(self._rules, rule_field)
self._cel.append(
CelRunner(
runner=prog,
constraint=rules,
rule_value=rule_value,
rule_cel=rule_cel,
rule_path=rule_path,
)
)


class MessageConstraintRules(CelConstraintRules):
"""Message-level rules."""

def validate(self, ctx: ConstraintContext, message: message.Message):
self._validate_cel(ctx, {"this": _msg_to_cel(message)})
self._validate_cel(ctx, this_cel=_msg_to_cel(message))


def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_name: typing.Optional[str] = None):
Expand Down Expand Up @@ -445,7 +484,7 @@ def __init__(
env,
funcs,
cel,
rule=_field_to_cel(rules, list_field),
rule_field=list_field,
rule_path=validate_pb2.FieldPath(
elements=[
_field_to_element(list_field),
Expand All @@ -465,13 +504,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
if _is_empty_field(message, self._field):
if self._required:
ctx.add(
validate_pb2.Violation(
Violation(
field=validate_pb2.FieldPath(
elements=[
_field_to_element(self._field),
],
),
rule=FieldConstraintRules._required_rule_path,
rule_value=self._required,
constraint_id="required",
message="value is required",
),
Expand All @@ -485,15 +525,15 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
return
sub_ctx = ctx.sub_context()
self._validate_value(sub_ctx, val)
self._validate_cel(sub_ctx, {"this": cel_val})
self._validate_cel(sub_ctx, this_value=_proto_message_get_field(message, self._field), this_cel=cel_val)
if sub_ctx.has_errors():
element = _field_to_element(self._field)
sub_ctx.add_field_path_element(element)
ctx.add_errors(sub_ctx)

def validate_item(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
self._validate_value(ctx, val, for_key=for_key)
self._validate_cel(ctx, {"this": _scalar_field_value_to_cel(val, self._field)}, for_key=for_key)
self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key)

def _validate_value(self, ctx: ConstraintContext, val: typing.Any, *, for_key: bool = False):
pass
Expand Down Expand Up @@ -546,17 +586,19 @@ def _validate_value(self, ctx: ConstraintContext, value: any_pb2.Any, *, for_key
if len(self._in) > 0:
if value.type_url not in self._in:
ctx.add(
validate_pb2.Violation(
Violation(
rule=AnyConstraintRules._in_rule_path,
rule_value=self._in,
constraint_id="any.in",
message="type URL must be in the allow list",
for_key=for_key,
)
)
if value.type_url in self._not_in:
ctx.add(
validate_pb2.Violation(
Violation(
rule=AnyConstraintRules._not_in_rule_path,
rule_value=self._not_in,
constraint_id="any.not_in",
message="type URL must not be in the block list",
for_key=for_key,
Expand Down Expand Up @@ -603,13 +645,14 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
value = getattr(message, self._field.name)
if value not in self._field.enum_type.values_by_number:
ctx.add(
validate_pb2.Violation(
Violation(
field=validate_pb2.FieldPath(
elements=[
_field_to_element(self._field),
],
),
rule=EnumConstraintRules._defined_only_rule_path,
rule_value=self._defined_only,
constraint_id="enum.defined_only",
message="value must be one of the defined enum values",
),
Expand Down Expand Up @@ -742,7 +785,7 @@ def validate(self, ctx: ConstraintContext, message: message.Message):
if not message.WhichOneof(self._oneof.name):
if self.required:
ctx.add(
validate_pb2.Violation(
Violation(
field=validate_pb2.FieldPath(
elements=[_oneof_to_element(self._oneof)],
),
Expand Down
43 changes: 28 additions & 15 deletions protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from google.protobuf import message

from buf.validate import validate_pb2 # type: ignore
Expand All @@ -20,6 +22,7 @@

CompilationError = _constraints.CompilationError
Violations = validate_pb2.Violations
Violation = _constraints.Violation


class Validator:
Expand Down Expand Up @@ -54,7 +57,7 @@ def validate(
ValidationError: If the message is invalid.
"""
violations = self.collect_violations(message, fail_fast=fail_fast)
if violations.violations:
if len(violations) > 0:
msg = f"invalid {message.DESCRIPTOR.name}"
raise ValidationError(msg, violations)

Expand All @@ -63,8 +66,8 @@ def collect_violations(
message: message.Message,
*,
fail_fast: bool = False,
into: validate_pb2.Violations = None,
) -> validate_pb2.Violations:
into: typing.Optional[list[Violation]] = None,
) -> list[Violation]:
"""
Validates the given message against the static constraints defined in
the message's descriptor. Compared to validate, collect_violations is
Expand All @@ -84,12 +87,12 @@ def collect_violations(
constraint.validate(ctx, message)
if ctx.done:
break
for violation in ctx.violations.violations:
if violation.HasField("field"):
violation.field.elements.reverse()
if violation.HasField("rule"):
violation.rule.elements.reverse()
violation.field_path = field_path.string(violation.field)
for violation in ctx.violations:
if violation.proto.HasField("field"):
violation.proto.field.elements.reverse()
if violation.proto.HasField("rule"):
violation.proto.rule.elements.reverse()
violation.proto.field_path = field_path.string(violation.proto.field)
return ctx.violations


Expand All @@ -98,15 +101,25 @@ class ValidationError(ValueError):
An error raised when a message fails to validate.
"""

violations: validate_pb2.Violations
_violations: list[_constraints.Violation]

def __init__(self, msg: str, violations: validate_pb2.Violations):
def __init__(self, msg: str, violations: list[_constraints.Violation]):
super().__init__(msg)
self.violations = violations
self._violations = violations

def to_proto(self) -> validate_pb2.Violations:
"""
Provides the Protobuf form of the validation errors.
"""
result = validate_pb2.Violations()
for violation in self._violations:
result.violations.append(violation.proto)
return result

def errors(self) -> list[validate_pb2.Violation]:
@property
def violations(self) -> list[Violation]:
"""
Returns the validation errors as a simple Python list, rather than the
Provides the validation errors as a simple Python list, rather than the
Protobuf-specific collection type used by Violations.
"""
return list(self.violations.violations)
return self._violations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense, because you've aliased _constraints.Violation to Violation; assuming that the constructor takes a _constraints.Violation as it's supposed to be initialized internally, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is supposed to return protovalidate.Violation and not buf.validate.Violation. It's probably a bit confusing due to the diff in the API but the naming scheme here now matches closer to the other runtimes.

4 changes: 3 additions & 1 deletion tests/conformance/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def run_test_case(tc: typing.Any, result: typing.Optional[harness_pb2.TestResult
result = harness_pb2.TestResult()
# Run the validator
try:
protovalidate.collect_violations(tc, into=result.validation_error)
violations = protovalidate.collect_violations(tc)
for violation in violations:
result.validation_error.violations.append(violation.proto)
if len(result.validation_error.violations) == 0:
result.success = True
except celpy.CELEvalError as e:
Expand Down
Loading
Loading