diff --git a/src/joserfc/_rfc7518/ec_key.py b/src/joserfc/_rfc7518/ec_key.py index 49805d8..25d4e6f 100644 --- a/src/joserfc/_rfc7518/ec_key.py +++ b/src/joserfc/_rfc7518/ec_key.py @@ -1,4 +1,5 @@ from __future__ import annotations + import typing as t from functools import cached_property from cryptography.hazmat.primitives import hashes @@ -74,11 +75,12 @@ def import_private_key(cls, obj: ECDictKey) -> EllipticCurvePrivateKey: @classmethod def export_private_key(cls, key: EllipticCurvePrivateKey) -> ECDictKey: numbers = key.private_numbers() + byte_count = (key.key_size + 7) // 8 return { "crv": cls._curves_dss[key.curve.name], - "x": int_to_base64(numbers.public_numbers.x), - "y": int_to_base64(numbers.public_numbers.y), - "d": int_to_base64(numbers.private_value), + "x": int_to_base64(numbers.public_numbers.x, byte_count), + "y": int_to_base64(numbers.public_numbers.y, byte_count), + "d": int_to_base64(numbers.private_value, byte_count), } @classmethod @@ -94,10 +96,11 @@ def import_public_key(cls, obj: ECDictKey) -> EllipticCurvePublicKey: @classmethod def export_public_key(cls, key: EllipticCurvePublicKey) -> ECDictKey: numbers = key.public_numbers() + byte_count = (key.key_size + 7) // 8 return { "crv": cls._curves_dss[numbers.curve.name], - "x": int_to_base64(numbers.x), - "y": int_to_base64(numbers.y), + "x": int_to_base64(numbers.x, byte_count), + "y": int_to_base64(numbers.y, byte_count), } diff --git a/src/joserfc/util.py b/src/joserfc/util.py index c287606..f0de753 100644 --- a/src/joserfc/util.py +++ b/src/joserfc/util.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional import base64 import struct import binascii @@ -48,11 +48,16 @@ def base64_to_int(s: str) -> int: return int("".join(["%02x" % byte for byte in buf]), 16) -def int_to_base64(num: int) -> str: +def int_to_base64(num: int, byte_count: Optional[int] = None) -> str: if num < 0: raise ValueError("Must be a positive integer") - s = num.to_bytes((num.bit_length() + 7) // 8, "big", signed=False) + if byte_count is None: + byte_count = (num.bit_length() + 7) // 8 + elif num.bit_length() > byte_count * 8: + raise ValueError("Number too large for byte count") + + s = num.to_bytes(byte_count, "big", signed=False) return urlsafe_b64encode(s).decode("utf-8", "strict") diff --git a/tests/jwk/test_ec_key.py b/tests/jwk/test_ec_key.py index d75e25c..443f63e 100644 --- a/tests/jwk/test_ec_key.py +++ b/tests/jwk/test_ec_key.py @@ -129,3 +129,54 @@ def test_derive_key_with_different_hash(self): key1 = ECKey.derive_key("ec-secret-key", "P-256", kdf_options={"algorithm": hashes.SHA256()}) key2 = ECKey.derive_key("ec-secret-key", "P-256", kdf_options={"algorithm": hashes.SHA512()}) self.assertNotEqual(key1, key2) + + def run_verify_full_size(self, curve_name: str, expected_base64_count: int): + """ + Verifies that the full-size keys (private and public) generated using the specified curve conform to the expected + Base64-encoded string length for their respective components. The checks involve generating keys that could lead + to truncated values when encoded and ensuring their lengths match the specified expectation. + + See section: https://datatracker.ietf.org/doc/html/rfc7518#section-6.2 + + Parameters: + curve_name (str): The name of the elliptic curve to use for key generation. + expected_base64_count (int): The expected length of the Base64-encoded key components (x, y, d). + + Raises: + AssertionError: Raised if any of the generated private or public key components fail to match the expected lengths. + """ + private_key = ECKey.generate_key(curve_name) + # find the number which requires one less byte(octet) than a full padding + byte_count = (private_key.curve_key_size + 7) // 8 + lower_cap = pow(2, 8 * (byte_count - 1)) + attempts_remaining = 1000000 + + # now generate keys until we find a parameter which could be truncated + while ( + private_key.public_key.public_numbers().x >= lower_cap + and private_key.public_key.public_numbers().y >= lower_cap + and private_key.private_key.private_numbers().private_value >= lower_cap + ): + private_key = ECKey.generate_key(curve_name) + attempts_remaining -= 1 + if attempts_remaining == 0: + raise AssertionError("Failed to find a key parameter that could be truncated") + + output_private = private_key.as_dict(private=True) + self.assertEqual(expected_base64_count, len(output_private["x"])) + self.assertEqual(expected_base64_count, len(output_private["y"])) + self.assertEqual(expected_base64_count, len(output_private["d"])) + + pub_key = ECKey.import_key(private_key.public_key) + output_public = pub_key.as_dict(private=False) + self.assertEqual(expected_base64_count, len(output_public["x"])) + self.assertEqual(expected_base64_count, len(output_public["y"])) + + def test_p256_full_size(self): + self.run_verify_full_size("P-256", 43) + + def test_p384_full_size(self): + self.run_verify_full_size("P-384", 64) + + def test_p521_full_size(self): + self.run_verify_full_size("P-521", 88)