Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
from buf.validate import validate_pb2
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has

# protobuf 7+ removed FieldDescriptor.label / LABEL_REPEATED in favour of is_repeated.
if hasattr(descriptor.FieldDescriptor, "is_repeated"):

def _is_repeated(field: descriptor.FieldDescriptor) -> bool:
return field.is_repeated # type: ignore[attr-defined]

else:

def _is_repeated(field: descriptor.FieldDescriptor) -> bool:
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED # type: ignore[attr-defined]


class CompilationError(Exception):
pass
Expand Down Expand Up @@ -155,7 +166,7 @@ def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescripto


def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if _is_repeated(field):
if field.message_type is not None and field.message_type.GetOptions().map_entry:
return _map_field_value_to_cel(val, field)
return _repeated_field_value_to_cel(val, field)
Expand All @@ -165,7 +176,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c
def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool:
if field.has_presence:
return not _proto_message_has_field(msg, field)
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if _is_repeated(field):
return len(_proto_message_get_field(msg, field)) == 0
return _proto_message_get_field(msg, field) == field.default_value

Expand Down Expand Up @@ -194,7 +205,7 @@ def _map_field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -


def field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value:
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if _is_repeated(field):
return _repeated_field_to_cel(msg, field)
elif field.message_type is not None and not _proto_message_has_field(msg, field):
return None
Expand Down Expand Up @@ -492,19 +503,15 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n


def _is_map(field: descriptor.FieldDescriptor):
return (
field.label == descriptor.FieldDescriptor.LABEL_REPEATED
and field.message_type is not None
and field.message_type.GetOptions().map_entry
)
return _is_repeated(field) and field.message_type is not None and field.message_type.GetOptions().map_entry


def _is_list(field: descriptor.FieldDescriptor):
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED and not _is_map(field)
return _is_repeated(field) and not _is_map(field)


def _zero_value(field: descriptor.FieldDescriptor):
if field.message_type is not None and field.label != descriptor.FieldDescriptor.LABEL_REPEATED:
if field.message_type is not None and not _is_repeated(field):
return _field_value_to_cel(message_factory.GetMessageClass(field.message_type)(), field)
else:
return _field_value_to_cel(field.default_value, field)
Expand Down Expand Up @@ -1030,7 +1037,7 @@ def _new_field_rule(
field: descriptor.FieldDescriptor,
rules: validate_pb2.FieldRules,
) -> FieldRules:
if field.label != descriptor.FieldDescriptor.LABEL_REPEATED:
if not _is_repeated(field):
return self._new_scalar_field_rule(field, rules)
if field.message_type is not None and field.message_type.GetOptions().map_entry:
key_rules = None
Expand Down Expand Up @@ -1084,7 +1091,7 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
if value_field.type != descriptor.FieldDescriptor.TYPE_MESSAGE:
continue
result.append(MapValMsgRule(self, field, key_field, value_field))
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
elif _is_repeated(field):
result.append(RepeatedMsgRule(self, field))
else:
result.append(SubMsgRule(self, field))
Expand Down
18 changes: 9 additions & 9 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.