From fe671ac5d0fc3793de8e62425025bbdfe306f532 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Apr 2022 18:35:27 +0100 Subject: [PATCH] implement an early data API --- crypto_test.go | 2 +- handshake.go | 57 +++++++++++++++++++++--------------------- session.go | 19 ++++++++------ transport.go | 36 +++++++++++++++++++++++---- transport_test.go | 63 ++++++++++++++++++++++++++++++----------------- 5 files changed, 113 insertions(+), 64 deletions(-) diff --git a/crypto_test.go b/crypto_test.go index ca5125c..0a06e97 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) { init, resp := net.Pipe() _ = resp.Close() - session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true) + session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true, nil) _, err := session.encrypt(nil, []byte("hi")) if err == nil { t.Error("expected encryption error when handshake incomplete") diff --git a/handshake.go b/handshake.go index d70a529..ee94480 100644 --- a/handshake.go +++ b/handshake.go @@ -70,10 +70,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error { if s.initiator { // stage 0 // - // do not send the payload just yet, as it would be plaintext; not secret. // Handshake Msg Len = len(DH ephemeral key) - err = s.sendHandshakeMessage(hs, nil, hbuf) - if err != nil { + if err := s.sendHandshakeMessage(hs, s.earlyData, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } @@ -82,44 +80,45 @@ func (s *secureSession) runHandshake(ctx context.Context) error { if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) - if err != nil { + if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil { return err } // stage 2 // // Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted) - err = s.sendHandshakeMessage(hs, payload, hbuf) - if err != nil { - return fmt.Errorf("error sending handshake message: %w", err) - } - } else { - // stage 0 // - // We don't expect any payload on the first message. - if _, err := s.readHandshakeMessage(hs); err != nil { - return fmt.Errorf("error reading handshake message: %w", err) - } - - // stage 1 // - // Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + - //MAC(payload is encrypted) - err = s.sendHandshakeMessage(hs, payload, hbuf) - if err != nil { + if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } + return nil + } - // stage 2 // - plaintext, err := s.readHandshakeMessage(hs) - if err != nil { - return fmt.Errorf("error reading handshake message: %w", err) - } - err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) - if err != nil { + // stage 0 // + // We don't expect any payload on the first message. + initialPayload, err := s.readHandshakeMessage(hs) + if err != nil { + return fmt.Errorf("error reading handshake message: %w", err) + } + if s.earlyDataHandler != nil { + if err := s.earlyDataHandler(initialPayload); err != nil { return err } + } else if len(initialPayload) > 0 { + return fmt.Errorf("received unexpected early data (%d bytes)", len(initialPayload)) } - return nil + // stage 1 // + // Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + + // MAC(payload is encrypted) + if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { + return fmt.Errorf("error sending handshake message: %w", err) + } + + // stage 2 // + plaintext, err := s.readHandshakeMessage(hs) + if err != nil { + return fmt.Errorf("error reading handshake message: %w", err) + } + return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) } // setCipherStates sets the initial cipher states that will be used to protect diff --git a/session.go b/session.go index 9675cc0..d84799a 100644 --- a/session.go +++ b/session.go @@ -28,6 +28,9 @@ type secureSession struct { insecureReader *bufio.Reader // to cushion io read syscalls // we don't buffer writes to avoid introducing latency; optimisation possible. // TODO revisit + earlyData []byte + earlyDataHandler func([]byte) error + qseek int // queued bytes seek value. qbuf []byte // queued bytes buffer. rlen [2]byte // work buffer to read in the incoming message length. @@ -38,14 +41,16 @@ type secureSession struct { // newSecureSession creates a Noise session over the given insecureConn Conn, using // the libp2p identity keypair from the given Transport. -func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) { +func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool, earlyData []byte) (*secureSession, error) { s := &secureSession{ - insecureConn: insecure, - insecureReader: bufio.NewReader(insecure), - initiator: initiator, - localID: tpt.localID, - localKey: tpt.privateKey, - remoteID: remote, + insecureConn: insecure, + insecureReader: bufio.NewReader(insecure), + initiator: initiator, + localID: tpt.localID, + localKey: tpt.privateKey, + remoteID: remote, + earlyData: earlyData, + earlyDataHandler: tpt.earlyDataHandler, } // the go-routine we create to run the handshake will diff --git a/transport.go b/transport.go index c8d7a44..e1e246f 100644 --- a/transport.go +++ b/transport.go @@ -14,34 +14,60 @@ const ID = "/noise" var _ sec.SecureTransport = &Transport{} +type Option func(*Transport) error + +// WithEarlyDataHandler specifies a handler for early data sent by the initiator. +// If the error returned is non-nil, the handshake is aborted. +func WithEarlyDataHandler(h func([]byte) error) Option { + return func(t *Transport) error { + t.earlyDataHandler = h + return nil + } +} + // Transport implements the interface sec.SecureTransport // https://godoc.org/github.com/libp2p/go-libp2p-core/sec#SecureConn type Transport struct { localID peer.ID privateKey crypto.PrivKey + + earlyDataHandler func([]byte) error } // New creates a new Noise transport using the given private key as its // libp2p identity key. -func New(privkey crypto.PrivKey) (*Transport, error) { +func New(privkey crypto.PrivKey, opts ...Option) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err } - return &Transport{ + t := &Transport{ localID: localID, privateKey: privkey, - }, nil + } + for _, opt := range opts { + if err := opt(t); err != nil { + return nil, err + } + } + + return t, nil } // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, false) + return newSecureSession(t, ctx, insecure, p, false, nil) } // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, true) + return newSecureSession(t, ctx, insecure, p, true, nil) +} + +// SecureOutboundWithEarlyData runs the Noise handshake as the initiator. +// earlyData is sent (unencrypted!) along with the first handshake message. +func (t *Transport) SecureOutboundWithEarlyData(ctx context.Context, insecure net.Conn, p peer.ID, earlyData []byte) (sec.SecureConn, error) { + return newSecureSession(t, ctx, insecure, p, true, earlyData) } diff --git a/transport_test.go b/transport_test.go index b65b9cb..ea9be18 100644 --- a/transport_test.go +++ b/transport_test.go @@ -11,30 +11,22 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/poly1305" "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func newTestTransport(t *testing.T, typ, bits int) *Transport { - priv, pub, err := crypto.GenerateKeyPair(typ, bits) - if err != nil { - t.Fatal(err) - } - id, err := peer.IDFromPublicKey(pub) - if err != nil { - t.Fatal(err) - } - return &Transport{ - localID: id, - privateKey: priv, - } +func newTestTransport(t *testing.T, typ, bits int, opts ...Option) *Transport { + t.Helper() + priv, _, err := crypto.GenerateKeyPair(typ, bits) + require.NoError(t, err) + tr, err := New(priv, opts...) + require.NoError(t, err) + return tr } // Create a new pair of connected TCP sockets. @@ -85,14 +77,27 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "") <-done - if initErr != nil { - t.Fatal(initErr) - } + require.NoError(t, initErr) + require.NoError(t, respErr) + return initConn.(*secureSession), respConn.(*secureSession) +} - if respErr != nil { - t.Fatal(respErr) - } +func connectWithEarlyData(t *testing.T, initTransport, respTransport *Transport, earlyData []byte) (*secureSession, *secureSession) { + init, resp := newConnPair(t) + var initConn sec.SecureConn + var initErr error + done := make(chan struct{}) + go func() { + defer close(done) + initConn, initErr = initTransport.SecureOutboundWithEarlyData(context.TODO(), init, respTransport.localID, earlyData) + }() + + respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "") + <-done + + require.NoError(t, initErr) + require.NoError(t, respErr) return initConn.(*secureSession), respConn.(*secureSession) } @@ -373,3 +378,17 @@ func TestReadUnencryptedFails(t *testing.T) { require.Error(t, err) require.Equal(t, 0, afterLen) } + +func TestEarlyData(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + earlyDataChan := make(chan []byte, 1) + respTransport := newTestTransport(t, crypto.Ed25519, 2048, WithEarlyDataHandler(func(b []byte) error { + earlyDataChan <- b + return nil + })) + + initConn, respConn := connectWithEarlyData(t, initTransport, respTransport, []byte("foobar")) + defer initConn.Close() + defer respConn.Close() + require.Equal(t, []byte("foobar"), <-earlyDataChan) +}