From 1bb8b5e6beacb2b36e108635292b2e2941466c81 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 9 Dec 2023 14:02:04 +0100 Subject: [PATCH] ssh: expose negotiated algorithms Fixes golang/go#58523 Fixes golang/go#46638 Change-Id: Ic64bd2fdd6e9ec96acac3ed4be842e2fbb15231d --- ssh/cipher.go | 16 ++++++------ ssh/cipher_test.go | 12 ++++----- ssh/client.go | 1 + ssh/common.go | 38 +++++++++++++++-------------- ssh/common_test.go | 34 +++++++++++++------------- ssh/connection.go | 12 +++++++++ ssh/handshake.go | 22 ++++++++++------- ssh/handshake_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++- ssh/server.go | 1 + ssh/transport.go | 8 +++--- 10 files changed, 138 insertions(+), 63 deletions(-) diff --git a/ssh/cipher.go b/ssh/cipher.go index e4611d4068..f8fe83a4c1 100644 --- a/ssh/cipher.go +++ b/ssh/cipher.go @@ -58,11 +58,11 @@ func newRC4(key, iv []byte) (cipher.Stream, error) { type cipherMode struct { keySize int ivSize int - create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) + create func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) } -func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) { - return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { +func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { + return func(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { stream, err := createFunc(key, iv) if err != nil { return nil, err @@ -307,7 +307,7 @@ type gcmCipher struct { buf []byte } -func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { +func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) { c, err := aes.NewCipher(key) if err != nil { return nil, err @@ -429,7 +429,7 @@ type cbcCipher struct { oracleCamouflage uint32 } -func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { +func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { cbc := &cbcCipher{ mac: macModes[algs.MAC].new(macKey), decrypter: cipher.NewCBCDecrypter(c, iv), @@ -443,7 +443,7 @@ func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorith return cbc, nil } -func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { +func newAESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { c, err := aes.NewCipher(key) if err != nil { return nil, err @@ -457,7 +457,7 @@ func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCi return cbc, nil } -func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { +func newTripleDESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { c, err := des.NewTripleDESCipher(key) if err != nil { return nil, err @@ -646,7 +646,7 @@ type chacha20Poly1305Cipher struct { buf []byte } -func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { +func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) { if len(key) != 64 { panic(len(key)) } diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go index 32304a4ed4..8d9a81f14c 100644 --- a/ssh/cipher_test.go +++ b/ssh/cipher_test.go @@ -44,10 +44,10 @@ func TestPacketCiphers(t *testing.T) { func testPacketCipher(t *testing.T, cipher, mac string) { kr := &kexResult{Hash: crypto.SHA1} - algs := directionAlgorithms{ + algs := DirectionAlgorithms{ Cipher: cipher, MAC: mac, - Compression: compressionNone, + compression: compressionNone, } client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { @@ -77,10 +77,10 @@ func testPacketCipher(t *testing.T, cipher, mac string) { func TestCBCOracleCounterMeasure(t *testing.T) { kr := &kexResult{Hash: crypto.SHA1} - algs := directionAlgorithms{ + algs := DirectionAlgorithms{ Cipher: InsecureCipherAES128CBC, MAC: InsecureHMACSHA1, - Compression: compressionNone, + compression: compressionNone, } client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { @@ -204,10 +204,10 @@ func TestCVE202143565(t *testing.T) { mac := HMACSHA256 kr := &kexResult{Hash: crypto.SHA1} - algs := directionAlgorithms{ + algs := DirectionAlgorithms{ Cipher: tc.cipher, MAC: mac, - Compression: compressionNone, + compression: compressionNone, } client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { diff --git a/ssh/client.go b/ssh/client.go index fd8c49749e..33079789bc 100644 --- a/ssh/client.go +++ b/ssh/client.go @@ -110,6 +110,7 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e } c.sessionID = c.transport.getSessionID() + c.algorithms = c.transport.getAlgorithms() return c.clientAuthenticate(config) } diff --git a/ssh/common.go b/ssh/common.go index c94f305940..14b1b7d37f 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -168,6 +168,14 @@ var ( insecurePubKeyAuthAlgos = []string{KeyAlgoRSA, InsecureKeyAlgoDSA} ) +// NegotiatedAlgorithms defines algorithms negotiated between client and server. +type NegotiatedAlgorithms struct { + KeyExchange string + HostKey string + Read DirectionAlgorithms + Write DirectionAlgorithms +} + // Algorithms defines a set of algorithms that can be configured in the client // or server config for negotiation during a handshake. type Algorithms struct { @@ -278,15 +286,16 @@ func findCommon(what string, client []string, server []string) (common string, e return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) } -// directionAlgorithms records algorithm choices in one direction (either read or write) -type directionAlgorithms struct { +// DirectionAlgorithms defines the algorithms negotiated in one direction +// (either read or write). +type DirectionAlgorithms struct { Cipher string MAC string - Compression string + compression string } // rekeyBytes returns a rekeying intervals in bytes. -func (a *directionAlgorithms) rekeyBytes() int64 { +func (a *DirectionAlgorithms) rekeyBytes() int64 { // According to RFC 4344 block ciphers should rekey after // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is // 128. @@ -306,27 +315,20 @@ var aeadCiphers = map[string]bool{ CipherChacha20Poly1305: true, } -type algorithms struct { - kex string - hostKey string - w directionAlgorithms - r directionAlgorithms -} - -func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { - result := &algorithms{} +func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *NegotiatedAlgorithms, err error) { + result := &NegotiatedAlgorithms{} - result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + result.KeyExchange, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) if err != nil { return } - result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + result.HostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) if err != nil { return } - stoc, ctos := &result.w, &result.r + stoc, ctos := &result.Write, &result.Read if isClient { ctos, stoc = stoc, ctos } @@ -355,12 +357,12 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs } } - ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + ctos.compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) if err != nil { return } - stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + stoc.compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) if err != nil { return } diff --git a/ssh/common_test.go b/ssh/common_test.go index a7beee8e88..67cf1f4269 100644 --- a/ssh/common_test.go +++ b/ssh/common_test.go @@ -51,33 +51,33 @@ func TestFindAgreedAlgorithms(t *testing.T) { } } - initDirAlgs := func(a *directionAlgorithms) { + initDirAlgs := func(a *DirectionAlgorithms) { if a.Cipher == "" { a.Cipher = "cipher1" } if a.MAC == "" { a.MAC = "mac1" } - if a.Compression == "" { - a.Compression = "compression1" + if a.compression == "" { + a.compression = "compression1" } } - initAlgs := func(a *algorithms) { - if a.kex == "" { - a.kex = "kex1" + initAlgs := func(a *NegotiatedAlgorithms) { + if a.KeyExchange == "" { + a.KeyExchange = "kex1" } - if a.hostKey == "" { - a.hostKey = "hostkey1" + if a.HostKey == "" { + a.HostKey = "hostkey1" } - initDirAlgs(&a.r) - initDirAlgs(&a.w) + initDirAlgs(&a.Read) + initDirAlgs(&a.Write) } type testcase struct { name string clientIn, serverIn kexInitMsg - wantClient, wantServer algorithms + wantClient, wantServer NegotiatedAlgorithms wantErr bool } @@ -120,19 +120,19 @@ func TestFindAgreedAlgorithms(t *testing.T) { CiphersClientServer: []string{"cipher2", "cipher1"}, CiphersServerClient: []string{"cipher3", "cipher2"}, }, - wantClient: algorithms{ - r: directionAlgorithms{ + wantClient: NegotiatedAlgorithms{ + Read: DirectionAlgorithms{ Cipher: "cipher3", }, - w: directionAlgorithms{ + Write: DirectionAlgorithms{ Cipher: "cipher2", }, }, - wantServer: algorithms{ - w: directionAlgorithms{ + wantServer: NegotiatedAlgorithms{ + Write: DirectionAlgorithms{ Cipher: "cipher3", }, - r: directionAlgorithms{ + Read: DirectionAlgorithms{ Cipher: "cipher2", }, }, diff --git a/ssh/connection.go b/ssh/connection.go index 8f345ee924..613a71a7b3 100644 --- a/ssh/connection.go +++ b/ssh/connection.go @@ -74,6 +74,13 @@ type Conn interface { // Disconnect } +// AlgorithmsConnMetadata is a ConnMetadata that can return the algorithms +// negotiated between client and server. +type AlgorithmsConnMetadata interface { + ConnMetadata + Algorithms() NegotiatedAlgorithms +} + // DiscardRequests consumes and rejects all requests from the // passed-in channel. func DiscardRequests(in <-chan *Request) { @@ -106,6 +113,7 @@ type sshConn struct { sessionID []byte clientVersion []byte serverVersion []byte + algorithms NegotiatedAlgorithms } func dup(src []byte) []byte { @@ -141,3 +149,7 @@ func (c *sshConn) ClientVersion() []byte { func (c *sshConn) ServerVersion() []byte { return dup(c.serverVersion) } + +func (c *sshConn) Algorithms() NegotiatedAlgorithms { + return c.algorithms +} diff --git a/ssh/handshake.go b/ssh/handshake.go index 8edfb69d5a..7f6ee18a4a 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -34,7 +34,7 @@ type keyingTransport interface { // prepareKeyChange sets up a key change. The key change for a // direction will be effected if a msgNewKeys message is sent // or received. - prepareKeyChange(*algorithms, *kexResult) error + prepareKeyChange(*NegotiatedAlgorithms, *kexResult) error // setStrictMode sets the strict KEX mode, notably triggering // sequence number resets on sending or receiving msgNewKeys. @@ -102,7 +102,7 @@ type handshakeTransport struct { bannerCallback BannerCallback // Algorithms agreed in the last key exchange. - algorithms *algorithms + algorithms *NegotiatedAlgorithms // Counters exclusively owned by readLoop. readPacketsLeft uint32 @@ -170,6 +170,10 @@ func (t *handshakeTransport) getSessionID() []byte { return t.sessionID } +func (t *handshakeTransport) getAlgorithms() NegotiatedAlgorithms { + return *t.algorithms +} + // waitSession waits for the session to be established. This should be // the first thing to call after instantiating handshakeTransport. func (t *handshakeTransport) waitSession() error { @@ -275,7 +279,7 @@ func (t *handshakeTransport) resetWriteThresholds() { if t.config.RekeyThreshold > 0 { t.writeBytesLeft = int64(t.config.RekeyThreshold) } else if t.algorithms != nil { - t.writeBytesLeft = t.algorithms.w.rekeyBytes() + t.writeBytesLeft = t.algorithms.Write.rekeyBytes() } else { t.writeBytesLeft = 1 << 30 } @@ -390,7 +394,7 @@ func (t *handshakeTransport) resetReadThresholds() { if t.config.RekeyThreshold > 0 { t.readBytesLeft = int64(t.config.RekeyThreshold) } else if t.algorithms != nil { - t.readBytesLeft = t.algorithms.r.rekeyBytes() + t.readBytesLeft = t.algorithms.Read.rekeyBytes() } else { t.readBytesLeft = 1 << 30 } @@ -664,9 +668,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { } } - kex, ok := kexAlgoMap[t.algorithms.kex] + kex, ok := kexAlgoMap[t.algorithms.KeyExchange] if !ok { - return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.KeyExchange) } var result *kexResult @@ -773,12 +777,12 @@ func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner { } func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { - hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey) + hostKey := pickHostKey(t.hostKeys, t.algorithms.HostKey) if hostKey == nil { return nil, errors.New("ssh: internal error: negotiated unsupported signature type") } - r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey) + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.HostKey) return r, err } @@ -793,7 +797,7 @@ func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) ( return nil, err } - if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil { + if err := verifyHostKeySignature(hostKey, t.algorithms.HostKey, result); err != nil { return nil, err } diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index 9c59d53c5b..c967c92a05 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -367,7 +367,7 @@ type errorKeyingTransport struct { readLeft, writeLeft int } -func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { +func (n *errorKeyingTransport) prepareKeyChange(*NegotiatedAlgorithms, *kexResult) error { return nil } @@ -1019,3 +1019,58 @@ func TestStrictKEXMixed(t *testing.T) { t.Fatalf("client.waitSession: %v", err) } } + +func TestNegotiatedAlgorithms(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAlgorithms NegotiatedAlgorithms + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if algorithmConn, ok := conn.(AlgorithmsConnMetadata); ok { + serverAlgorithms = algorithmConn.Algorithms() + return &Permissions{}, nil + } + return nil, errors.New("server conn does not implement AlgorithmsConnMetadata") + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + } + + if conn, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } else { + if algorithmConn, ok := conn.(AlgorithmsConnMetadata); ok { + clientAlgorithms := algorithmConn.Algorithms() + if clientAlgorithms.HostKey == "" { + t.Fatal("negotiated client host key is empty") + } + if clientAlgorithms.KeyExchange == "" { + t.Fatal("negotiated client KEX is empty") + } + if clientAlgorithms.Read.Cipher == "" { + t.Fatal("negotiated client read cipher is empty") + } + if clientAlgorithms.Write.Cipher == "" { + t.Fatal("negotiated client write cipher is empty") + } + if !reflect.DeepEqual(clientAlgorithms, serverAlgorithms) { + t.Fatalf("negotiated client algorithms: %+v differs from negotiated server algorithms: %+v", + clientAlgorithms, serverAlgorithms) + } + } else { + t.Fatal("client conn does not implement AlgorithmsConnMetadata") + } + } +} diff --git a/ssh/server.go b/ssh/server.go index 30c4376bc8..314dba0f9b 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -281,6 +281,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) // We just did the key change, so the session ID is established. s.sessionID = s.transport.getSessionID() + s.algorithms = s.transport.getAlgorithms() var packet []byte if packet, err = s.transport.readPacket(); err != nil { diff --git a/ssh/transport.go b/ssh/transport.go index 9e51085f8b..663619845c 100644 --- a/ssh/transport.go +++ b/ssh/transport.go @@ -85,14 +85,14 @@ func (t *transport) setInitialKEXDone() { // prepareKeyChange sets up key material for a keychange. The key changes in // both directions are triggered by reading and writing a msgNewKey packet // respectively. -func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { - ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult) +func (t *transport) prepareKeyChange(algs *NegotiatedAlgorithms, kexResult *kexResult) error { + ciph, err := newPacketCipher(t.reader.dir, algs.Read, kexResult) if err != nil { return err } t.reader.pendingKeyChange <- ciph - ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult) + ciph, err = newPacketCipher(t.writer.dir, algs.Write, kexResult) if err != nil { return err } @@ -252,7 +252,7 @@ var ( // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as // described in RFC 4253, section 6.4. direction should either be serverKeys // (to setup server->client keys) or clientKeys (for client->server keys). -func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { +func newPacketCipher(d direction, algs DirectionAlgorithms, kex *kexResult) (packetCipher, error) { cipherMode := cipherModes[algs.Cipher] iv := make([]byte, cipherMode.ivSize)