mirror of
https://github.com/golang/go.git
synced 2025-12-28 06:34:04 +00:00
crypto/hpke: don't corrupt enc's excess capacity in DHKEM decap
Caught because the one-shop APIs put the ciphertext after enc in a single slice, so Recipient.Open would corrupt the ciphertext. Change-Id: I15fe1dfcc05a5a7f5cd0b4ada21661e66a6a6964 Reviewed-on: https://go-review.googlesource.com/c/go/+/728500 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Roland Shoemaker <roland@golang.org> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
parent
cd873cf7e9
commit
db0ab834d6
@ -69,6 +69,183 @@ func Example() {
|
||||
// Decrypted message: |-()-|
|
||||
}
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
kems := []KEM{
|
||||
DHKEM(ecdh.P256()),
|
||||
DHKEM(ecdh.P384()),
|
||||
DHKEM(ecdh.P521()),
|
||||
DHKEM(ecdh.X25519()),
|
||||
MLKEM768(),
|
||||
MLKEM1024(),
|
||||
MLKEM768P256(),
|
||||
MLKEM1024P384(),
|
||||
MLKEM768X25519(),
|
||||
}
|
||||
kdfs := []KDF{
|
||||
HKDFSHA256(),
|
||||
HKDFSHA384(),
|
||||
HKDFSHA512(),
|
||||
SHAKE128(),
|
||||
SHAKE256(),
|
||||
}
|
||||
aeads := []AEAD{
|
||||
AES128GCM(),
|
||||
AES256GCM(),
|
||||
ChaCha20Poly1305(),
|
||||
}
|
||||
|
||||
for _, kem := range kems {
|
||||
t.Run(fmt.Sprintf("KEM_%04x", kem.ID()), func(t *testing.T) {
|
||||
k, err := kem.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kb, err := k.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kk, err := kem.NewPrivateKey(kb)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, err := kk.Bytes(); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !bytes.Equal(got, kb) {
|
||||
t.Errorf("re-serialized key mismatch: got %x, want %x", got, kb)
|
||||
}
|
||||
pk, err := kem.NewPublicKey(k.PublicKey().Bytes())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := pk.Bytes(); !bytes.Equal(got, k.PublicKey().Bytes()) {
|
||||
t.Errorf("re-serialized public key mismatch: got %x, want %x", got, k.PublicKey().Bytes())
|
||||
}
|
||||
|
||||
for _, kdf := range kdfs {
|
||||
t.Run(fmt.Sprintf("KDF_%04x", kdf.ID()), func(t *testing.T) {
|
||||
for _, aead := range aeads {
|
||||
t.Run(fmt.Sprintf("AEAD_%04x", aead.ID()), func(t *testing.T) {
|
||||
c, err := Seal(pk, kdf, aead, []byte("info"), []byte("plaintext"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p, err := Open(kk, kdf, aead, []byte("info"), c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(p, []byte("plaintext")) {
|
||||
t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
|
||||
}
|
||||
|
||||
p, err = Open(kk, kdf, aead, []byte("wrong"), c)
|
||||
if err == nil {
|
||||
t.Errorf("expected error when opening with wrong info, got plaintext %x", p)
|
||||
}
|
||||
c[len(c)-1] ^= 0xFF
|
||||
p, err = Open(kk, kdf, aead, []byte("info"), c)
|
||||
if err == nil {
|
||||
t.Errorf("expected error when opening with corrupted ciphertext, got plaintext %x", p)
|
||||
}
|
||||
|
||||
c, err = Seal(k.PublicKey(), kdf, aead, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p, err = Open(k, kdf, aead, nil, c)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(p) != 0 {
|
||||
t.Errorf("unexpected plaintext: got %x, want empty", p)
|
||||
}
|
||||
|
||||
// Test that Seal and Open don't modify the excess capacity of input
|
||||
// slices. This is a regression test for a bug where decap would
|
||||
// append to the enc slice, corrupting the ciphertext if they shared
|
||||
// a backing array.
|
||||
padSlice := func(b []byte) []byte {
|
||||
s := make([]byte, len(b), len(b)+2000)
|
||||
copy(s, b)
|
||||
for i := len(b); i < cap(s); i++ {
|
||||
s[:cap(s)][i] = 0xAA
|
||||
}
|
||||
return s[:len(b)]
|
||||
}
|
||||
checkSlice := func(name string, s []byte) {
|
||||
for i := len(s); i < cap(s); i++ {
|
||||
if s[:cap(s)][i] != 0xAA {
|
||||
t.Errorf("%s: modified byte at index %d beyond slice length", name, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infoS := padSlice([]byte("info"))
|
||||
plaintextS := padSlice([]byte("plaintext"))
|
||||
c, err = Seal(pk, kdf, aead, infoS, plaintextS)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkSlice("Seal info", infoS)
|
||||
checkSlice("Seal plaintext", plaintextS)
|
||||
|
||||
infoO := padSlice([]byte("info"))
|
||||
ciphertextO := padSlice(c)
|
||||
p, err = Open(kk, kdf, aead, infoO, ciphertextO)
|
||||
if err != nil {
|
||||
t.Fatalf("Open with large capacity slices failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(p, []byte("plaintext")) {
|
||||
t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
|
||||
}
|
||||
checkSlice("Open info", infoO)
|
||||
checkSlice("Open ciphertext", ciphertextO)
|
||||
|
||||
// Also test the Sender.Seal and Recipient.Open methods.
|
||||
infoSender := padSlice([]byte("info"))
|
||||
enc, sender, err := NewSender(pk, kdf, aead, infoSender)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkSlice("NewSender info", infoSender)
|
||||
|
||||
aadSeal := padSlice([]byte("aad"))
|
||||
plaintextSeal := padSlice([]byte("plaintext"))
|
||||
ct, err := sender.Seal(aadSeal, plaintextSeal)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkSlice("Sender.Seal aad", aadSeal)
|
||||
checkSlice("Sender.Seal plaintext", plaintextSeal)
|
||||
|
||||
infoRecipient := padSlice([]byte("info"))
|
||||
encPadded := padSlice(enc)
|
||||
recipient, err := NewRecipient(encPadded, kk, kdf, aead, infoRecipient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkSlice("NewRecipient info", infoRecipient)
|
||||
checkSlice("NewRecipient enc", encPadded)
|
||||
|
||||
aadOpen := padSlice([]byte("aad"))
|
||||
ctPadded := padSlice(ct)
|
||||
p, err = recipient.Open(aadOpen, ctPadded)
|
||||
if err != nil {
|
||||
t.Fatalf("Recipient.Open failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(p, []byte("plaintext")) {
|
||||
t.Errorf("unexpected plaintext: got %x, want %x", p, []byte("plaintext"))
|
||||
}
|
||||
checkSlice("Recipient.Open aad", aadOpen)
|
||||
checkSlice("Recipient.Open ciphertext", ctPadded)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustDecodeHex(t *testing.T, in string) []byte {
|
||||
t.Helper()
|
||||
b, err := hex.DecodeString(in)
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"crypto/internal/rand"
|
||||
"errors"
|
||||
"internal/byteorder"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// A KEM is a Key Encapsulation Mechanism, one of the three components of an
|
||||
@ -377,6 +378,6 @@ func (k *dhKEMPrivateKey) decap(encPubEph []byte) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kemContext := append(encPubEph, k.priv.PublicKey().Bytes()...)
|
||||
kemContext := append(slices.Clip(encPubEph), k.priv.PublicKey().Bytes()...)
|
||||
return k.kem.extractAndExpand(dhVal, kemContext)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user