diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 1ac8335a653d..3e951356854d 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -272,6 +272,9 @@ def x448_supported(self) -> bool: and not rust_openssl.CRYPTOGRAPHY_IS_AWSLC ) + def mldsa_supported(self) -> bool: + return rust_openssl.CRYPTOGRAPHY_IS_AWSLC + def ed25519_supported(self) -> bool: return not self._fips_enabled diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index 1504f458ca32..16c8a2b80b2d 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -18,6 +18,7 @@ from cryptography.hazmat.bindings._rust.openssl import ( hpke, kdf, keys, + mldsa, poly1305, rsa, x448, @@ -38,6 +39,7 @@ __all__ = [ "hpke", "kdf", "keys", + "mldsa", "openssl_version", "openssl_version_text", "poly1305", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi new file mode 100644 index 000000000000..83ef45f65dc0 --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi @@ -0,0 +1,13 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import mldsa +from cryptography.utils import Buffer + +class MlDsa65PrivateKey: ... +class MlDsa65PublicKey: ... + +def generate_key() -> mldsa.MlDsa65PrivateKey: ... +def from_public_bytes(data: bytes) -> mldsa.MlDsa65PublicKey: ... +def from_seed_bytes(data: Buffer) -> mldsa.MlDsa65PrivateKey: ... diff --git a/src/cryptography/hazmat/primitives/asymmetric/mldsa.py b/src/cryptography/hazmat/primitives/asymmetric/mldsa.py new file mode 100644 index 000000000000..b4ac1a08c757 --- /dev/null +++ b/src/cryptography/hazmat/primitives/asymmetric/mldsa.py @@ -0,0 +1,155 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.exceptions import UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization +from cryptography.utils import Buffer + + +class MlDsa65PublicKey(metaclass=abc.ABCMeta): + @classmethod + def from_public_bytes(cls, data: bytes) -> MlDsa65PublicKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.mldsa_supported(): + raise UnsupportedAlgorithm( + "ML-DSA-65 is not supported by this backend.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.mldsa.from_public_bytes(data) + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + The serialized bytes of the public key. + """ + + @abc.abstractmethod + def public_bytes_raw(self) -> bytes: + """ + The raw bytes of the public key. + Equivalent to public_bytes(Raw, Raw). + + The public key is 1,952 bytes for MLDSA-65. + """ + + @abc.abstractmethod + def verify( + self, + signature: Buffer, + data: Buffer, + context: Buffer | None = None, + ) -> None: + """ + Verify the signature. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + @abc.abstractmethod + def __copy__(self) -> MlDsa65PublicKey: + """ + Returns a copy. + """ + + @abc.abstractmethod + def __deepcopy__(self, memo: dict) -> MlDsa65PublicKey: + """ + Returns a deep copy. + """ + + +if hasattr(rust_openssl, "mldsa"): + MlDsa65PublicKey.register(rust_openssl.mldsa.MlDsa65PublicKey) + + +class MlDsa65PrivateKey(metaclass=abc.ABCMeta): + @classmethod + def generate(cls) -> MlDsa65PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.mldsa_supported(): + raise UnsupportedAlgorithm( + "ML-DSA-65 is not supported by this backend.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.mldsa.generate_key() + + @classmethod + def from_seed_bytes(cls, data: Buffer) -> MlDsa65PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.mldsa_supported(): + raise UnsupportedAlgorithm( + "ML-DSA-65 is not supported by this backend.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.mldsa.from_seed_bytes(data) + + @abc.abstractmethod + def public_key(self) -> MlDsa65PublicKey: + """ + The MlDsa65PublicKey derived from the private key. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + The serialized bytes of the private key. + + This method only returns the serialization of the seed form of the + private key, never the expanded one. + """ + + @abc.abstractmethod + def private_bytes_raw(self) -> bytes: + """ + The raw bytes of the private key. + Equivalent to private_bytes(Raw, Raw, NoEncryption()). + + This method only returns the seed form of the private key (32 bytes). + """ + + @abc.abstractmethod + def sign(self, data: Buffer, context: Buffer | None = None) -> bytes: + """ + Signs the data. + """ + + @abc.abstractmethod + def __copy__(self) -> MlDsa65PrivateKey: + """ + Returns a copy. + """ + + @abc.abstractmethod + def __deepcopy__(self, memo: dict) -> MlDsa65PrivateKey: + """ + Returns a deep copy. + """ + + +if hasattr(rust_openssl, "mldsa"): + MlDsa65PrivateKey.register(rust_openssl.mldsa.MlDsa65PrivateKey) diff --git a/src/cryptography/hazmat/primitives/asymmetric/types.py b/src/cryptography/hazmat/primitives/asymmetric/types.py index 1fe4eaf51d85..3854e2e234a9 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/types.py +++ b/src/cryptography/hazmat/primitives/asymmetric/types.py @@ -13,6 +13,7 @@ ec, ed448, ed25519, + mldsa, rsa, x448, x25519, @@ -26,6 +27,7 @@ ec.EllipticCurvePublicKey, ed25519.Ed25519PublicKey, ed448.Ed448PublicKey, + mldsa.MlDsa65PublicKey, x25519.X25519PublicKey, x448.X448PublicKey, ] @@ -42,6 +44,7 @@ dh.DHPrivateKey, ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey, + mldsa.MlDsa65PrivateKey, rsa.RSAPrivateKey, dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, diff --git a/src/rust/cryptography-key-parsing/src/pkcs8.rs b/src/rust/cryptography-key-parsing/src/pkcs8.rs index 2ac84d0d177c..5c7354d25372 100644 --- a/src/rust/cryptography-key-parsing/src/pkcs8.rs +++ b/src/rust/cryptography-key-parsing/src/pkcs8.rs @@ -22,6 +22,28 @@ pub struct PrivateKeyInfo<'a> { pub attributes: Option>, } +// RFC 9881 Section 6.5 +#[cfg(CRYPTOGRAPHY_IS_AWSLC)] +#[derive(asn1::Asn1Read, asn1::Asn1Write)] +pub enum MlDsaPrivateKey { + #[implicit(0)] + Seed([u8; 32]), +} + +/// Extract the 32-byte ML-DSA-65 seed from a private key. +/// +/// AWS-LC's `raw_private_key()` returns the expanded key, not the seed. +/// This function round-trips through the native PKCS#8 encoding to extract it. +/// https://github.com/aws/aws-lc/issues/3072 +#[cfg(CRYPTOGRAPHY_IS_AWSLC)] +pub fn mldsa_seed_from_pkey( + pkey: &openssl::pkey::PKeyRef, +) -> Result { + let pkcs8_der = pkey.private_key_to_pkcs8()?; + let pki = asn1::parse_single::>(&pkcs8_der).unwrap(); + Ok(asn1::parse_single::(pki.private_key).unwrap()) +} + pub fn parse_private_key( data: &[u8], ) -> KeyParsingResult> { @@ -108,6 +130,12 @@ pub fn parse_private_key( )?) } + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + AlgorithmParameters::MlDsa65 => { + let MlDsaPrivateKey::Seed(seed) = asn1::parse_single::(k.private_key)?; + Ok(cryptography_openssl::mldsa::new_raw_private_key(&seed)?) + } + _ => Err(KeyParsingError::UnsupportedKeyType( k.algorithm.oid().clone(), )), @@ -443,6 +471,11 @@ pub fn serialize_private_key( (params, private_key_der) } + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + cryptography_openssl::mldsa::PKEY_ID => { + let private_key_der = asn1::write_single(&mldsa_seed_from_pkey(pkey)?)?; + (AlgorithmParameters::MlDsa65, private_key_der) + } _ => { unimplemented!("Unknown key type"); } diff --git a/src/rust/cryptography-key-parsing/src/spki.rs b/src/rust/cryptography-key-parsing/src/spki.rs index 7ce292b642d0..da18f6f6fdb8 100644 --- a/src/rust/cryptography-key-parsing/src/spki.rs +++ b/src/rust/cryptography-key-parsing/src/spki.rs @@ -100,6 +100,12 @@ pub fn parse_public_key( Ok(openssl::pkey::PKey::from_dh(dh)?) } + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + AlgorithmParameters::MlDsa65 => Ok(cryptography_openssl::mldsa::new_raw_public_key( + k.subject_public_key.as_bytes(), + ) + .map_err(|_| KeyParsingError::InvalidKey)?), + _ => Err(KeyParsingError::UnsupportedKeyType( k.algorithm.oid().clone(), )), @@ -214,6 +220,15 @@ pub fn serialize_public_key( (params, pub_key_der) } + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + cryptography_openssl::mldsa::PKEY_ID => { + let raw_bytes = pkey.raw_public_key()?; + assert_eq!( + raw_bytes.len(), + cryptography_openssl::mldsa::MLDSA65_PUBLIC_KEY_BYTES + ); + (AlgorithmParameters::MlDsa65, raw_bytes) + } _ => { unimplemented!("Unknown key type"); } diff --git a/src/rust/cryptography-openssl/src/lib.rs b/src/rust/cryptography-openssl/src/lib.rs index 7dcf8599f0d5..1f90f08c5062 100644 --- a/src/rust/cryptography-openssl/src/lib.rs +++ b/src/rust/cryptography-openssl/src/lib.rs @@ -9,6 +9,8 @@ pub mod aead; pub mod cmac; pub mod fips; pub mod hmac; +#[cfg(CRYPTOGRAPHY_IS_AWSLC)] +pub mod mldsa; #[cfg(any( CRYPTOGRAPHY_IS_BORINGSSL, CRYPTOGRAPHY_IS_LIBRESSL, diff --git a/src/rust/cryptography-openssl/src/mldsa.rs b/src/rust/cryptography-openssl/src/mldsa.rs new file mode 100644 index 000000000000..617d82e063a7 --- /dev/null +++ b/src/rust/cryptography-openssl/src/mldsa.rs @@ -0,0 +1,152 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use foreign_types_shared::ForeignType; +use openssl_sys as ffi; +use std::os::raw::c_int; + +use crate::{cvt, cvt_p, OpenSSLResult}; + +const NID_ML_DSA_65: c_int = ffi::NID_MLDSA65; +pub const PKEY_ID: openssl::pkey::Id = openssl::pkey::Id::from_raw(ffi::NID_PQDSA); +const MLDSA65_SIGNATURE_BYTES: usize = 3309; +pub const MLDSA65_PUBLIC_KEY_BYTES: usize = 1952; +pub const MLDSA65_SEED_BYTES: usize = 32; + +extern "C" { + // We call ml_dsa_65_sign/verify directly instead of going through + // EVP_DigestSign/EVP_DigestVerify because the EVP PQDSA path hardcodes + // context to (NULL, 0), so we'd lose context string support. + fn ml_dsa_65_sign( + private_key: *const u8, + sig: *mut u8, + sig_len: *mut usize, + message: *const u8, + message_len: usize, + ctx_string: *const u8, + ctx_string_len: usize, + ) -> c_int; + + fn ml_dsa_65_verify( + public_key: *const u8, + sig: *const u8, + sig_len: usize, + message: *const u8, + message_len: usize, + ctx_string: *const u8, + ctx_string_len: usize, + ) -> c_int; +} + +pub fn new_raw_private_key( + data: &[u8], +) -> OpenSSLResult> { + // SAFETY: EVP_PKEY_pqdsa_new_raw_private_key creates a new EVP_PKEY from + // raw key bytes. For ML-DSA-65, a 32-byte seed expands into the full + // keypair. + unsafe { + let pkey = cvt_p(ffi::EVP_PKEY_pqdsa_new_raw_private_key( + NID_ML_DSA_65, + data.as_ptr(), + data.len(), + ))?; + Ok(openssl::pkey::PKey::from_ptr(pkey)) + } +} + +pub fn new_raw_public_key( + data: &[u8], +) -> OpenSSLResult> { + // SAFETY: EVP_PKEY_pqdsa_new_raw_public_key creates a new EVP_PKEY from + // raw public key bytes. + unsafe { + let pkey = cvt_p(ffi::EVP_PKEY_pqdsa_new_raw_public_key( + NID_ML_DSA_65, + data.as_ptr(), + data.len(), + ))?; + Ok(openssl::pkey::PKey::from_ptr(pkey)) + } +} + +pub fn sign( + pkey: &openssl::pkey::PKeyRef, + data: &[u8], + context: &[u8], +) -> OpenSSLResult> { + let raw_key = pkey.raw_private_key()?; + + let mut sig = vec![0u8; MLDSA65_SIGNATURE_BYTES]; + let mut sig_len: usize = 0; + + let msg_ptr = if data.is_empty() { + std::ptr::null() + } else { + data.as_ptr() + }; + let ctx_ptr = if context.is_empty() { + std::ptr::null() + } else { + context.as_ptr() + }; + + // SAFETY: ml_dsa_65_sign takes raw key bytes, message, and context. + unsafe { + let r = ml_dsa_65_sign( + raw_key.as_ptr(), + sig.as_mut_ptr(), + &mut sig_len, + msg_ptr, + data.len(), + ctx_ptr, + context.len(), + ); + cvt(r)?; + } + + sig.truncate(sig_len); + Ok(sig) +} + +pub fn verify( + pkey: &openssl::pkey::PKeyRef, + signature: &[u8], + data: &[u8], + context: &[u8], +) -> OpenSSLResult { + let raw_key = pkey.raw_public_key()?; + + let msg_ptr = if data.is_empty() { + std::ptr::null() + } else { + data.as_ptr() + }; + let ctx_ptr = if context.is_empty() { + std::ptr::null() + } else { + context.as_ptr() + }; + + // SAFETY: ml_dsa_65_verify takes raw key bytes, signature, message, + // and context. + let r = unsafe { + ml_dsa_65_verify( + raw_key.as_ptr(), + signature.as_ptr(), + signature.len(), + msg_ptr, + data.len(), + ctx_ptr, + context.len(), + ) + }; + + if r != 1 { + // Clear any errors from the OpenSSL error stack to prevent + // leaking errors into subsequent operations. + let _ = openssl::error::ErrorStack::get(); + } + + Ok(r == 1) +} diff --git a/src/rust/cryptography-x509/src/common.rs b/src/rust/cryptography-x509/src/common.rs index 8d9e9ddcff23..7a22c100cabe 100644 --- a/src/rust/cryptography-x509/src/common.rs +++ b/src/rust/cryptography-x509/src/common.rs @@ -53,6 +53,9 @@ pub enum AlgorithmParameters<'a> { #[defined_by(oid::ED448_OID)] Ed448, + #[defined_by(oid::ML_DSA_65_OID)] + MlDsa65, + #[defined_by(oid::X25519_OID)] X25519, #[defined_by(oid::X448_OID)] diff --git a/src/rust/cryptography-x509/src/oid.rs b/src/rust/cryptography-x509/src/oid.rs index edae48def631..4a69312d7ceb 100644 --- a/src/rust/cryptography-x509/src/oid.rs +++ b/src/rust/cryptography-x509/src/oid.rs @@ -109,6 +109,8 @@ pub const X448_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 101, 111); pub const ED25519_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 101, 112); pub const ED448_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 101, 113); +pub const ML_DSA_65_OID: asn1::ObjectIdentifier = asn1::oid!(2, 16, 840, 1, 101, 3, 4, 3, 18); + // Hashes pub const SHA1_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 14, 3, 2, 26); pub const SHA224_OID: asn1::ObjectIdentifier = asn1::oid!(2, 16, 840, 1, 101, 3, 4, 2, 4); diff --git a/src/rust/src/backend/keys.rs b/src/rust/src/backend/keys.rs index b8fc6f247781..a638c305f096 100644 --- a/src/rust/src/backend/keys.rs +++ b/src/rust/src/backend/keys.rs @@ -167,6 +167,17 @@ fn private_key_from_pkey<'p>( openssl::pkey::Id::DHX => Ok(crate::backend::dh::private_key_from_pkey(pkey) .into_pyobject(py)? .into_any()), + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + cryptography_openssl::mldsa::PKEY_ID => { + let pub_len = pkey.raw_public_key()?.len(); + assert_eq!( + pub_len, + cryptography_openssl::mldsa::MLDSA65_PUBLIC_KEY_BYTES + ); + Ok(crate::backend::mldsa::private_key_from_pkey(pkey) + .into_pyobject(py)? + .into_any()) + } _ => Err(CryptographyError::from( exceptions::UnsupportedAlgorithm::new_err("Unsupported key type."), )), @@ -294,6 +305,17 @@ fn public_key_from_pkey<'p>( .into_pyobject(py)? .into_any()), + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + cryptography_openssl::mldsa::PKEY_ID => { + let pub_len = pkey.raw_public_key()?.len(); + assert_eq!( + pub_len, + cryptography_openssl::mldsa::MLDSA65_PUBLIC_KEY_BYTES + ); + Ok(crate::backend::mldsa::public_key_from_pkey(pkey) + .into_pyobject(py)? + .into_any()) + } _ => Err(CryptographyError::from( exceptions::UnsupportedAlgorithm::new_err("Unsupported key type."), )), diff --git a/src/rust/src/backend/mldsa.rs b/src/rust/src/backend/mldsa.rs new file mode 100644 index 000000000000..87961630ff33 --- /dev/null +++ b/src/rust/src/backend/mldsa.rs @@ -0,0 +1,210 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use pyo3::types::PyAnyMethods; + +use crate::backend::utils; +use crate::buf::CffiBuf; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; + +const MAX_CONTEXT_BYTES: usize = 255; + +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.mldsa")] +pub(crate) struct MlDsa65PrivateKey { + pkey: openssl::pkey::PKey, +} + +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.mldsa")] +pub(crate) struct MlDsa65PublicKey { + pkey: openssl::pkey::PKey, +} + +pub(crate) fn private_key_from_pkey( + pkey: &openssl::pkey::PKeyRef, +) -> MlDsa65PrivateKey { + MlDsa65PrivateKey { + pkey: pkey.to_owned(), + } +} + +pub(crate) fn public_key_from_pkey( + pkey: &openssl::pkey::PKeyRef, +) -> MlDsa65PublicKey { + MlDsa65PublicKey { + pkey: pkey.to_owned(), + } +} + +#[pyo3::pyfunction] +fn generate_key() -> CryptographyResult { + let mut seed = [0u8; cryptography_openssl::mldsa::MLDSA65_SEED_BYTES]; + cryptography_openssl::rand::rand_bytes(&mut seed)?; + let pkey = cryptography_openssl::mldsa::new_raw_private_key(&seed)?; + Ok(MlDsa65PrivateKey { pkey }) +} + +#[pyo3::pyfunction] +fn from_seed_bytes(data: CffiBuf<'_>) -> pyo3::PyResult { + let pkey = cryptography_openssl::mldsa::new_raw_private_key(data.as_bytes()).map_err(|_| { + pyo3::exceptions::PyValueError::new_err("An ML-DSA-65 seed is 32 bytes long") + })?; + Ok(MlDsa65PrivateKey { pkey }) +} + +#[pyo3::pyfunction] +fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { + let pkey = cryptography_openssl::mldsa::new_raw_public_key(data).map_err(|_| { + pyo3::exceptions::PyValueError::new_err("An ML-DSA-65 public key is 1952 bytes long") + })?; + Ok(MlDsa65PublicKey { pkey }) +} + +#[pyo3::pymethods] +impl MlDsa65PrivateKey { + #[pyo3(signature = (data, context=None))] + fn sign<'p>( + &self, + py: pyo3::Python<'p>, + data: CffiBuf<'_>, + context: Option>, + ) -> CryptographyResult> { + let ctx_bytes = context.as_ref().map_or(&[][..], |c| c.as_bytes()); + if ctx_bytes.len() > MAX_CONTEXT_BYTES { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Context must be at most 255 bytes"), + )); + } + let sig = cryptography_openssl::mldsa::sign(&self.pkey, data.as_bytes(), ctx_bytes)?; + Ok(pyo3::types::PyBytes::new(py, &sig)) + } + + fn public_key(&self) -> CryptographyResult { + let raw_bytes = self.pkey.raw_public_key()?; + Ok(MlDsa65PublicKey { + pkey: cryptography_openssl::mldsa::new_raw_public_key(&raw_bytes)?, + }) + } + + fn private_bytes_raw<'p>( + &self, + py: pyo3::Python<'p>, + ) -> CryptographyResult> { + let cryptography_key_parsing::pkcs8::MlDsaPrivateKey::Seed(seed) = + cryptography_key_parsing::pkcs8::mldsa_seed_from_pkey(&self.pkey)?; + Ok(pyo3::types::PyBytes::new(py, &seed)) + } + + fn private_bytes<'p>( + slf: &pyo3::Bound<'p, Self>, + py: pyo3::Python<'p>, + encoding: crate::serialization::Encoding, + format: crate::serialization::PrivateFormat, + encryption_algorithm: &pyo3::Bound<'p, pyo3::PyAny>, + ) -> CryptographyResult> { + // Intercept Raw/Raw/NoEncryption so we return the seed. + // The generic pkey_private_bytes raw path calls raw_private_key() + // which returns the expanded key on AWS-LC, not the seed. + if encoding == crate::serialization::Encoding::Raw + && format == crate::serialization::PrivateFormat::Raw + && encryption_algorithm.is_instance(&crate::types::NO_ENCRYPTION.get(py)?)? + { + return slf.borrow().private_bytes_raw(py); + } + utils::pkey_private_bytes( + py, + slf, + &slf.borrow().pkey, + encoding, + format, + encryption_algorithm, + true, + false, + ) + } + + fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> { + slf + } + + fn __deepcopy__<'p>( + slf: pyo3::PyRef<'p, Self>, + _memo: &pyo3::Bound<'p, pyo3::PyAny>, + ) -> pyo3::PyRef<'p, Self> { + slf + } +} + +#[pyo3::pymethods] +impl MlDsa65PublicKey { + #[pyo3(signature = (signature, data, context=None))] + fn verify( + &self, + signature: CffiBuf<'_>, + data: CffiBuf<'_>, + context: Option>, + ) -> CryptographyResult<()> { + let ctx_bytes = context.as_ref().map_or(&[][..], |c| c.as_bytes()); + if ctx_bytes.len() > MAX_CONTEXT_BYTES { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Context must be at most 255 bytes"), + )); + } + let valid = cryptography_openssl::mldsa::verify( + &self.pkey, + signature.as_bytes(), + data.as_bytes(), + ctx_bytes, + ) + .unwrap_or(false); + + if !valid { + return Err(CryptographyError::from( + exceptions::InvalidSignature::new_err(()), + )); + } + + Ok(()) + } + + fn public_bytes_raw<'p>( + &self, + py: pyo3::Python<'p>, + ) -> CryptographyResult> { + let raw_bytes = self.pkey.raw_public_key()?; + Ok(pyo3::types::PyBytes::new(py, &raw_bytes)) + } + + fn public_bytes<'p>( + slf: &pyo3::Bound<'p, Self>, + py: pyo3::Python<'p>, + encoding: crate::serialization::Encoding, + format: crate::serialization::PublicFormat, + ) -> CryptographyResult> { + utils::pkey_public_bytes(py, slf, &slf.borrow().pkey, encoding, format, true, true) + } + + fn __eq__(&self, other: pyo3::PyRef<'_, Self>) -> bool { + self.pkey.public_eq(&other.pkey) + } + + fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> { + slf + } + + fn __deepcopy__<'p>( + slf: pyo3::PyRef<'p, Self>, + _memo: &pyo3::Bound<'p, pyo3::PyAny>, + ) -> pyo3::PyRef<'p, Self> { + slf + } +} + +#[pyo3::pymodule(gil_used = false)] +pub(crate) mod mldsa { + #[pymodule_export] + use super::{ + from_public_bytes, from_seed_bytes, generate_key, MlDsa65PrivateKey, MlDsa65PublicKey, + }; +} diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index a9133cafb8c8..a5e47e360357 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -21,6 +21,8 @@ pub(crate) mod hmac; pub(crate) mod hpke; pub(crate) mod kdf; pub(crate) mod keys; +#[cfg(CRYPTOGRAPHY_IS_AWSLC)] +pub(crate) mod mldsa; pub(crate) mod poly1305; pub(crate) mod rand; pub(crate) mod rsa; diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index aca7b7d1e22d..eab36ee35d64 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -242,6 +242,9 @@ mod _rust { use crate::backend::kdf::kdf; #[pymodule_export] use crate::backend::keys::keys; + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + #[pymodule_export] + use crate::backend::mldsa::mldsa; #[pymodule_export] use crate::backend::poly1305::poly1305; #[pymodule_export] diff --git a/tests/hazmat/primitives/test_mldsa.py b/tests/hazmat/primitives/test_mldsa.py new file mode 100644 index 000000000000..1364ce73b7d0 --- /dev/null +++ b/tests/hazmat/primitives/test_mldsa.py @@ -0,0 +1,392 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +import binascii +import copy +import os + +import pytest + +from cryptography.exceptions import InvalidSignature, _Reasons +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.mldsa import ( + MlDsa65PrivateKey, + MlDsa65PublicKey, +) + +from ...doubles import DummyKeySerializationEncryption +from ...utils import ( + load_nist_vectors, + load_vectors_from_file, + raises_unsupported_algorithm, +) + + +@pytest.mark.supported( + only_if=lambda backend: not backend.mldsa_supported(), + skip_message="Requires a backend without ML-DSA-65 support", +) +def test_mldsa_unsupported(backend): + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + MlDsa65PublicKey.from_public_bytes(b"0" * 1952) + + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + MlDsa65PrivateKey.from_seed_bytes(b"0" * 32) + + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + MlDsa65PrivateKey.generate() + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +class TestMlDsa65: + def test_sign_verify(self, backend): + key = MlDsa65PrivateKey.generate() + sig = key.sign(b"test data") + key.public_key().verify(sig, b"test data") + + def test_sign_verify_empty_message(self, backend): + key = MlDsa65PrivateKey.generate() + sig = key.sign(b"") + key.public_key().verify(sig, b"") + + @pytest.mark.parametrize( + "ctx", + [ + b"ctx", + b"a" * 255, + ], + ) + def test_sign_verify_with_context(self, backend, ctx): + key = MlDsa65PrivateKey.generate() + sig = key.sign(b"test data", ctx) + key.public_key().verify(sig, b"test data", ctx) + + def test_empty_context_equivalence(self, backend): + key = MlDsa65PrivateKey.generate() + pub = key.public_key() + data = b"test data" + sig = key.sign(data) + pub.verify(sig, data, b"") + sig2 = key.sign(data, b"") + pub.verify(sig2, data) + + def test_kat_vectors(self, backend, subtests): + vectors = load_vectors_from_file( + os.path.join("asymmetric", "MLDSA", "kat_MLDSA_65_det_pure.rsp"), + load_nist_vectors, + ) + for vector in vectors: + with subtests.test(): + xi = binascii.unhexlify(vector["xi"]) + pk = binascii.unhexlify(vector["pk"]) + msg = binascii.unhexlify(vector["msg"]) + ctx = binascii.unhexlify(vector["ctx"]) + sm = binascii.unhexlify(vector["sm"]) + expected_sig = sm[:3309] + + key = MlDsa65PrivateKey.from_seed_bytes(xi) + assert key.private_bytes_raw() == xi + assert key.public_key().public_bytes_raw() == pk + + pub = MlDsa65PublicKey.from_public_bytes(pk) + pub.verify(expected_sig, msg, ctx) + + def test_private_bytes_raw_round_trip(self, backend): + key = MlDsa65PrivateKey.generate() + seed = key.private_bytes_raw() + assert len(seed) == 32 + key2 = MlDsa65PrivateKey.from_seed_bytes(seed) + assert key2.private_bytes_raw() == seed + assert seed == key.private_bytes( + serialization.Encoding.Raw, + serialization.PrivateFormat.Raw, + serialization.NoEncryption(), + ) + + pub = key.public_key() + raw_pub = pub.public_bytes_raw() + assert len(raw_pub) == 1952 + pub2 = MlDsa65PublicKey.from_public_bytes(raw_pub) + assert pub2.public_bytes_raw() == raw_pub + + @pytest.mark.parametrize( + ("encoding", "fmt", "encryption", "passwd", "load_func"), + [ + ( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + None, + serialization.load_pem_private_key, + ), + ( + serialization.Encoding.DER, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + None, + serialization.load_der_private_key, + ), + ( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.BestAvailableEncryption(b"password"), + b"password", + serialization.load_pem_private_key, + ), + ( + serialization.Encoding.DER, + serialization.PrivateFormat.PKCS8, + serialization.BestAvailableEncryption(b"password"), + b"password", + serialization.load_der_private_key, + ), + ], + ) + def test_round_trip_private_serialization( + self, encoding, fmt, encryption, passwd, load_func, backend + ): + key = MlDsa65PrivateKey.generate() + serialized = key.private_bytes(encoding, fmt, encryption) + loaded_key = load_func(serialized, passwd, backend) + assert isinstance(loaded_key, MlDsa65PrivateKey) + assert loaded_key.private_bytes_raw() == key.private_bytes_raw() + sig = loaded_key.sign(b"test data") + key.public_key().verify(sig, b"test data") + + @pytest.mark.parametrize( + ("encoding", "fmt", "load_func"), + [ + ( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo, + serialization.load_pem_public_key, + ), + ( + serialization.Encoding.DER, + serialization.PublicFormat.SubjectPublicKeyInfo, + serialization.load_der_public_key, + ), + ], + ) + def test_round_trip_public_serialization( + self, encoding, fmt, load_func, backend + ): + key = MlDsa65PrivateKey.generate() + pub = key.public_key() + serialized = pub.public_bytes(encoding, fmt) + loaded_pub = load_func(serialized, backend) + assert isinstance(loaded_pub, MlDsa65PublicKey) + assert loaded_pub == pub + + def test_invalid_signature(self, backend): + key = MlDsa65PrivateKey.generate() + sig = key.sign(b"test data") + with pytest.raises(InvalidSignature): + key.public_key().verify(sig, b"wrong data") + + with pytest.raises(InvalidSignature): + key.public_key().verify(b"0" * 3309, b"test data") + + def test_context_wrong_context(self, backend): + key = MlDsa65PrivateKey.generate() + sig = key.sign(b"test data", b"ctx-a") + with pytest.raises(InvalidSignature): + key.public_key().verify(sig, b"test data", b"ctx-b") + + def test_context_too_long(self, backend): + key = MlDsa65PrivateKey.generate() + with pytest.raises(ValueError): + key.sign(b"data", b"x" * 256) + with pytest.raises(ValueError): + key.public_key().verify(b"sig", b"data", b"x" * 256) + + def test_invalid_length_from_public_bytes(self, backend): + with pytest.raises(ValueError): + MlDsa65PublicKey.from_public_bytes(b"a" * 10) + + def test_invalid_length_from_seed_bytes(self, backend): + with pytest.raises(ValueError): + MlDsa65PrivateKey.from_seed_bytes(b"a" * 10) + + def test_invalid_type_public_bytes(self, backend): + with pytest.raises(TypeError): + MlDsa65PublicKey.from_public_bytes( + object() # type: ignore[arg-type] + ) + + def test_invalid_type_seed_bytes(self, backend): + with pytest.raises(TypeError): + MlDsa65PrivateKey.from_seed_bytes( + object() # type: ignore[arg-type] + ) + + def test_invalid_private_bytes(self, backend): + key = MlDsa65PrivateKey.generate() + with pytest.raises(TypeError): + key.private_bytes( + serialization.Encoding.Raw, + serialization.PrivateFormat.Raw, + None, # type: ignore[arg-type] + ) + with pytest.raises(ValueError): + key.private_bytes( + serialization.Encoding.Raw, + serialization.PrivateFormat.Raw, + DummyKeySerializationEncryption(), + ) + + with pytest.raises(ValueError): + key.private_bytes( + serialization.Encoding.Raw, + serialization.PrivateFormat.PKCS8, + DummyKeySerializationEncryption(), + ) + + with pytest.raises(ValueError): + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.Raw, + serialization.NoEncryption(), + ) + + def test_invalid_public_bytes(self, backend): + key = MlDsa65PrivateKey.generate().public_key() + with pytest.raises(ValueError): + key.public_bytes( + serialization.Encoding.Raw, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + with pytest.raises(ValueError): + key.public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.PKCS1, + ) + + with pytest.raises(ValueError): + key.public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.Raw, + ) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_unsupported_mldsa_variant_private_key(backend): + # ML-DSA-44 is not supported; loading it must raise UnsupportedAlgorithm. + pkcs8_der = load_vectors_from_file( + os.path.join("asymmetric", "MLDSA", "mldsa44_priv.der"), + lambda derfile: derfile.read(), + mode="rb", + ) + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + serialization.load_der_private_key( + pkcs8_der, password=None, backend=backend + ) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_mldsa65_private_key_no_seed(backend): + pkcs8_der = load_vectors_from_file( + os.path.join("asymmetric", "MLDSA", "mldsa65_noseed_priv.der"), + lambda derfile: derfile.read(), + mode="rb", + ) + with pytest.raises(ValueError): + serialization.load_der_private_key( + pkcs8_der, password=None, backend=backend + ) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_unsupported_mldsa_variant_public_key(backend): + # ML-DSA-44 is not supported; loading it must raise UnsupportedAlgorithm. + spki_der = load_vectors_from_file( + os.path.join("asymmetric", "MLDSA", "mldsa44_pub.der"), + lambda derfile: derfile.read(), + mode="rb", + ) + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + serialization.load_der_public_key(spki_der, backend=backend) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_public_key_equality(backend): + key = MlDsa65PrivateKey.generate() + pub1 = key.public_key() + pub2 = key.public_key() + pub3 = MlDsa65PrivateKey.generate().public_key() + assert pub1 == pub2 + assert pub1 != pub3 + assert pub1 != object() + + with pytest.raises(TypeError): + pub1 < pub2 # type: ignore[operator] + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_public_key_copy(backend): + key = MlDsa65PrivateKey.generate() + pub1 = key.public_key() + pub2 = copy.copy(pub1) + assert pub1 == pub2 + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_public_key_deepcopy(backend): + key = MlDsa65PrivateKey.generate() + pub1 = key.public_key() + pub2 = copy.deepcopy(pub1) + assert pub1 == pub2 + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_private_key_copy(backend): + key1 = MlDsa65PrivateKey.generate() + key2 = copy.copy(key1) + assert key1.private_bytes_raw() == key2.private_bytes_raw() + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +def test_private_key_deepcopy(backend): + key1 = MlDsa65PrivateKey.generate() + key2 = copy.deepcopy(key1) + assert key1.private_bytes_raw() == key2.private_bytes_raw() diff --git a/tests/wycheproof/test_mldsa.py b/tests/wycheproof/test_mldsa.py new file mode 100644 index 000000000000..48615356efcc --- /dev/null +++ b/tests/wycheproof/test_mldsa.py @@ -0,0 +1,86 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import binascii + +import pytest + +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives.asymmetric.mldsa import ( + MlDsa65PrivateKey, + MlDsa65PublicKey, +) + +from .utils import wycheproof_tests + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +@wycheproof_tests("mldsa_65_verify_test.json") +def test_mldsa65_verify(backend, wycheproof): + try: + pub = MlDsa65PublicKey.from_public_bytes( + binascii.unhexlify(wycheproof.testgroup["publicKey"]) + ) + except ValueError: + assert wycheproof.invalid + assert wycheproof.has_flag("IncorrectPublicKeyLength") + return + + msg = binascii.unhexlify(wycheproof.testcase["msg"]) + sig = binascii.unhexlify(wycheproof.testcase["sig"]) + has_ctx = "ctx" in wycheproof.testcase + ctx = binascii.unhexlify(wycheproof.testcase["ctx"]) if has_ctx else None + + if wycheproof.valid: + pub.verify(sig, msg, ctx) + else: + with pytest.raises( + ( + ValueError, + InvalidSignature, + ) + ): + pub.verify(sig, msg, ctx) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA-65 support", +) +@wycheproof_tests("mldsa_65_sign_seed_test.json") +def test_mldsa65_sign_seed(backend, wycheproof): + # Skip "Internal" tests, they use the inner method `Sign_internal` + # instead of `Sign` which we do not expose. + if wycheproof.has_flag("Internal"): + return + + seed = binascii.unhexlify(wycheproof.testgroup["privateSeed"]) + try: + key = MlDsa65PrivateKey.from_seed_bytes(seed) + except ValueError: + assert wycheproof.invalid + assert wycheproof.has_flag("IncorrectPrivateKeyLength") + return + pub = MlDsa65PublicKey.from_public_bytes( + binascii.unhexlify(wycheproof.testgroup["publicKey"]) + ) + + assert key.public_key() == pub + + msg = binascii.unhexlify(wycheproof.testcase["msg"]) + has_ctx = "ctx" in wycheproof.testcase + ctx = binascii.unhexlify(wycheproof.testcase["ctx"]) if has_ctx else None + + if wycheproof.valid or wycheproof.acceptable: + # Sign and verify round-trip. We don't compare exact signature + # bytes because some backends use hedged (randomized) signing. + sig = key.sign(msg, ctx) + pub.verify(sig, msg, ctx) + else: + with pytest.raises(ValueError): + assert has_ctx + key.sign(msg, ctx)