diff --git a/src/crypto/tls/cfkem.go b/src/crypto/tls/cfkem.go index 083f2921b64..8d440e4c3c5 100644 --- a/src/crypto/tls/cfkem.go +++ b/src/crypto/tls/cfkem.go @@ -6,13 +6,12 @@ // To enable set CurvePreferences with the desired scheme as the first element: // // import ( -// "github.com/cloudflare/circl/kem/tls" -// "github.com/cloudflare/circl/kem/hybrid" +// "crypto/tls" // // [...] // // config.CurvePreferences = []tls.CurveID{ -// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(), +// tls.X25519Kyber768Draft00, // tls.X25519, // tls.P256, // } @@ -29,38 +28,27 @@ import ( "github.com/cloudflare/circl/kem/hybrid" ) -// Either ecdheParameters or kem.PrivateKey +// Either *ecdh.PrivateKey or *kemPrivateKey type clientKeySharePrivate interface{} +type kemPrivateKey struct { + secretKey kem.PrivateKey + curveID CurveID +} + var ( - X25519Kyber512Draft00 = CurveID(0xfe30) - X25519Kyber768Draft00 = CurveID(0xfe31) - P256Kyber768Draft00 = CurveID(0xfe32) - invalidCurveID = CurveID(0) + X25519Kyber512Draft00 = CurveID(0xfe30) + X25519Kyber768Draft00 = CurveID(0x6399) + X25519Kyber768Draft00Old = CurveID(0xfe31) + P256Kyber768Draft00 = CurveID(0xfe32) + invalidCurveID = CurveID(0) ) -func kemSchemeKeyToCurveID(s kem.Scheme) CurveID { - switch s.Name() { - case "Kyber512-X25519": - return X25519Kyber512Draft00 - case "Kyber768-X25519": - return X25519Kyber768Draft00 - case "P256Kyber768Draft00": - return P256Kyber768Draft00 - default: - return invalidCurveID - } -} - // Extract CurveID from clientKeySharePrivate func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { switch v := ks.(type) { - case kem.PrivateKey: - ret := kemSchemeKeyToCurveID(v.Scheme()) - if ret == invalidCurveID { - panic("cfkem: internal error: don't know CurveID for this KEM") - } - return ret + case *kemPrivateKey: + return v.curveID case *ecdh.PrivateKey: ret, ok := curveIDForCurve(v.Curve()) if !ok { @@ -77,7 +65,7 @@ func curveIdToCirclScheme(id CurveID) kem.Scheme { switch id { case X25519Kyber512Draft00: return hybrid.Kyber512X25519() - case X25519Kyber768Draft00: + case X25519Kyber768Draft00, X25519Kyber768Draft00Old: return hybrid.Kyber768X25519() case P256Kyber768Draft00: return hybrid.P256Kyber768Draft00() @@ -102,12 +90,12 @@ func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) ( } // Generate a new keypair using randomness from rnd. -func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) ( - kem.PublicKey, kem.PrivateKey, error) { +func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) ( + kem.PublicKey, *kemPrivateKey, error) { seed := make([]byte, scheme.SeedSize()) if _, err := io.ReadFull(rnd, seed); err != nil { return nil, nil, err } pk, sk := scheme.DeriveKeyPair(seed) - return pk, sk, nil + return pk, &kemPrivateKey{sk, curveID}, nil } diff --git a/src/crypto/tls/cfkem_test.go b/src/crypto/tls/cfkem_test.go index fb2156aa963..85da45ede8c 100644 --- a/src/crypto/tls/cfkem_test.go +++ b/src/crypto/tls/cfkem_test.go @@ -7,28 +7,16 @@ import ( "context" "fmt" "testing" - - "github.com/cloudflare/circl/kem" - "github.com/cloudflare/circl/kem/hybrid" ) -func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, +func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { var clientSelectedKEX *CurveID var retry bool - rsaCert := Certificate{ - Certificate: [][]byte{testRSACertificate}, - PrivateKey: testRSAPrivateKey, - } - serverCerts := []Certificate{rsaCert} - clientConfig := testConfig.Clone() if clientPQ { - clientConfig.CurvePreferences = []CurveID{ - kemSchemeKeyToCurveID(scheme), - X25519, - } + clientConfig.CurvePreferences = []CurveID{curveID, X25519} } clientCFEventHandler := func(ev CFEvent) { switch e := ev.(type) { @@ -44,15 +32,13 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, serverConfig := testConfig.Clone() if serverPQ { - serverConfig.CurvePreferences = []CurveID{ - kemSchemeKeyToCurveID(scheme), - X25519, - } + serverConfig.CurvePreferences = []CurveID{curveID, X25519} + } else { + serverConfig.CurvePreferences = []CurveID{X25519} } if serverTLS12 { serverConfig.MaxVersion = VersionTLS12 } - serverConfig.Certificates = serverCerts c, s := localPipe(t) done := make(chan error) @@ -78,7 +64,7 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, var expectedRetry bool if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 { - expectedKEX = kemSchemeKeyToCurveID(scheme) + expectedKEX = curveID } else { expectedKEX = X25519 } @@ -86,36 +72,35 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, expectedRetry = true } + if expectedRetry != retry { + t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) + } if clientSelectedKEX == nil { t.Error("No KEX happened?") - } - - if *clientSelectedKEX != expectedKEX { + } else if *clientSelectedKEX != expectedKEX { t.Errorf("failed to negotiate: expected %d, got %d", expectedKEX, *clientSelectedKEX) } - if expectedRetry != retry { - t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) - } } func TestHybridKEX(t *testing.T) { - run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { - t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(), + run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { + t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID), serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) { - testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12) + testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12) }) } - for _, scheme := range []kem.Scheme{ - hybrid.Kyber512X25519(), - hybrid.Kyber768X25519(), - hybrid.P256Kyber768Draft00(), + for _, curveID := range []CurveID{ + X25519Kyber512Draft00, + X25519Kyber768Draft00, + X25519Kyber768Draft00Old, + P256Kyber768Draft00, } { - run(scheme, true, true, false, false) - run(scheme, true, false, false, false) - run(scheme, false, true, false, false) - run(scheme, true, true, true, false) - run(scheme, true, true, false, true) - run(scheme, true, true, true, true) + run(curveID, true, true, false, false) + run(curveID, true, false, false, false) + run(curveID, false, true, false, false) + run(curveID, true, true, true, false) + run(curveID, true, true, false, true) + run(curveID, true, true, true, true) } } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index cfe1c25ec31..7c026f9284a 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -136,7 +136,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha curveID := config.curvePreferences()[0] if scheme := curveIdToCirclScheme(curveID); scheme != nil { - pk, sk, err := generateKemKeyPair(scheme, config.rand()) + pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand()) if err != nil { return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err) diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 6dd1627f08d..74780d5a01c 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -16,8 +16,6 @@ import ( "fmt" "hash" "time" - - circlKem "github.com/cloudflare/circl/kem" ) type clientHandshakeStateTLS13 struct { @@ -382,7 +380,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } if scheme := curveIdToCirclScheme(curveID); scheme != nil { - pk, sk, err := generateKemKeyPair(scheme, c.config.rand()) + pk, sk, err := generateKemKeyPair(scheme, curveID, c.config.rand()) if err != nil { c.sendAlert(alertInternalError) return fmt.Errorf("HRR generateKemKeyPair %s: %w", @@ -610,7 +608,8 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { if err == nil { sharedKey, _ = key.ECDH(peerKey) } - } else if sk, ok := hs.keySharePrivate.(circlKem.PrivateKey); ok { + } else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok { + sk := key.secretKey sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) if err != nil { c.sendAlert(alertIllegalParameter)