diff --git a/src/crypto/hpke/hpke_test.go b/src/crypto/hpke/hpke_test.go index b54a234fe5..ceb33263a6 100644 --- a/src/crypto/hpke/hpke_test.go +++ b/src/crypto/hpke/hpke_test.go @@ -69,6 +69,183 @@ func Example() { // Decrypted message: |-()-| } +func TestRoundTrip(t *testing.T) { + kems := []KEM{ + DHKEM(ecdh.P256()), + DHKEM(ecdh.P384()), + DHKEM(ecdh.P521()), + DHKEM(ecdh.X25519()), + MLKEM768(), + MLKEM1024(), + MLKEM768P256(), + MLKEM1024P384(), + MLKEM768X25519(), + } + kdfs := []KDF{ + HKDFSHA256(), + HKDFSHA384(), + HKDFSHA512(), + SHAKE128(), + SHAKE256(), + } + aeads := []AEAD{ + AES128GCM(), + AES256GCM(), + ChaCha20Poly1305(), + } + + for _, kem := range kems { + t.Run(fmt.Sprintf("KEM_%04x", kem.ID()), func(t *testing.T) { + k, err := kem.GenerateKey() + if err != nil { + t.Fatal(err) + } + kb, err := k.Bytes() + if err != nil { + t.Fatal(err) + } + kk, err := kem.NewPrivateKey(kb) + if err != nil { + t.Fatal(err) + } + if got, err := kk.Bytes(); err != nil { + t.Fatal(err) + } else if !bytes.Equal(got, kb) { + t.Errorf("re-serialized key mismatch: got %x, want %x", got, kb) + } + pk, err := kem.NewPublicKey(k.PublicKey().Bytes()) + if err != nil { + t.Fatal(err) + } + if got := pk.Bytes(); !bytes.Equal(got, k.PublicKey().Bytes()) { + t.Errorf("re-serialized public key mismatch: got %x, want %x", got, k.PublicKey().Bytes()) + } + + for _, kdf := range kdfs { + t.Run(fmt.Sprintf("KDF_%04x", kdf.ID()), func(t *testing.T) { + for _, aead := range aeads { + t.Run(fmt.Sprintf("AEAD_%04x", aead.ID()), func(t *testing.T) { + c, err := Seal(pk, kdf, aead, []byte("info"), []byte("plaintext")) + if err != nil { + t.Fatal(err) + } + p, err := Open(kk, kdf, aead, []byte("info"), c) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(p, []byte("plaintext")) { + t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext")) + } + + p, err = Open(kk, kdf, aead, []byte("wrong"), c) + if err == nil { + t.Errorf("expected error when opening with wrong info, got plaintext %x", p) + } + c[len(c)-1] ^= 0xFF + p, err = Open(kk, kdf, aead, []byte("info"), c) + if err == nil { + t.Errorf("expected error when opening with corrupted ciphertext, got plaintext %x", p) + } + + c, err = Seal(k.PublicKey(), kdf, aead, nil, nil) + if err != nil { + t.Fatal(err) + } + p, err = Open(k, kdf, aead, nil, c) + if err != nil { + t.Fatal(err) + } + if len(p) != 0 { + t.Errorf("unexpected plaintext: got %x, want empty", p) + } + + // Test that Seal and Open don't modify the excess capacity of input + // slices. This is a regression test for a bug where decap would + // append to the enc slice, corrupting the ciphertext if they shared + // a backing array. + padSlice := func(b []byte) []byte { + s := make([]byte, len(b), len(b)+2000) + copy(s, b) + for i := len(b); i < cap(s); i++ { + s[:cap(s)][i] = 0xAA + } + return s[:len(b)] + } + checkSlice := func(name string, s []byte) { + for i := len(s); i < cap(s); i++ { + if s[:cap(s)][i] != 0xAA { + t.Errorf("%s: modified byte at index %d beyond slice length", name, i) + return + } + } + } + + infoS := padSlice([]byte("info")) + plaintextS := padSlice([]byte("plaintext")) + c, err = Seal(pk, kdf, aead, infoS, plaintextS) + if err != nil { + t.Fatal(err) + } + checkSlice("Seal info", infoS) + checkSlice("Seal plaintext", plaintextS) + + infoO := padSlice([]byte("info")) + ciphertextO := padSlice(c) + p, err = Open(kk, kdf, aead, infoO, ciphertextO) + if err != nil { + t.Fatalf("Open with large capacity slices failed: %v", err) + } + if !bytes.Equal(p, []byte("plaintext")) { + t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext")) + } + checkSlice("Open info", infoO) + checkSlice("Open ciphertext", ciphertextO) + + // Also test the Sender.Seal and Recipient.Open methods. + infoSender := padSlice([]byte("info")) + enc, sender, err := NewSender(pk, kdf, aead, infoSender) + if err != nil { + t.Fatal(err) + } + checkSlice("NewSender info", infoSender) + + aadSeal := padSlice([]byte("aad")) + plaintextSeal := padSlice([]byte("plaintext")) + ct, err := sender.Seal(aadSeal, plaintextSeal) + if err != nil { + t.Fatal(err) + } + checkSlice("Sender.Seal aad", aadSeal) + checkSlice("Sender.Seal plaintext", plaintextSeal) + + infoRecipient := padSlice([]byte("info")) + encPadded := padSlice(enc) + recipient, err := NewRecipient(encPadded, kk, kdf, aead, infoRecipient) + if err != nil { + t.Fatal(err) + } + checkSlice("NewRecipient info", infoRecipient) + checkSlice("NewRecipient enc", encPadded) + + aadOpen := padSlice([]byte("aad")) + ctPadded := padSlice(ct) + p, err = recipient.Open(aadOpen, ctPadded) + if err != nil { + t.Fatalf("Recipient.Open failed: %v", err) + } + if !bytes.Equal(p, []byte("plaintext")) { + t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext")) + } + checkSlice("Recipient.Open aad", aadOpen) + checkSlice("Recipient.Open ciphertext", ctPadded) + }) + } + }) + } + }) + } +} + func mustDecodeHex(t *testing.T, in string) []byte { t.Helper() b, err := hex.DecodeString(in) diff --git a/src/crypto/hpke/kem.go b/src/crypto/hpke/kem.go index 7633aa2b71..132e0a754c 100644 --- a/src/crypto/hpke/kem.go +++ b/src/crypto/hpke/kem.go @@ -9,6 +9,7 @@ import ( "crypto/internal/rand" "errors" "internal/byteorder" + "slices" ) // A KEM is a Key Encapsulation Mechanism, one of the three components of an @@ -377,6 +378,6 @@ func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) { if err != nil { return nil, err } - kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...) + kemContext := append(slices.Clip(encPubEph), k.priv.PublicKey().Bytes()...) return k.kem.extractAndExpand(dhVal, kemContext) }