From 1c27bdde83081cc272594125d0c0f4e3adcfa00a Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Sat, 21 Mar 2026 14:57:56 +0100 Subject: [PATCH] asn1: Add support for `SET` Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/__init__.py | 2 + src/cryptography/hazmat/asn1/asn1.py | 48 +++++++++- .../bindings/_rust/declarative_asn1.pyi | 1 + src/rust/src/declarative_asn1/decode.rs | 15 ++++ src/rust/src/declarative_asn1/encode.rs | 16 ++++ src/rust/src/declarative_asn1/types.rs | 5 ++ tests/hazmat/asn1/test_api.py | 70 +++++++++++++++ tests/hazmat/asn1/test_serialization.py | 88 +++++++++++++++++++ 8 files changed, 244 insertions(+), 1 deletion(-) diff --git a/src/cryptography/hazmat/asn1/__init__.py b/src/cryptography/hazmat/asn1/__init__.py index 223e5e16f64f..ac3d4bd8590b 100644 --- a/src/cryptography/hazmat/asn1/__init__.py +++ b/src/cryptography/hazmat/asn1/__init__.py @@ -19,6 +19,7 @@ decode_der, encode_der, sequence, + set, ) __all__ = [ @@ -38,4 +39,5 @@ "decode_der", "encode_der", "sequence", + "set", ] diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 8a017b903e25..94ac06f8b9cb 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -164,7 +164,10 @@ def _normalize_field_type( if hasattr(field_type, "__asn1_root__"): root_type = field_type.__asn1_root__ - if not isinstance(root_type, declarative_asn1.Type.Sequence): + if not isinstance( + root_type, + (declarative_asn1.Type.Sequence, declarative_asn1.Type.Set), + ): raise TypeError(f"unsupported root type: {root_type}") return declarative_asn1.AnnotatedType( typing.cast(declarative_asn1.Type, root_type), annotation @@ -325,6 +328,13 @@ def _register_asn1_sequence(cls: type[U]) -> None: setattr(cls, "__asn1_root__", root) +def _register_asn1_set(cls: type[U]) -> None: + raw_fields = get_type_hints(cls, include_extras=True) + root = declarative_asn1.Type.Set(cls, _annotate_fields(raw_fields)) + + setattr(cls, "__asn1_root__", root) + + # Due to https://github.com/python/mypy/issues/19731, we can't define an alias # for `dataclass_transform` that conditionally points to `typing` or # `typing_extensions` depending on the Python version (like we do for @@ -356,6 +366,29 @@ def sequence(cls: type[U]) -> type[U]: _register_asn1_sequence(dataclass_cls) return dataclass_cls + @typing_extensions.dataclass_transform(kw_only_default=True) + def set(cls: type[U]) -> type[U]: + # We use `dataclasses.dataclass` to add an __init__ method + # to the class with keyword-only parameters. + if sys.version_info >= (3, 10): + dataclass_cls = dataclasses.dataclass( + repr=False, + eq=False, + # `match_args` was added in Python 3.10 and defaults + # to True + match_args=False, + # `kw_only` was added in Python 3.10 and defaults to + # False + kw_only=True, + )(cls) + else: + dataclass_cls = dataclasses.dataclass( + repr=False, + eq=False, + )(cls) + _register_asn1_set(dataclass_cls) + return dataclass_cls + else: @typing.dataclass_transform(kw_only_default=True) @@ -371,6 +404,19 @@ def sequence(cls: type[U]) -> type[U]: _register_asn1_sequence(dataclass_cls) return dataclass_cls + @typing.dataclass_transform(kw_only_default=True) + def set(cls: type[U]) -> type[U]: + # Only add an __init__ method, with keyword-only + # parameters. + dataclass_cls = dataclasses.dataclass( + repr=False, + eq=False, + match_args=False, + kw_only=True, + )(cls) + _register_asn1_set(dataclass_cls) + return dataclass_cls + # TODO: replace with `Default[U]` once the min Python version is >= 3.12 @dataclasses.dataclass(frozen=True) diff --git a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi index 5f5062495a82..a4c0cb6d66cd 100644 --- a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi +++ b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi @@ -14,6 +14,7 @@ def non_root_python_to_rust(cls: type) -> Type: ... class Type: Sequence: typing.ClassVar[type] SequenceOf: typing.ClassVar[type] + Set: typing.ClassVar[type] SetOf: typing.ClassVar[type] Option: typing.ClassVar[type] Choice: typing.ClassVar[type] diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index d88b16a1e18e..fb31adb18073 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -286,6 +286,21 @@ pub(crate) fn decode_annotated_type<'a>( Ok(list.into_any()) })? } + Type::Set(cls, fields) => { + let set_parse_result = read_value::>(parser, encoding)?; + + set_parse_result.parse(|d| -> ParseResult> { + let kwargs = pyo3::types::PyDict::new(py); + let fields = fields.bind(py); + for (name, ann_type) in fields.into_iter() { + let ann_type = ann_type.cast::()?; + let value = decode_annotated_type(py, d, ann_type.get())?; + kwargs.set_item(name, value)?; + } + let val = cls.call(py, (), Some(&kwargs))?.into_bound(py); + Ok(val) + })? + } Type::SetOf(cls) => { let setof_parse_result = read_value::>(parser, encoding)?; diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index f17b85895441..e22162083c17 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -78,6 +78,22 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { write_value(writer, &asn1::SequenceOfWriter::new(values), encoding) } + Type::Set(_cls, fields) => write_value( + writer, + &asn1::SetWriter::new(&|w| { + for (name, ann_type) in fields.bind(py).into_iter() { + let name = name.cast::()?; + let ann_type = ann_type.cast::()?; + let object = AnnotatedTypeObject { + annotated_type: ann_type.get(), + value: self.value.getattr(name)?, + }; + w.write_element(&object)?; + } + Ok(()) + }), + encoding, + ), Type::SetOf(cls) => { let setof = value.cast::()?; let values: Vec> = setof diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index ea89aaf8c459..487991f678ff 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -23,6 +23,10 @@ pub enum Type { Sequence(pyo3::Py, pyo3::Py), /// SEQUENCE OF (`list[`T`]`) SequenceOf(pyo3::Py), + /// SET(`class`, `dict`) + /// The first element is the Python class that represents the set, + /// the second element is a dict of the (already converted) fields of the class. + Set(pyo3::Py, pyo3::Py), /// SET OF (`list[`T`]`) SetOf(pyo3::Py), /// OPTIONAL (`T | None`) @@ -650,6 +654,7 @@ pub(crate) fn is_tag_valid_for_type( ) -> bool { match type_ { Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag), + Type::Set(_, _) => check_tag_with_encoding(asn1::Set::TAG, encoding, tag), Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag), Type::SetOf(_) => check_tag_with_encoding(asn1::SetOf::<()>::TAG, encoding, tag), Type::Option(t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding), diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index a08ba3811249..878b4cd7d301 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -362,6 +362,10 @@ def test_fields_of_variant_type(self) -> None: assert seq._0 is type(None) assert seq._1 == {} + set = declarative_asn1.Type.Set(type(None), {}) + assert set._0 is type(None) + assert set._1 == {} + ann_type = declarative_asn1.AnnotatedType( seq, declarative_asn1.Annotation() ) @@ -461,3 +465,69 @@ def test_fail_optional_tlv(self) -> None: @asn1.sequence class Example: invalid: typing.Union[asn1.TLV, None] + + +class TestSetAPI: + def test_fail_unsupported_field(self) -> None: + class Unsupported: + foo: int + + with pytest.raises(TypeError, match="cannot handle type"): + + @asn1.set + class Example: + foo: Unsupported + + def test_fail_init_incorrect_field_name(self) -> None: + @asn1.set + class Example: + foo: int + + with pytest.raises( + TypeError, match="got an unexpected keyword argument 'bar'" + ): + Example(bar=3) # type: ignore[call-arg] + + def test_fail_init_missing_field_name(self) -> None: + @asn1.set + class Example: + foo: int + + expected_err = ( + "missing 1 required keyword-only argument: 'foo'" + if sys.version_info >= (3, 10) + else "missing 1 required positional argument: 'foo'" + ) + + with pytest.raises(TypeError, match=expected_err): + Example() # type: ignore[call-arg] + + def test_fail_positional_field_initialization(self) -> None: + @asn1.set + class Example: + foo: int + + # The kw-only init is only enforced in Python >= 3.10, which is + # when the parameter `kw_only` for `dataclasses.datalass` was + # added. + if sys.version_info < (3, 10): + assert Example(5).foo == 5 # type: ignore[misc] + else: + with pytest.raises( + TypeError, + match="takes 1 positional argument but 2 were given", + ): + Example(5) # type: ignore[misc] + + def test_fail_malformed_root_type(self) -> None: + @asn1.set + class Invalid: + foo: int + + setattr(Invalid, "__asn1_root__", int) + + with pytest.raises(TypeError, match="unsupported root type"): + + @asn1.set + class Example: + foo: Invalid diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index 435d346ab898..9edb89963415 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -528,10 +528,16 @@ def test_ok_sequence_all_types_optional(self) -> None: class MyField: a: int + @asn1.set + @_comparable_dataclass + class MySetField: + a: int + @asn1.sequence @_comparable_dataclass class Example: a: typing.Union[MyField, None] + a2: typing.Union[MySetField, None] b: typing.Union[int, None] c: typing.Union[bytes, None] d: typing.Union[asn1.PrintableString, None] @@ -553,6 +559,7 @@ class Example: ( Example( a=None, + a2=None, b=None, c=None, d=None, @@ -589,6 +596,11 @@ class MyField: ) default_oid = x509.ObjectIdentifier("1.3.6.1.4.1.343") + @asn1.set + @_comparable_dataclass + class MySetField: + a: int + @asn1.sequence @_comparable_dataclass class Example: @@ -628,6 +640,10 @@ class Example: MyField, asn1.Default(MyField(a=9)), ] + k3: Annotated[ + MySetField, + asn1.Default(MySetField(a=9)), + ] z: Annotated[str, asn1.Default("a"), asn1.Implicit(0)] only_field_present: Annotated[ str, asn1.Default("a"), asn1.Implicit(1) @@ -649,6 +665,7 @@ class Example: j=3, k=asn1.Null(), k2=MyField(a=9), + k3=MySetField(a=9), z="a", only_field_present="b", ), @@ -1047,6 +1064,77 @@ class Example: ) +class TestSet: + def test_ok_set_single_field(self) -> None: + @asn1.set + @_comparable_dataclass + class Example: + foo: int + + assert_roundtrips([(Example(foo=9), b"\x31\x03\x02\x01\x09")]) + + def test_ok_set_multiple_fields(self) -> None: + @asn1.set + @_comparable_dataclass + class Example: + foo: int + bar: int + + assert_roundtrips( + [(Example(foo=6, bar=9), b"\x31\x06\x02\x01\x06\x02\x01\x09")] + ) + + def test_fail_set_multiple_fields_wrong_order(self) -> None: + @asn1.set + @_comparable_dataclass + class Example: + foo: int + bar: int + + with pytest.raises( + ValueError, + match=re.escape( + "invalid SET ordering while performing ASN.1 serialization" + ), + ): + assert_roundtrips( + [(Example(foo=9, bar=6), b"\x31\x06\x02\x01\x06\x02\x01\x09")] + ) + + def test_ok_nested_set(self) -> None: + @asn1.set + @_comparable_dataclass + class Child: + foo: int + + @asn1.set + @_comparable_dataclass + class Parent: + foo: Child + + assert_roundtrips( + [(Parent(foo=Child(foo=9)), b"\x31\x05\x31\x03\x02\x01\x09")] + ) + + def test_ok_set_multiple_types(self) -> None: + @asn1.set + @_comparable_dataclass + class Example: + a: bool + b: int + c: bytes + d: str + + assert_roundtrips( + [ + ( + Example(a=True, b=9, c=b"c", d="d"), + b"\x31\x0c\x01\x01\xff\x02\x01\x09\x04\x01c\x0c\x01d", + ) + ] + ) + + class TestSize: def test_ok_sequenceof_size_restriction(self) -> None: @asn1.sequence