Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions lib/ocrypto/asym_encrypt_decrypt_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package ocrypto

import (
"crypto/ecdsa"
"crypto/sha256"
"crypto/x509"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func salty(s string) []byte {
Expand Down Expand Up @@ -346,3 +352,61 @@ MJseKiCRhbMS8XoCOTogO4Au9SqpOKqHq2CFRb4=
})
}
}

// TestDecryptWithCompressedEphemeralKey reproduces issue #3070:
// EC decrypt fails for P-384/P-521 when the ephemeral key is passed in
// compressed form, because UncompressECPubKey hardcodes P-256.
func TestDecryptWithCompressedEphemeralKey(t *testing.T) {
for _, tc := range []struct {
name string
mode ECCMode
}{
{"P-256", ECCModeSecp256r1},
{"P-384", ECCModeSecp384r1},
{"P-521", ECCModeSecp521r1},
} {
t.Run(tc.name, func(t *testing.T) {
salt := salty("TDF")
plainText := "virtru"

kasKeyPair, err := NewECKeyPair(tc.mode)
require.NoError(t, err)
kasPubPEM, err := kasKeyPair.PublicKeyInPemFormat()
require.NoError(t, err)
kasPriv, err := kasKeyPair.PrivateKey.ECDH()
require.NoError(t, err)

// Encrypt with the library's encryptor (generates ephemeral key internally)
encryptor, err := FromPublicPEMWithSalt(kasPubPEM, salt, nil)
require.NoError(t, err)
ciphertext, err := encryptor.Encrypt([]byte(plainText))
require.NoError(t, err)

// Compress the encryptor's ephemeral key (this is what rewrap.go does)
ephDER := encryptor.EphemeralKey()
compressedEphemeral, err := compressEphemeralDER(tc.mode, ephDER)
require.NoError(t, err)

// Decrypt with compressed ephemeral key — exercises UncompressECPubKey
ecDecryptor, err := NewSaltedECDecryptor(kasPriv, salt, nil)
require.NoError(t, err)

decrypted, err := ecDecryptor.DecryptWithEphemeralKey(ciphertext, compressedEphemeral)
require.NoError(t, err, "DecryptWithEphemeralKey should succeed for %s", tc.name)
assert.Equal(t, plainText, string(decrypted), "decrypted text should match plaintext")
})
}
}

// compressEphemeralDER parses a DER-encoded public key and returns its compressed form.
func compressEphemeralDER(mode ECCMode, der []byte) ([]byte, error) {
pub, err := x509.ParsePKIXPublicKey(der)
if err != nil {
return nil, err
}
ecPub, ok := pub.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("not an ECDSA public key")
}
return CompressedECPublicKey(mode, *ecPub)
}
40 changes: 40 additions & 0 deletions lib/ocrypto/ec_key_pair_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/sha256"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -109,6 +110,45 @@ func TestECRewrapKeyGenerate(t *testing.T) {
}
}

func TestUncompressECPubKey_CurvePreservation(t *testing.T) {
for _, tc := range []struct {
name string
mode ECCMode
}{
{"P-256", ECCModeSecp256r1},
{"P-384", ECCModeSecp384r1},
{"P-521", ECCModeSecp521r1},
} {
t.Run(tc.name, func(t *testing.T) {
keyPair, err := NewECKeyPair(tc.mode)
require.NoError(t, err)

curve, err := GetECCurveFromECCMode(tc.mode)
require.NoError(t, err)
Comment on lines +126 to +127
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The curve object can be retrieved directly from the keyPair that was just created, which already holds the curve information. This avoids a redundant call to GetECCurveFromECCMode and makes the test setup slightly more concise.

Suggested change
curve, err := GetECCurveFromECCMode(tc.mode)
require.NoError(t, err)
curve := keyPair.PrivateKey.Curve


original := keyPair.PrivateKey.PublicKey

compressed, err := CompressedECPublicKey(tc.mode, original)
require.NoError(t, err)

uncompressed, err := UncompressECPubKey(curve, compressed)
require.NoError(t, err)

// The returned key's curve must match the input curve
assert.Equal(t, curve.Params().Name, uncompressed.Curve.Params().Name,
"UncompressECPubKey returned wrong curve")

// Coordinates must survive the round-trip
assert.Equal(t, original.X, uncompressed.X, "X coordinate mismatch")
assert.Equal(t, original.Y, uncompressed.Y, "Y coordinate mismatch")

// The key must be usable for ECDH (validates point is on the declared curve)
_, err = uncompressed.ECDH()
assert.NoError(t, err, "ECDH conversion should succeed for a valid key on the correct curve")
})
}
}

func TestECDSASignature(t *testing.T) {
digest := CalculateSHA256([]byte("Virtru"))
for _, cvurve := range []ECCMode{ECCModeSecp256r1, ECCModeSecp384r1, ECCModeSecp521r1} {
Expand Down
29 changes: 29 additions & 0 deletions sdk/basekey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,35 @@ func (s *BaseKeyTestSuite) TestGetBaseKeyMissingPublicKey() {
s.Require().ErrorIs(err, ErrBaseKeyInvalidFormat)
}

// TestFormatAlg_GetKasKeyAlg_RoundTrip verifies that every supported algorithm
// survives a round-trip through the SDK's own formatAlg → getKasKeyAlg path.
// This locks in the SDK-side contract: formatAlg must produce strings that
// getKasKeyAlg maps back to the original enum.
func TestFormatAlg_GetKasKeyAlg_RoundTrip(t *testing.T) {
supportedAlgs := []struct {
name string
alg policy.Algorithm
}{
{"RSA-2048", policy.Algorithm_ALGORITHM_RSA_2048},
{"RSA-4096", policy.Algorithm_ALGORITHM_RSA_4096},
{"EC-P256", policy.Algorithm_ALGORITHM_EC_P256},
{"EC-P384", policy.Algorithm_ALGORITHM_EC_P384},
{"EC-P521", policy.Algorithm_ALGORITHM_EC_P521},
}

for _, tc := range supportedAlgs {
t.Run(tc.name, func(t *testing.T) {
formatted, err := formatAlg(tc.alg)
require.NoError(t, err, "formatAlg should not error for %s", tc.name)

roundTripped := getKasKeyAlg(formatted)
assert.Equal(t, tc.alg, roundTripped,
"round-trip mismatch: formatAlg(%s) = %q → getKasKeyAlg returned %s, want %s",
tc.name, formatted, roundTripped, tc.alg)
})
}
}

func (s *BaseKeyTestSuite) TestGetBaseKeyInvalidPublicKey() {
// Create base key with invalid public_key (string instead of map)
wellknownConfig := map[string]interface{}{
Expand Down
150 changes: 150 additions & 0 deletions service/internal/security/basic_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,153 @@ func TestBasicManager_GenerateECSessionKey(t *testing.T) {
require.Error(t, err)
})
}

// TestBasicManager_Decrypt_ECAllCurves reproduces issue #3070:
// BasicManager.Decrypt with compressed ephemeral keys fails for P-384/P-521
// because UncompressECPubKey hardcodes elliptic.P256() instead of the actual curve.
//
// The existing "successful EC decryption" test only covers P-256 and passes the
// ephemeral key in DER/PKIX format (which bypasses UncompressECPubKey entirely).
// This test uses compressed ephemeral keys to exercise the buggy code path.
func TestBasicManager_Decrypt_ECAllCurves(t *testing.T) {
log := logger.CreateTestLogger()
testCache := newTestCache(t, log)
rootKeyHex := "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"
rootKey, _ := hex.DecodeString(rootKeyHex)

bm, err := NewBasicManager(log, testCache, rootKeyHex)
require.NoError(t, err)

samplePayload := []byte("secret payload16") // 16 bytes for valid AES key

for _, tc := range []struct {
name string
mode ocrypto.ECCMode
algorithm string
}{
{"P-256", ocrypto.ECCModeSecp256r1, AlgorithmECP256R1},
{"P-384", ocrypto.ECCModeSecp384r1, AlgorithmECP384R1},
{"P-521", ocrypto.ECCModeSecp521r1, AlgorithmECP521R1},
} {
t.Run(tc.name, func(t *testing.T) {
ecKey, err := generateECKeyAndPEM(tc.mode)
require.NoError(t, err)
ecPrivKey, err := ecKey.PrivateKeyInPemFormat()
require.NoError(t, err)
ecPubKey, err := ecKey.PublicKeyInPemFormat()
require.NoError(t, err)

wrappedECPrivKeyStr, err := wrapKeyWithAESGCM([]byte(ecPrivKey), rootKey)
require.NoError(t, err)

// Encrypt using the KAS public key (generates an internal ephemeral key)
ecEncryptor, err := ocrypto.FromPublicPEM(ecPubKey)
require.NoError(t, err)
ciphertext, err := ecEncryptor.Encrypt(samplePayload)
require.NoError(t, err)

// Get ephemeral key in DER format and compress it.
// This simulates what rewrap.go does: parse PEM → compress → pass to Decrypt.
ephemeralDER := ecEncryptor.EphemeralKey()
pub, err := x509.ParsePKIXPublicKey(ephemeralDER)
require.NoError(t, err)
ecPub, ok := pub.(*ecdsa.PublicKey)
require.True(t, ok)
compressedEphemeral, err := ocrypto.CompressedECPublicKey(tc.mode, *ecPub)
require.NoError(t, err)

mockDetails := new(MockKeyDetails)
kid := fmt.Sprintf("ec-%s-decrypt", tc.name)
mockDetails.MID = kid
mockDetails.MAlgorithm = tc.algorithm
mockDetails.MPrivateKey = &policy.PrivateKeyCtx{WrappedKey: wrappedECPrivKeyStr}
mockDetails.On("ID").Return(trust.KeyIdentifier(kid))
mockDetails.On("Algorithm").Return(tc.algorithm)
mockDetails.On("ExportPrivateKey").Return(&trust.PrivateKey{
WrappingKeyID: trust.KeyIdentifier(mockDetails.MPrivateKey.GetKeyId()),
WrappedKey: mockDetails.MPrivateKey.GetWrappedKey(),
}, nil)

protectedKey, err := bm.Decrypt(t.Context(), mockDetails, ciphertext, compressedEphemeral)
require.NoError(t, err, "Decrypt should succeed for %s with compressed ephemeral key", tc.name)
require.NotNil(t, protectedKey)

noOpEnc := &noOpEncapsulator{}
decryptedPayload, err := protectedKey.Export(noOpEnc)
require.NoError(t, err)
assert.Equal(t, samplePayload, decryptedPayload, "decrypted payload should match for %s", tc.name)
})
}
}

// TestBasicManager_DeriveKey_ECAllCurves reproduces issue #3070:
// BasicManager.DeriveKey calls UncompressECPubKey directly (not via DecryptWithEphemeralKey),
// so the hardcoded P-256 causes DeriveKey to fail for P-384/P-521.
func TestBasicManager_DeriveKey_ECAllCurves(t *testing.T) {
log := logger.CreateTestLogger()
testCache := newTestCache(t, log)
rootKeyHex := "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"
rootKey, _ := hex.DecodeString(rootKeyHex)

bm, err := NewBasicManager(log, testCache, rootKeyHex)
require.NoError(t, err)

for _, tc := range []struct {
name string
mode ocrypto.ECCMode
algorithm string
curve elliptic.Curve
}{
{"P-256", ocrypto.ECCModeSecp256r1, AlgorithmECP256R1, elliptic.P256()},
{"P-384", ocrypto.ECCModeSecp384r1, AlgorithmECP384R1, elliptic.P384()},
{"P-521", ocrypto.ECCModeSecp521r1, AlgorithmECP521R1, elliptic.P521()},
} {
t.Run(tc.name, func(t *testing.T) {
// Generate KAS key pair for this curve
ecKey, err := generateECKeyAndPEM(tc.mode)
require.NoError(t, err)
ecPrivKey, err := ecKey.PrivateKeyInPemFormat()
require.NoError(t, err)

wrappedECPrivKeyStr, err := wrapKeyWithAESGCM([]byte(ecPrivKey), rootKey)
require.NoError(t, err)

// Generate client ephemeral key pair and compress the public key
clientKeyPair, err := ocrypto.NewECKeyPair(tc.mode)
require.NoError(t, err)
compressedClientKey, err := ocrypto.CompressedECPublicKey(tc.mode, clientKeyPair.PrivateKey.PublicKey)
require.NoError(t, err)

mockDetails := new(MockKeyDetails)
kid := fmt.Sprintf("ec-%s-derive", tc.name)
mockDetails.MID = kid
mockDetails.MAlgorithm = tc.algorithm
mockDetails.MPrivateKey = &policy.PrivateKeyCtx{WrappedKey: wrappedECPrivKeyStr}
mockDetails.On("ID").Return(trust.KeyIdentifier(kid))
mockDetails.On("Algorithm").Return(tc.algorithm)
mockDetails.On("ExportPrivateKey").Return(&trust.PrivateKey{
WrappingKeyID: trust.KeyIdentifier(mockDetails.MPrivateKey.GetKeyId()),
WrappedKey: mockDetails.MPrivateKey.GetWrappedKey(),
}, nil)

protectedKey, err := bm.DeriveKey(t.Context(), mockDetails, compressedClientKey, tc.curve)
require.NoError(t, err, "DeriveKey should succeed for %s", tc.name)
require.NotNil(t, protectedKey)

// Verify by computing expected key via direct ECDH (bypasses UncompressECPubKey)
kasPrivKey, err := ocrypto.ECPrivateKeyFromPem([]byte(ecPrivKey))
require.NoError(t, err)
clientPriv, err := clientKeyPair.PrivateKey.ECDH()
require.NoError(t, err)
expectedSharedSecret, err := kasPrivKey.ECDH(clientPriv.PublicKey())
require.NoError(t, err)
expectedDerivedKey, err := ocrypto.CalculateHKDF(TDFSalt(), expectedSharedSecret)
require.NoError(t, err)

noOpEnc := &noOpEncapsulator{}
actualDerivedKey, err := protectedKey.Export(noOpEnc)
require.NoError(t, err)
assert.Equal(t, expectedDerivedKey, actualDerivedKey, "derived key should match for %s", tc.name)
})
}
}
65 changes: 65 additions & 0 deletions service/pkg/db/marshalHelpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package db

import (
"testing"

"github.com/opentdf/platform/lib/ocrypto"
"github.com/opentdf/platform/protocol/go/policy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// reverseAlgMap mirrors the SDK's getKasKeyAlg mapping: ocrypto.KeyType string → policy.Algorithm.
// If FormatAlg produces a string that isn't in this map, the SDK would return ALGORITHM_UNSPECIFIED.
var reverseAlgMap = map[string]policy.Algorithm{
string(ocrypto.RSA2048Key): policy.Algorithm_ALGORITHM_RSA_2048,
string(ocrypto.RSA4096Key): policy.Algorithm_ALGORITHM_RSA_4096,
string(ocrypto.EC256Key): policy.Algorithm_ALGORITHM_EC_P256,
string(ocrypto.EC384Key): policy.Algorithm_ALGORITHM_EC_P384,
string(ocrypto.EC521Key): policy.Algorithm_ALGORITHM_EC_P521,
}

func TestFormatAlg_RoundTrip(t *testing.T) {
// Every supported algorithm must survive a round-trip:
// enum → FormatAlg(enum) → reverseAlgMap[result] → must equal original enum
// This proves FormatAlg produces strings the SDK's getKasKeyAlg can parse.
supportedAlgs := []struct {
name string
alg policy.Algorithm
}{
{"RSA-2048", policy.Algorithm_ALGORITHM_RSA_2048},
{"RSA-4096", policy.Algorithm_ALGORITHM_RSA_4096},
{"EC-P256", policy.Algorithm_ALGORITHM_EC_P256},
{"EC-P384", policy.Algorithm_ALGORITHM_EC_P384},
{"EC-P521", policy.Algorithm_ALGORITHM_EC_P521},
}

for _, tc := range supportedAlgs {
t.Run(tc.name, func(t *testing.T) {
formatted, err := FormatAlg(tc.alg)
require.NoError(t, err, "FormatAlg should not error for %s", tc.name)

roundTripped, ok := reverseAlgMap[formatted]
require.True(t, ok, "FormatAlg returned %q which is not a known ocrypto.KeyType string", formatted)
assert.Equal(t, tc.alg, roundTripped, "round-trip mismatch: FormatAlg(%s) = %q maps back to %s, not %s",
tc.name, formatted, roundTripped, tc.alg)
})
}
}

func TestFormatAlg_Unsupported(t *testing.T) {
unsupported := []struct {
name string
alg policy.Algorithm
}{
{"Unspecified", policy.Algorithm_ALGORITHM_UNSPECIFIED},
{"Invalid", policy.Algorithm(99)},
}

for _, tc := range unsupported {
t.Run(tc.name, func(t *testing.T) {
_, err := FormatAlg(tc.alg)
require.Error(t, err)
})
}
}
Loading