Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cryptography/hazmat/asn1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
decode_der,
encode_der,
sequence,
set,
)

__all__ = [
Expand All @@ -38,4 +39,5 @@
"decode_der",
"encode_der",
"sequence",
"set",
]
48 changes: 47 additions & 1 deletion src/cryptography/hazmat/asn1/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def _normalize_field_type(

if hasattr(field_type, "__asn1_root__"):
root_type = field_type.__asn1_root__
if not isinstance(root_type, declarative_asn1.Type.Sequence):
if not isinstance(
root_type,
(declarative_asn1.Type.Sequence, declarative_asn1.Type.Set),
):
raise TypeError(f"unsupported root type: {root_type}")
return declarative_asn1.AnnotatedType(
typing.cast(declarative_asn1.Type, root_type), annotation
Expand Down Expand Up @@ -325,6 +328,13 @@ def _register_asn1_sequence(cls: type[U]) -> None:
setattr(cls, "__asn1_root__", root)


def _register_asn1_set(cls: type[U]) -> None:
raw_fields = get_type_hints(cls, include_extras=True)
root = declarative_asn1.Type.Set(cls, _annotate_fields(raw_fields))

setattr(cls, "__asn1_root__", root)


# Due to https://github.com/python/mypy/issues/19731, we can't define an alias
# for `dataclass_transform` that conditionally points to `typing` or
# `typing_extensions` depending on the Python version (like we do for
Expand Down Expand Up @@ -356,6 +366,29 @@ def sequence(cls: type[U]) -> type[U]:
_register_asn1_sequence(dataclass_cls)
return dataclass_cls

@typing_extensions.dataclass_transform(kw_only_default=True)
def set(cls: type[U]) -> type[U]:
# We use `dataclasses.dataclass` to add an __init__ method
# to the class with keyword-only parameters.
if sys.version_info >= (3, 10):
dataclass_cls = dataclasses.dataclass(
repr=False,
eq=False,
# `match_args` was added in Python 3.10 and defaults
# to True
match_args=False,
# `kw_only` was added in Python 3.10 and defaults to
# False
kw_only=True,
)(cls)
else:
dataclass_cls = dataclasses.dataclass(
repr=False,
eq=False,
)(cls)
_register_asn1_set(dataclass_cls)
return dataclass_cls

else:

@typing.dataclass_transform(kw_only_default=True)
Expand All @@ -371,6 +404,19 @@ def sequence(cls: type[U]) -> type[U]:
_register_asn1_sequence(dataclass_cls)
return dataclass_cls

@typing.dataclass_transform(kw_only_default=True)
def set(cls: type[U]) -> type[U]:
# Only add an __init__ method, with keyword-only
# parameters.
dataclass_cls = dataclasses.dataclass(
repr=False,
eq=False,
match_args=False,
kw_only=True,
)(cls)
_register_asn1_set(dataclass_cls)
return dataclass_cls


# TODO: replace with `Default[U]` once the min Python version is >= 3.12
@dataclasses.dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def non_root_python_to_rust(cls: type) -> Type: ...
class Type:
Sequence: typing.ClassVar[type]
SequenceOf: typing.ClassVar[type]
Set: typing.ClassVar[type]
SetOf: typing.ClassVar[type]
Option: typing.ClassVar[type]
Choice: typing.ClassVar[type]
Expand Down
15 changes: 15 additions & 0 deletions src/rust/src/declarative_asn1/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,21 @@ pub(crate) fn decode_annotated_type<'a>(
Ok(list.into_any())
})?
}
Type::Set(cls, fields) => {
let set_parse_result = read_value::<asn1::Set<'_>>(parser, encoding)?;

set_parse_result.parse(|d| -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
let kwargs = pyo3::types::PyDict::new(py);
let fields = fields.bind(py);
for (name, ann_type) in fields.into_iter() {
let ann_type = ann_type.cast::<AnnotatedType>()?;
let value = decode_annotated_type(py, d, ann_type.get())?;
kwargs.set_item(name, value)?;
}
let val = cls.call(py, (), Some(&kwargs))?.into_bound(py);
Ok(val)
})?
}
Type::SetOf(cls) => {
let setof_parse_result = read_value::<asn1::Set<'_>>(parser, encoding)?;

Expand Down
16 changes: 16 additions & 0 deletions src/rust/src/declarative_asn1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {

write_value(writer, &asn1::SequenceOfWriter::new(values), encoding)
}
Type::Set(_cls, fields) => write_value(
writer,
&asn1::SetWriter::new(&|w| {
for (name, ann_type) in fields.bind(py).into_iter() {
let name = name.cast::<pyo3::types::PyString>()?;
let ann_type = ann_type.cast::<AnnotatedType>()?;
let object = AnnotatedTypeObject {
annotated_type: ann_type.get(),
value: self.value.getattr(name)?,
};
w.write_element(&object)?;
}
Ok(())
}),
encoding,
),
Type::SetOf(cls) => {
let setof = value.cast::<super::types::SetOf>()?;
let values: Vec<AnnotatedTypeObject<'_>> = setof
Expand Down
5 changes: 5 additions & 0 deletions src/rust/src/declarative_asn1/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ pub enum Type {
Sequence(pyo3::Py<pyo3::types::PyType>, pyo3::Py<pyo3::types::PyDict>),
/// SEQUENCE OF (`list[`T`]`)
SequenceOf(pyo3::Py<AnnotatedType>),
/// SET(`class`, `dict`)
/// The first element is the Python class that represents the set,
/// the second element is a dict of the (already converted) fields of the class.
Set(pyo3::Py<pyo3::types::PyType>, pyo3::Py<pyo3::types::PyDict>),
/// SET OF (`list[`T`]`)
SetOf(pyo3::Py<AnnotatedType>),
/// OPTIONAL (`T | None`)
Expand Down Expand Up @@ -650,6 +654,7 @@ pub(crate) fn is_tag_valid_for_type(
) -> bool {
match type_ {
Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
Type::Set(_, _) => check_tag_with_encoding(asn1::Set::TAG, encoding, tag),
Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
Type::SetOf(_) => check_tag_with_encoding(asn1::SetOf::<()>::TAG, encoding, tag),
Type::Option(t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding),
Expand Down
70 changes: 70 additions & 0 deletions tests/hazmat/asn1/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ def test_fields_of_variant_type(self) -> None:
assert seq._0 is type(None)
assert seq._1 == {}

set = declarative_asn1.Type.Set(type(None), {})
assert set._0 is type(None)
assert set._1 == {}

ann_type = declarative_asn1.AnnotatedType(
seq, declarative_asn1.Annotation()
)
Expand Down Expand Up @@ -461,3 +465,69 @@ def test_fail_optional_tlv(self) -> None:
@asn1.sequence
class Example:
invalid: typing.Union[asn1.TLV, None]


class TestSetAPI:
def test_fail_unsupported_field(self) -> None:
class Unsupported:
foo: int

with pytest.raises(TypeError, match="cannot handle type"):

@asn1.set
class Example:
foo: Unsupported

def test_fail_init_incorrect_field_name(self) -> None:
@asn1.set
class Example:
foo: int

with pytest.raises(
TypeError, match="got an unexpected keyword argument 'bar'"
):
Example(bar=3) # type: ignore[call-arg]

def test_fail_init_missing_field_name(self) -> None:
@asn1.set
class Example:
foo: int

expected_err = (
"missing 1 required keyword-only argument: 'foo'"
if sys.version_info >= (3, 10)
else "missing 1 required positional argument: 'foo'"
)

with pytest.raises(TypeError, match=expected_err):
Example() # type: ignore[call-arg]

def test_fail_positional_field_initialization(self) -> None:
@asn1.set
class Example:
foo: int

# The kw-only init is only enforced in Python >= 3.10, which is
# when the parameter `kw_only` for `dataclasses.datalass` was
# added.
if sys.version_info < (3, 10):
assert Example(5).foo == 5 # type: ignore[misc]
else:
with pytest.raises(
TypeError,
match="takes 1 positional argument but 2 were given",
):
Example(5) # type: ignore[misc]

def test_fail_malformed_root_type(self) -> None:
@asn1.set
class Invalid:
foo: int

setattr(Invalid, "__asn1_root__", int)

with pytest.raises(TypeError, match="unsupported root type"):

@asn1.set
class Example:
foo: Invalid
88 changes: 88 additions & 0 deletions tests/hazmat/asn1/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,10 +528,16 @@ def test_ok_sequence_all_types_optional(self) -> None:
class MyField:
a: int

@asn1.set
@_comparable_dataclass
class MySetField:
a: int

@asn1.sequence
@_comparable_dataclass
class Example:
a: typing.Union[MyField, None]
a2: typing.Union[MySetField, None]
b: typing.Union[int, None]
c: typing.Union[bytes, None]
d: typing.Union[asn1.PrintableString, None]
Expand All @@ -553,6 +559,7 @@ class Example:
(
Example(
a=None,
a2=None,
b=None,
c=None,
d=None,
Expand Down Expand Up @@ -589,6 +596,11 @@ class MyField:
)
default_oid = x509.ObjectIdentifier("1.3.6.1.4.1.343")

@asn1.set
@_comparable_dataclass
class MySetField:
a: int

@asn1.sequence
@_comparable_dataclass
class Example:
Expand Down Expand Up @@ -628,6 +640,10 @@ class Example:
MyField,
asn1.Default(MyField(a=9)),
]
k3: Annotated[
MySetField,
asn1.Default(MySetField(a=9)),
]
z: Annotated[str, asn1.Default("a"), asn1.Implicit(0)]
only_field_present: Annotated[
str, asn1.Default("a"), asn1.Implicit(1)
Expand All @@ -649,6 +665,7 @@ class Example:
j=3,
k=asn1.Null(),
k2=MyField(a=9),
k3=MySetField(a=9),
z="a",
only_field_present="b",
),
Expand Down Expand Up @@ -1047,6 +1064,77 @@ class Example:
)


class TestSet:
def test_ok_set_single_field(self) -> None:
@asn1.set
@_comparable_dataclass
class Example:
foo: int

assert_roundtrips([(Example(foo=9), b"\x31\x03\x02\x01\x09")])

def test_ok_set_multiple_fields(self) -> None:
@asn1.set
@_comparable_dataclass
class Example:
foo: int
bar: int

assert_roundtrips(
[(Example(foo=6, bar=9), b"\x31\x06\x02\x01\x06\x02\x01\x09")]
)

def test_fail_set_multiple_fields_wrong_order(self) -> None:
@asn1.set
@_comparable_dataclass
class Example:
foo: int
bar: int

with pytest.raises(
ValueError,
match=re.escape(
"invalid SET ordering while performing ASN.1 serialization"
),
):
assert_roundtrips(
[(Example(foo=9, bar=6), b"\x31\x06\x02\x01\x06\x02\x01\x09")]
)

def test_ok_nested_set(self) -> None:
@asn1.set
@_comparable_dataclass
class Child:
foo: int

@asn1.set
@_comparable_dataclass
class Parent:
foo: Child

assert_roundtrips(
[(Parent(foo=Child(foo=9)), b"\x31\x05\x31\x03\x02\x01\x09")]
)

def test_ok_set_multiple_types(self) -> None:
@asn1.set
@_comparable_dataclass
class Example:
a: bool
b: int
c: bytes
d: str

assert_roundtrips(
[
(
Example(a=True, b=9, c=b"c", d="d"),
b"\x31\x0c\x01\x01\xff\x02\x01\x09\x04\x01c\x0c\x01d",
)
]
)


class TestSize:
def test_ok_sequenceof_size_restriction(self) -> None:
@asn1.sequence
Expand Down
Loading