Skip to content

Commit 95deafd

Browse files
Add handling for protobuf==7
`protobuf` v7 was released last week, which dropped support for some of the APIs that we use for checking field types. This fixes those accesses so that we can still support all the way back to v5. Fixes #410. Ref: https://pypi.org/project/protobuf/ Ref: https://github.com/protocolbuffers/protobuf/releases/tag/v34.0
1 parent 4bac257 commit 95deafd

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

protovalidate/internal/rules.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
from buf.validate import validate_pb2
2525
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
2626

27+
# protobuf 7+ removed FieldDescriptor.label / LABEL_REPEATED in favour of is_repeated.
28+
if hasattr(descriptor.FieldDescriptor, "LABEL_REPEATED"):
29+
30+
def _is_repeated(field: descriptor.FieldDescriptor) -> bool:
31+
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED # type: ignore[attr-defined]
32+
33+
else:
34+
35+
def _is_repeated(field: descriptor.FieldDescriptor) -> bool:
36+
return field.is_repeated # type: ignore[attr-defined]
37+
2738

2839
class CompilationError(Exception):
2940
pass
@@ -155,7 +166,7 @@ def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescripto
155166

156167

157168
def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
158-
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
169+
if _is_repeated(field):
159170
if field.message_type is not None and field.message_type.GetOptions().map_entry:
160171
return _map_field_value_to_cel(val, field)
161172
return _repeated_field_value_to_cel(val, field)
@@ -165,7 +176,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c
165176
def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool:
166177
if field.has_presence:
167178
return not _proto_message_has_field(msg, field)
168-
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
179+
if _is_repeated(field):
169180
return len(_proto_message_get_field(msg, field)) == 0
170181
return _proto_message_get_field(msg, field) == field.default_value
171182

@@ -194,7 +205,7 @@ def _map_field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -
194205

195206

196207
def field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value:
197-
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
208+
if _is_repeated(field):
198209
return _repeated_field_to_cel(msg, field)
199210
elif field.message_type is not None and not _proto_message_has_field(msg, field):
200211
return None
@@ -493,18 +504,18 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n
493504

494505
def _is_map(field: descriptor.FieldDescriptor):
495506
return (
496-
field.label == descriptor.FieldDescriptor.LABEL_REPEATED
507+
_is_repeated(field)
497508
and field.message_type is not None
498509
and field.message_type.GetOptions().map_entry
499510
)
500511

501512

502513
def _is_list(field: descriptor.FieldDescriptor):
503-
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED and not _is_map(field)
514+
return _is_repeated(field) and not _is_map(field)
504515

505516

506517
def _zero_value(field: descriptor.FieldDescriptor):
507-
if field.message_type is not None and field.label != descriptor.FieldDescriptor.LABEL_REPEATED:
518+
if field.message_type is not None and not _is_repeated(field):
508519
return _field_value_to_cel(message_factory.GetMessageClass(field.message_type)(), field)
509520
else:
510521
return _field_value_to_cel(field.default_value, field)
@@ -1030,7 +1041,7 @@ def _new_field_rule(
10301041
field: descriptor.FieldDescriptor,
10311042
rules: validate_pb2.FieldRules,
10321043
) -> FieldRules:
1033-
if field.label != descriptor.FieldDescriptor.LABEL_REPEATED:
1044+
if not _is_repeated(field):
10341045
return self._new_scalar_field_rule(field, rules)
10351046
if field.message_type is not None and field.message_type.GetOptions().map_entry:
10361047
key_rules = None
@@ -1084,7 +1095,7 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
10841095
if value_field.type != descriptor.FieldDescriptor.TYPE_MESSAGE:
10851096
continue
10861097
result.append(MapValMsgRule(self, field, key_field, value_field))
1087-
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1098+
elif _is_repeated(field):
10881099
result.append(RepeatedMsgRule(self, field))
10891100
else:
10901101
result.append(SubMsgRule(self, field))

0 commit comments

Comments
 (0)