crypto/hpke: don't corrupt enc's excess capacity in DHKEM decap

Caught because the one-shop APIs put the ciphertext after enc in a
single slice, so Recipient.Open would corrupt the ciphertext.

Change-Id: I15fe1dfcc05a5a7f5cd0b4ada21661e66a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/728500
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Filippo Valsorda 2025-01-09 16:56:37 +01:00 committed by Gopher Robot
parent cd873cf7e9
commit db0ab834d6
2 changed files with 179 additions and 1 deletions

View File

@ -69,6 +69,183 @@ func Example() {
// Decrypted message: |-()-|
}
func TestRoundTrip(t *testing.T) {
kems := []KEM{
DHKEM(ecdh.P256()),
DHKEM(ecdh.P384()),
DHKEM(ecdh.P521()),
DHKEM(ecdh.X25519()),
MLKEM768(),
MLKEM1024(),
MLKEM768P256(),
MLKEM1024P384(),
MLKEM768X25519(),
}
kdfs := []KDF{
HKDFSHA256(),
HKDFSHA384(),
HKDFSHA512(),
SHAKE128(),
SHAKE256(),
}
aeads := []AEAD{
AES128GCM(),
AES256GCM(),
ChaCha20Poly1305(),
}
for _, kem := range kems {
t.Run(fmt.Sprintf("KEM_%04x", kem.ID()), func(t *testing.T) {
k, err := kem.GenerateKey()
if err != nil {
t.Fatal(err)
}
kb, err := k.Bytes()
if err != nil {
t.Fatal(err)
}
kk, err := kem.NewPrivateKey(kb)
if err != nil {
t.Fatal(err)
}
if got, err := kk.Bytes(); err != nil {
t.Fatal(err)
} else if !bytes.Equal(got, kb) {
t.Errorf("re-serialized key mismatch: got %x, want %x", got, kb)
}
pk, err := kem.NewPublicKey(k.PublicKey().Bytes())
if err != nil {
t.Fatal(err)
}
if got := pk.Bytes(); !bytes.Equal(got, k.PublicKey().Bytes()) {
t.Errorf("re-serialized public key mismatch: got %x, want %x", got, k.PublicKey().Bytes())
}
for _, kdf := range kdfs {
t.Run(fmt.Sprintf("KDF_%04x", kdf.ID()), func(t *testing.T) {
for _, aead := range aeads {
t.Run(fmt.Sprintf("AEAD_%04x", aead.ID()), func(t *testing.T) {
c, err := Seal(pk, kdf, aead, []byte("info"), []byte("plaintext"))
if err != nil {
t.Fatal(err)
}
p, err := Open(kk, kdf, aead, []byte("info"), c)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(p, []byte("plaintext")) {
t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
}
p, err = Open(kk, kdf, aead, []byte("wrong"), c)
if err == nil {
t.Errorf("expected error when opening with wrong info, got plaintext %x", p)
}
c[len(c)-1] ^= 0xFF
p, err = Open(kk, kdf, aead, []byte("info"), c)
if err == nil {
t.Errorf("expected error when opening with corrupted ciphertext, got plaintext %x", p)
}
c, err = Seal(k.PublicKey(), kdf, aead, nil, nil)
if err != nil {
t.Fatal(err)
}
p, err = Open(k, kdf, aead, nil, c)
if err != nil {
t.Fatal(err)
}
if len(p) != 0 {
t.Errorf("unexpected plaintext: got %x, want empty", p)
}
// Test that Seal and Open don't modify the excess capacity of input
// slices. This is a regression test for a bug where decap would
// append to the enc slice, corrupting the ciphertext if they shared
// a backing array.
padSlice := func(b []byte) []byte {
s := make([]byte, len(b), len(b)+2000)
copy(s, b)
for i := len(b); i < cap(s); i++ {
s[:cap(s)][i] = 0xAA
}
return s[:len(b)]
}
checkSlice := func(name string, s []byte) {
for i := len(s); i < cap(s); i++ {
if s[:cap(s)][i] != 0xAA {
t.Errorf("%s: modified byte at index %d beyond slice length", name, i)
return
}
}
}
infoS := padSlice([]byte("info"))
plaintextS := padSlice([]byte("plaintext"))
c, err = Seal(pk, kdf, aead, infoS, plaintextS)
if err != nil {
t.Fatal(err)
}
checkSlice("Seal info", infoS)
checkSlice("Seal plaintext", plaintextS)
infoO := padSlice([]byte("info"))
ciphertextO := padSlice(c)
p, err = Open(kk, kdf, aead, infoO, ciphertextO)
if err != nil {
t.Fatalf("Open with large capacity slices failed: %v", err)
}
if !bytes.Equal(p, []byte("plaintext")) {
t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
}
checkSlice("Open info", infoO)
checkSlice("Open ciphertext", ciphertextO)
// Also test the Sender.Seal and Recipient.Open methods.
infoSender := padSlice([]byte("info"))
enc, sender, err := NewSender(pk, kdf, aead, infoSender)
if err != nil {
t.Fatal(err)
}
checkSlice("NewSender info", infoSender)
aadSeal := padSlice([]byte("aad"))
plaintextSeal := padSlice([]byte("plaintext"))
ct, err := sender.Seal(aadSeal, plaintextSeal)
if err != nil {
t.Fatal(err)
}
checkSlice("Sender.Seal aad", aadSeal)
checkSlice("Sender.Seal plaintext", plaintextSeal)
infoRecipient := padSlice([]byte("info"))
encPadded := padSlice(enc)
recipient, err := NewRecipient(encPadded, kk, kdf, aead, infoRecipient)
if err != nil {
t.Fatal(err)
}
checkSlice("NewRecipient info", infoRecipient)
checkSlice("NewRecipient enc", encPadded)
aadOpen := padSlice([]byte("aad"))
ctPadded := padSlice(ct)
p, err = recipient.Open(aadOpen, ctPadded)
if err != nil {
t.Fatalf("Recipient.Open failed: %v", err)
}
if !bytes.Equal(p, []byte("plaintext")) {
t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
}
checkSlice("Recipient.Open aad", aadOpen)
checkSlice("Recipient.Open ciphertext", ctPadded)
})
}
})
}
})
}
}
func mustDecodeHex(t *testing.T, in string) []byte {
t.Helper()
b, err := hex.DecodeString(in)

View File

@ -9,6 +9,7 @@ import (
"crypto/internal/rand"
"errors"
"internal/byteorder"
"slices"
)
// A KEM is a Key Encapsulation Mechanism, one of the three components of an
@ -377,6 +378,6 @@ func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
if err != nil {
return nil, err
}
kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...)
kemContext := append(slices.Clip(encPubEph), k.priv.PublicKey().Bytes()...)
return k.kem.extractAndExpand(dhVal, kemContext)
}