Skip to content

Commit

Permalink
ssh: expose negotiated algorithms
Browse files Browse the repository at this point in the history
Fixes golang/go#58523
Fixes golang/go#46638

Change-Id: Ic64bd2fdd6e9ec96acac3ed4be842e2fbb15231d
  • Loading branch information
drakkan committed Aug 11, 2024
1 parent 893767f commit 7e15e34
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 63 deletions.
16 changes: 8 additions & 8 deletions ssh/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
type cipherMode struct {
keySize int
ivSize int
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
create func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error)
}

func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv)
if err != nil {
return nil, err
Expand Down Expand Up @@ -307,7 +307,7 @@ type gcmCipher struct {
buf []byte
}

func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
Expand Down Expand Up @@ -429,7 +429,7 @@ type cbcCipher struct {
oracleCamouflage uint32
}

func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
Expand All @@ -443,7 +443,7 @@ func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorith
return cbc, nil
}

func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newAESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
Expand All @@ -457,7 +457,7 @@ func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCi
return cbc, nil
}

func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
func newTripleDESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
Expand Down Expand Up @@ -646,7 +646,7 @@ type chacha20Poly1305Cipher struct {
buf []byte
}

func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) {
if len(key) != 64 {
panic(len(key))
}
Expand Down
12 changes: 6 additions & 6 deletions ssh/cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func TestPacketCiphers(t *testing.T) {

func testPacketCipher(t *testing.T, cipher, mac string) {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
algs := DirectionAlgorithms{
Cipher: cipher,
MAC: mac,
Compression: compressionNone,
compression: compressionNone,
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
Expand Down Expand Up @@ -77,10 +77,10 @@ func testPacketCipher(t *testing.T, cipher, mac string) {

func TestCBCOracleCounterMeasure(t *testing.T) {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
algs := DirectionAlgorithms{
Cipher: InsecureCipherAES128CBC,
MAC: InsecureHMACSHA1,
Compression: compressionNone,
compression: compressionNone,
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
Expand Down Expand Up @@ -204,10 +204,10 @@ func TestCVE202143565(t *testing.T) {
mac := HMACSHA256

kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
algs := DirectionAlgorithms{
Cipher: tc.cipher,
MAC: mac,
Compression: compressionNone,
compression: compressionNone,
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
}

c.sessionID = c.transport.getSessionID()
c.algorithms = c.transport.getAlgorithms()
return c.clientAuthenticate(config)
}

Expand Down
38 changes: 20 additions & 18 deletions ssh/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ var (
insecurePubKeyAuthAlgos = []string{KeyAlgoRSA, InsecureKeyAlgoDSA}
)

// NegotiatedAlgorithms defines algorithms negotiated between client and server.
type NegotiatedAlgorithms struct {
KeyExchange string
HostKey string
Read DirectionAlgorithms
Write DirectionAlgorithms
}

// Algorithms defines a set of algorithms that can be configured in the client
// or server config for negotiation during a handshake.
type Algorithms struct {
Expand Down Expand Up @@ -278,15 +286,16 @@ func findCommon(what string, client []string, server []string) (common string, e
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
}

// directionAlgorithms records algorithm choices in one direction (either read or write)
type directionAlgorithms struct {
// DirectionAlgorithms defines the algorithms negotiated in one direction
// (either read or write).
type DirectionAlgorithms struct {
Cipher string
MAC string
Compression string
compression string
}

// rekeyBytes returns a rekeying intervals in bytes.
func (a *directionAlgorithms) rekeyBytes() int64 {
func (a *DirectionAlgorithms) rekeyBytes() int64 {
// According to RFC 4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128.
Expand All @@ -306,27 +315,20 @@ var aeadCiphers = map[string]bool{
CipherChacha20Poly1305: true,
}

type algorithms struct {
kex string
hostKey string
w directionAlgorithms
r directionAlgorithms
}

func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
result := &algorithms{}
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *NegotiatedAlgorithms, err error) {
result := &NegotiatedAlgorithms{}

result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
result.KeyExchange, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil {
return
}

result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
result.HostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if err != nil {
return
}

stoc, ctos := &result.w, &result.r
stoc, ctos := &result.Write, &result.Read
if isClient {
ctos, stoc = stoc, ctos
}
Expand Down Expand Up @@ -355,12 +357,12 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs
}
}

ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
ctos.compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if err != nil {
return
}

stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
stoc.compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if err != nil {
return
}
Expand Down
34 changes: 17 additions & 17 deletions ssh/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,33 @@ func TestFindAgreedAlgorithms(t *testing.T) {
}
}

initDirAlgs := func(a *directionAlgorithms) {
initDirAlgs := func(a *DirectionAlgorithms) {
if a.Cipher == "" {
a.Cipher = "cipher1"
}
if a.MAC == "" {
a.MAC = "mac1"
}
if a.Compression == "" {
a.Compression = "compression1"
if a.compression == "" {
a.compression = "compression1"
}
}

initAlgs := func(a *algorithms) {
if a.kex == "" {
a.kex = "kex1"
initAlgs := func(a *NegotiatedAlgorithms) {
if a.KeyExchange == "" {
a.KeyExchange = "kex1"
}
if a.hostKey == "" {
a.hostKey = "hostkey1"
if a.HostKey == "" {
a.HostKey = "hostkey1"
}
initDirAlgs(&a.r)
initDirAlgs(&a.w)
initDirAlgs(&a.Read)
initDirAlgs(&a.Write)
}

type testcase struct {
name string
clientIn, serverIn kexInitMsg
wantClient, wantServer algorithms
wantClient, wantServer NegotiatedAlgorithms
wantErr bool
}

Expand Down Expand Up @@ -120,19 +120,19 @@ func TestFindAgreedAlgorithms(t *testing.T) {
CiphersClientServer: []string{"cipher2", "cipher1"},
CiphersServerClient: []string{"cipher3", "cipher2"},
},
wantClient: algorithms{
r: directionAlgorithms{
wantClient: NegotiatedAlgorithms{
Read: DirectionAlgorithms{
Cipher: "cipher3",
},
w: directionAlgorithms{
Write: DirectionAlgorithms{
Cipher: "cipher2",
},
},
wantServer: algorithms{
w: directionAlgorithms{
wantServer: NegotiatedAlgorithms{
Write: DirectionAlgorithms{
Cipher: "cipher3",
},
r: directionAlgorithms{
Read: DirectionAlgorithms{
Cipher: "cipher2",
},
},
Expand Down
12 changes: 12 additions & 0 deletions ssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ type Conn interface {
// Disconnect
}

// AlgorithmsConnMetadata is a ConnMetadata that can return the algorithms
// negotiated between client and server.
type AlgorithmsConnMetadata interface {
ConnMetadata
Algorithms() NegotiatedAlgorithms
}

// DiscardRequests consumes and rejects all requests from the
// passed-in channel.
func DiscardRequests(in <-chan *Request) {
Expand Down Expand Up @@ -106,6 +113,7 @@ type sshConn struct {
sessionID []byte
clientVersion []byte
serverVersion []byte
algorithms NegotiatedAlgorithms
}

func dup(src []byte) []byte {
Expand Down Expand Up @@ -141,3 +149,7 @@ func (c *sshConn) ClientVersion() []byte {
func (c *sshConn) ServerVersion() []byte {
return dup(c.serverVersion)
}

func (c *sshConn) Algorithms() NegotiatedAlgorithms {
return c.algorithms
}
22 changes: 13 additions & 9 deletions ssh/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type keyingTransport interface {
// prepareKeyChange sets up a key change. The key change for a
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
prepareKeyChange(*NegotiatedAlgorithms, *kexResult) error

// setStrictMode sets the strict KEX mode, notably triggering
// sequence number resets on sending or receiving msgNewKeys.
Expand Down Expand Up @@ -102,7 +102,7 @@ type handshakeTransport struct {
bannerCallback BannerCallback

// Algorithms agreed in the last key exchange.
algorithms *algorithms
algorithms *NegotiatedAlgorithms

// Counters exclusively owned by readLoop.
readPacketsLeft uint32
Expand Down Expand Up @@ -170,6 +170,10 @@ func (t *handshakeTransport) getSessionID() []byte {
return t.sessionID
}

func (t *handshakeTransport) getAlgorithms() NegotiatedAlgorithms {
return *t.algorithms
}

// waitSession waits for the session to be established. This should be
// the first thing to call after instantiating handshakeTransport.
func (t *handshakeTransport) waitSession() error {
Expand Down Expand Up @@ -275,7 +279,7 @@ func (t *handshakeTransport) resetWriteThresholds() {
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
t.writeBytesLeft = t.algorithms.Write.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
Expand Down Expand Up @@ -390,7 +394,7 @@ func (t *handshakeTransport) resetReadThresholds() {
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
t.readBytesLeft = t.algorithms.Read.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
Expand Down Expand Up @@ -664,9 +668,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
}
}

kex, ok := kexAlgoMap[t.algorithms.kex]
kex, ok := kexAlgoMap[t.algorithms.KeyExchange]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.KeyExchange)
}

var result *kexResult
Expand Down Expand Up @@ -773,12 +777,12 @@ func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
}

func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
hostKey := pickHostKey(t.hostKeys, t.algorithms.HostKey)
if hostKey == nil {
return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
}

r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.HostKey)
return r, err
}

Expand All @@ -793,7 +797,7 @@ func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (
return nil, err
}

if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
if err := verifyHostKeySignature(hostKey, t.algorithms.HostKey, result); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit 7e15e34

Please sign in to comment.