Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/joserfc/_rfc7518/ec_key.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import typing as t
from functools import cached_property
from cryptography.hazmat.primitives import hashes
Expand Down Expand Up @@ -74,11 +75,12 @@ def import_private_key(cls, obj: ECDictKey) -> EllipticCurvePrivateKey:
@classmethod
def export_private_key(cls, key: EllipticCurvePrivateKey) -> ECDictKey:
numbers = key.private_numbers()
byte_count = (key.key_size + 7) // 8
return {
"crv": cls._curves_dss[key.curve.name],
"x": int_to_base64(numbers.public_numbers.x),
"y": int_to_base64(numbers.public_numbers.y),
"d": int_to_base64(numbers.private_value),
"x": int_to_base64(numbers.public_numbers.x, byte_count),
"y": int_to_base64(numbers.public_numbers.y, byte_count),
"d": int_to_base64(numbers.private_value, byte_count),
}

@classmethod
Expand All @@ -94,10 +96,11 @@ def import_public_key(cls, obj: ECDictKey) -> EllipticCurvePublicKey:
@classmethod
def export_public_key(cls, key: EllipticCurvePublicKey) -> ECDictKey:
numbers = key.public_numbers()
byte_count = (key.key_size + 7) // 8
return {
"crv": cls._curves_dss[numbers.curve.name],
"x": int_to_base64(numbers.x),
"y": int_to_base64(numbers.y),
"x": int_to_base64(numbers.x, byte_count),
"y": int_to_base64(numbers.y, byte_count),
}


Expand Down
11 changes: 8 additions & 3 deletions src/joserfc/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any
from typing import Any, Optional
import base64
import struct
import binascii
Expand Down Expand Up @@ -48,11 +48,16 @@ def base64_to_int(s: str) -> int:
return int("".join(["%02x" % byte for byte in buf]), 16)


def int_to_base64(num: int) -> str:
def int_to_base64(num: int, byte_count: Optional[int] = None) -> str:
if num < 0:
raise ValueError("Must be a positive integer")

s = num.to_bytes((num.bit_length() + 7) // 8, "big", signed=False)
if byte_count is None:
byte_count = (num.bit_length() + 7) // 8
elif num.bit_length() > byte_count * 8:
raise ValueError("Number too large for byte count")

s = num.to_bytes(byte_count, "big", signed=False)
return urlsafe_b64encode(s).decode("utf-8", "strict")


Expand Down
51 changes: 51 additions & 0 deletions tests/jwk/test_ec_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,54 @@ def test_derive_key_with_different_hash(self):
key1 = ECKey.derive_key("ec-secret-key", "P-256", kdf_options={"algorithm": hashes.SHA256()})
key2 = ECKey.derive_key("ec-secret-key", "P-256", kdf_options={"algorithm": hashes.SHA512()})
self.assertNotEqual(key1, key2)

def run_verify_full_size(self, curve_name: str, expected_base64_count: int):
"""
Verifies that the full-size keys (private and public) generated using the specified curve conform to the expected
Base64-encoded string length for their respective components. The checks involve generating keys that could lead
to truncated values when encoded and ensuring their lengths match the specified expectation.

See section: https://datatracker.ietf.org/doc/html/rfc7518#section-6.2

Parameters:
curve_name (str): The name of the elliptic curve to use for key generation.
expected_base64_count (int): The expected length of the Base64-encoded key components (x, y, d).

Raises:
AssertionError: Raised if any of the generated private or public key components fail to match the expected lengths.
"""
private_key = ECKey.generate_key(curve_name)
# find the number which requires one less byte(octet) than a full padding
byte_count = (private_key.curve_key_size + 7) // 8
lower_cap = pow(2, 8 * (byte_count - 1))
attempts_remaining = 1000000

# now generate keys until we find a parameter which could be truncated
while (
private_key.public_key.public_numbers().x >= lower_cap
and private_key.public_key.public_numbers().y >= lower_cap
and private_key.private_key.private_numbers().private_value >= lower_cap
):
private_key = ECKey.generate_key(curve_name)
attempts_remaining -= 1
if attempts_remaining == 0:
raise AssertionError("Failed to find a key parameter that could be truncated")

output_private = private_key.as_dict(private=True)
self.assertEqual(expected_base64_count, len(output_private["x"]))
self.assertEqual(expected_base64_count, len(output_private["y"]))
self.assertEqual(expected_base64_count, len(output_private["d"]))

pub_key = ECKey.import_key(private_key.public_key)
output_public = pub_key.as_dict(private=False)
self.assertEqual(expected_base64_count, len(output_public["x"]))
self.assertEqual(expected_base64_count, len(output_public["y"]))

def test_p256_full_size(self):
self.run_verify_full_size("P-256", 43)

def test_p384_full_size(self):
self.run_verify_full_size("P-384", 64)

def test_p521_full_size(self):
self.run_verify_full_size("P-521", 88)
Loading