Skip to content

Commit 974aa0e

Browse files
authored
Merge pull request #180 from kimbeelen/signature-alg
Added support for EC and DSA keys. DCO check passed.
2 parents dd17127 + 60c6866 commit 974aa0e

2 files changed

Lines changed: 82 additions & 4 deletions

File tree

openleadr/messaging.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
from openleadr import errors
2424
from datetime import datetime, timezone, timedelta
2525
import os
26+
from signxml.algorithms import SignatureMethod
27+
from cryptography.hazmat.primitives import serialization
28+
from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec, ed25519, ed448
2629

2730
from openleadr import utils
2831
from .preflight import preflight_message
2932

3033
import logging
3134
logger = logging.getLogger('openleadr')
3235

33-
SIGNER = XMLSigner(method=methods.detached,
34-
c14n_algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315")
35-
SIGNER.namespaces['oadr'] = "http://openadr.org/oadr-2.0b/2012/07"
3636
VERIFIER = XMLVerifier()
3737

3838
XML_SCHEMA_LOCATION = os.path.join(os.path.dirname(__file__), 'schema', 'oadr_20b.xsd')
@@ -62,6 +62,45 @@ def parse_message(data):
6262
return message_type, message_payload
6363

6464

65+
def load_private_key(key_data, passphrase=None):
66+
"""
67+
Load the key based on key data. Supports .pem and .der keys.
68+
69+
Returns a private key object.
70+
"""
71+
passphrase_bytes = passphrase.encode() if passphrase else None
72+
try:
73+
key = serialization.load_pem_private_key(key_data, passphrase_bytes)
74+
except ValueError:
75+
try:
76+
key = serialization.load_der_private_key(key_data, passphrase_bytes)
77+
except ValueError:
78+
logger.warning("Could not load key: unknown key file format.")
79+
return key
80+
81+
82+
def get_signature_algorithm_from_private_key(key_data, passphrase=None, default_algorithm="rsa-sha256"):
83+
"""
84+
Derive a signature algorithm based on the private key type. Returns a string that can be used to lookup
85+
a signature algorithm by fragment. Algorithms are chosen based on NIST recommendations.
86+
87+
SignXML supports only RSA-, DSA- and EC-based signature methods. As XMLSigner uses RSA_SHA256 as default
88+
signature algorithm, a fragment that results in this algorithm is returned for unsupported keys.
89+
"""
90+
key = load_private_key(key_data, passphrase)
91+
if isinstance(key, rsa.RSAPrivateKey):
92+
return "rsa-sha256"
93+
elif isinstance(key, dsa.DSAPrivateKey):
94+
return "dsa-sha256"
95+
elif isinstance(key, ec.EllipticCurvePrivateKey):
96+
return "ecdsa-sha256"
97+
elif isinstance(key, ed25519.Ed25519PrivateKey):
98+
logger.warning("ED25519 keys are not supported")
99+
elif isinstance(key, ed448.Ed448PrivateKey):
100+
logger.warning("ED448 keys are not supported")
101+
return default_algorithm
102+
103+
65104
def create_message(message_type, cert=None, key=None, passphrase=None, disable_signature=False, **message_payload):
66105
"""
67106
Create and optionally sign an OpenADR message. Returns an XML string.
@@ -72,6 +111,12 @@ def create_message(message_type, cert=None, key=None, passphrase=None, disable_s
72111
envelope = TEMPLATES.get_template('oadrPayload.xml')
73112
if cert and key and not disable_signature:
74113
tree = etree.fromstring(signed_object)
114+
SIGNER = XMLSigner(
115+
method=methods.detached,
116+
c14n_algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"
117+
)
118+
SIGNER.namespaces['oadr'] = "http://openadr.org/oadr-2.0b/2012/07"
119+
SIGNER.sign_alg = SignatureMethod.from_fragment(get_signature_algorithm_from_private_key(key, passphrase))
75120
signature_tree = SIGNER.sign(tree,
76121
key=key,
77122
cert=cert,
@@ -83,7 +128,8 @@ def create_message(message_type, cert=None, key=None, passphrase=None, disable_s
83128
signature = None
84129
msg = envelope.render(template=f'{message_type}',
85130
signature=signature,
86-
signed_object=signed_object)
131+
signed_object=signed_object
132+
)
87133
logger.debug(f"Created message: {msg}")
88134
return msg
89135

test/test_signature_algorithms.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
from cryptography.hazmat.primitives import serialization
3+
from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec, ed25519, ed448
4+
from openleadr.messaging import get_signature_algorithm_from_private_key
5+
6+
7+
test_keys = {
8+
"rsa": rsa.generate_private_key(public_exponent=65537, key_size=2048),
9+
"dsa": dsa.generate_private_key(key_size=2048),
10+
"ec": ec.generate_private_key(ec.SECP256R1()),
11+
"ed25519": ed25519.Ed25519PrivateKey.generate(),
12+
"ed448": ed448.Ed448PrivateKey.generate()
13+
}
14+
15+
16+
@pytest.mark.parametrize("key_type, expected_alg", [
17+
("rsa", "rsa-sha256"),
18+
("dsa", "dsa-sha256"),
19+
("ec", "ecdsa-sha256"),
20+
("ed25519", "rsa-sha256"),
21+
("ed448", "rsa-sha256"),
22+
])
23+
def test_key_type_sign_alg_match(key_type, expected_alg):
24+
test_key = test_keys[key_type]
25+
key_encoding = serialization.Encoding.PEM
26+
key_format = serialization.PrivateFormat.PKCS8
27+
key_encryption_alg = serialization.NoEncryption()
28+
key_bytes = test_key.private_bytes(key_encoding, key_format, key_encryption_alg)
29+
30+
detected_alg = get_signature_algorithm_from_private_key(key_bytes)
31+
32+
assert detected_alg == expected_alg

0 commit comments

Comments
 (0)