Skip to content

Commit

Permalink
ssh: invert algorithm choices on the server
Browse files Browse the repository at this point in the history
At the protocol level, SSH lets client and server specify different
algorithms for the read and write half of the connection. This has
never worked correctly, as Client-to-Server was always interpreted as
the "write" side, even if we were the server.

This has never been a problem because, apparently, there are no
clients that insist on different algorithm choices running against Go
SSH servers.

Since the SSH package does not expose a mechanism to specify
algorithms for read/write separately, there is end-to-end for this
change, so add a unittest instead.

Change-Id: Ie3aa781630a3bb7a3b0e3754cb67b3ce12581544
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/172538
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
  • Loading branch information
hanwen committed Apr 18, 2019
1 parent b43e412 commit df01cb2
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 9 deletions.
20 changes: 13 additions & 7 deletions ssh/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func findCommon(what string, client []string, server []string) (common string, e
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
}

// directionAlgorithms records algorithm choices in one direction (either read or write)
type directionAlgorithms struct {
Cipher string
MAC string
Expand Down Expand Up @@ -137,7 +138,7 @@ type algorithms struct {
r directionAlgorithms
}

func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
result := &algorithms{}

result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
Expand All @@ -150,32 +151,37 @@ func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algor
return
}

result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
stoc, ctos := &result.w, &result.r
if isClient {
ctos, stoc = stoc, ctos
}

ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if err != nil {
return
}

result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if err != nil {
return
}

result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if err != nil {
return
}

result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if err != nil {
return
}

result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if err != nil {
return
}

result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if err != nil {
return
}
Expand Down
176 changes: 176 additions & 0 deletions ssh/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssh

import (
"reflect"
"testing"
)

func TestFindAgreedAlgorithms(t *testing.T) {
initKex := func(k *kexInitMsg) {
if k.KexAlgos == nil {
k.KexAlgos = []string{"kex1"}
}
if k.ServerHostKeyAlgos == nil {
k.ServerHostKeyAlgos = []string{"hostkey1"}
}
if k.CiphersClientServer == nil {
k.CiphersClientServer = []string{"cipher1"}

}
if k.CiphersServerClient == nil {
k.CiphersServerClient = []string{"cipher1"}

}
if k.MACsClientServer == nil {
k.MACsClientServer = []string{"mac1"}

}
if k.MACsServerClient == nil {
k.MACsServerClient = []string{"mac1"}

}
if k.CompressionClientServer == nil {
k.CompressionClientServer = []string{"compression1"}

}
if k.CompressionServerClient == nil {
k.CompressionServerClient = []string{"compression1"}

}
if k.LanguagesClientServer == nil {
k.LanguagesClientServer = []string{"language1"}

}
if k.LanguagesServerClient == nil {
k.LanguagesServerClient = []string{"language1"}

}
}

initDirAlgs := func(a *directionAlgorithms) {
if a.Cipher == "" {
a.Cipher = "cipher1"
}
if a.MAC == "" {
a.MAC = "mac1"
}
if a.Compression == "" {
a.Compression = "compression1"
}
}

initAlgs := func(a *algorithms) {
if a.kex == "" {
a.kex = "kex1"
}
if a.hostKey == "" {
a.hostKey = "hostkey1"
}
initDirAlgs(&a.r)
initDirAlgs(&a.w)
}

type testcase struct {
name string
clientIn, serverIn kexInitMsg
wantClient, wantServer algorithms
wantErr bool
}

cases := []testcase{
testcase{
name: "standard",
},

testcase{
name: "no common hostkey",
serverIn: kexInitMsg{
ServerHostKeyAlgos: []string{"hostkey2"},
},
wantErr: true,
},

testcase{
name: "no common kex",
serverIn: kexInitMsg{
KexAlgos: []string{"kex2"},
},
wantErr: true,
},

testcase{
name: "no common cipher",
serverIn: kexInitMsg{
CiphersClientServer: []string{"cipher2"},
},
wantErr: true,
},

testcase{
name: "client decides cipher",
serverIn: kexInitMsg{
CiphersClientServer: []string{"cipher1", "cipher2"},
CiphersServerClient: []string{"cipher2", "cipher3"},
},
clientIn: kexInitMsg{
CiphersClientServer: []string{"cipher2", "cipher1"},
CiphersServerClient: []string{"cipher3", "cipher2"},
},
wantClient: algorithms{
r: directionAlgorithms{
Cipher: "cipher3",
},
w: directionAlgorithms{
Cipher: "cipher2",
},
},
wantServer: algorithms{
w: directionAlgorithms{
Cipher: "cipher3",
},
r: directionAlgorithms{
Cipher: "cipher2",
},
},
},

// TODO(hanwen): fix and add tests for AEAD ignoring
// the MACs field
}

for i := range cases {
initKex(&cases[i].clientIn)
initKex(&cases[i].serverIn)
initAlgs(&cases[i].wantClient)
initAlgs(&cases[i].wantServer)
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)

serverHasErr := serverErr != nil
clientHasErr := clientErr != nil
if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
t.Fatalf("got client/server error (%v, %v), want hasError %v",
clientErr, serverErr, c.wantErr)

}
if c.wantErr {
return
}

if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
}
if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
}
})
}
}
5 changes: 3 additions & 2 deletions ssh/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,16 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {

clientInit := otherInit
serverInit := t.sentInitMsg
if len(t.hostKeys) == 0 {
isClient := len(t.hostKeys) == 0
if isClient {
clientInit, serverInit = serverInit, clientInit

magics.clientKexInit = t.sentInitPacket
magics.serverKexInit = otherInitPacket
}

var err error
t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
if err != nil {
return err
}
Expand Down

0 comments on commit df01cb2

Please sign in to comment.