Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

implement an early data API #110

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
init, resp := net.Pipe()
_ = resp.Close()

session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true)
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true, nil)
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")
Expand Down
57 changes: 28 additions & 29 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {

if s.initiator {
// stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret.
// Handshake Msg Len = len(DH ephemeral key)
err = s.sendHandshakeMessage(hs, nil, hbuf)
if err != nil {
if err := s.sendHandshakeMessage(hs, s.earlyData, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}

Expand All @@ -82,44 +80,45 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil {
return err
}

// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
} else {
// stage 0 //
// We don't expect any payload on the first message.
if _, err := s.readHandshakeMessage(hs); err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}

// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
//MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
return nil
}

// stage 2 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
// stage 0 //
// We don't expect any payload on the first message.
initialPayload, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler(initialPayload); err != nil {
return err
}
} else if len(initialPayload) > 0 {
return fmt.Errorf("received unexpected early data (%d bytes)", len(initialPayload))
}

return nil
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}

// stage 2 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
}

// setCipherStates sets the initial cipher states that will be used to protect
Expand Down
19 changes: 12 additions & 7 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type secureSession struct {
insecureReader *bufio.Reader // to cushion io read syscalls
// we don't buffer writes to avoid introducing latency; optimisation possible. // TODO revisit

earlyData []byte
earlyDataHandler func([]byte) error

qseek int // queued bytes seek value.
qbuf []byte // queued bytes buffer.
rlen [2]byte // work buffer to read in the incoming message length.
Expand All @@ -38,14 +41,16 @@ type secureSession struct {

// newSecureSession creates a Noise session over the given insecureConn Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) {
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool, earlyData []byte) (*secureSession, error) {
s := &secureSession{
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
earlyData: earlyData,
earlyDataHandler: tpt.earlyDataHandler,
}

// the go-routine we create to run the handshake will
Expand Down
36 changes: 31 additions & 5 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,60 @@ const ID = "/noise"

var _ sec.SecureTransport = &Transport{}

type Option func(*Transport) error

// WithEarlyDataHandler specifies a handler for early data sent by the initiator.
// If the error returned is non-nil, the handshake is aborted.
func WithEarlyDataHandler(h func([]byte) error) Option {
return func(t *Transport) error {
t.earlyDataHandler = h
return nil
}
}

// Transport implements the interface sec.SecureTransport
// https://godoc.org/github.com/libp2p/go-libp2p-core/sec#SecureConn
type Transport struct {
localID peer.ID
privateKey crypto.PrivKey

earlyDataHandler func([]byte) error
}

// New creates a new Noise transport using the given private key as its
// libp2p identity key.
func New(privkey crypto.PrivKey) (*Transport, error) {
func New(privkey crypto.PrivKey, opts ...Option) (*Transport, error) {
localID, err := peer.IDFromPrivateKey(privkey)
if err != nil {
return nil, err
}

return &Transport{
t := &Transport{
localID: localID,
privateKey: privkey,
}, nil
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}

return t, nil
}

// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, false)
return newSecureSession(t, ctx, insecure, p, false, nil)
}

// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, true)
return newSecureSession(t, ctx, insecure, p, true, nil)
}

// SecureOutboundWithEarlyData runs the Noise handshake as the initiator.
// earlyData is sent (unencrypted!) along with the first handshake message.
func (t *Transport) SecureOutboundWithEarlyData(ctx context.Context, insecure net.Conn, p peer.ID, earlyData []byte) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, true, earlyData)
}
63 changes: 41 additions & 22 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,22 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"golang.org/x/crypto/poly1305"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newTestTransport(t *testing.T, typ, bits int) *Transport {
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
if err != nil {
t.Fatal(err)
}
id, err := peer.IDFromPublicKey(pub)
if err != nil {
t.Fatal(err)
}
return &Transport{
localID: id,
privateKey: priv,
}
func newTestTransport(t *testing.T, typ, bits int, opts ...Option) *Transport {
t.Helper()
priv, _, err := crypto.GenerateKeyPair(typ, bits)
require.NoError(t, err)
tr, err := New(priv, opts...)
require.NoError(t, err)
return tr
}

// Create a new pair of connected TCP sockets.
Expand Down Expand Up @@ -85,14 +77,27 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
<-done

if initErr != nil {
t.Fatal(initErr)
}
require.NoError(t, initErr)
require.NoError(t, respErr)
return initConn.(*secureSession), respConn.(*secureSession)
}

if respErr != nil {
t.Fatal(respErr)
}
func connectWithEarlyData(t *testing.T, initTransport, respTransport *Transport, earlyData []byte) (*secureSession, *secureSession) {
init, resp := newConnPair(t)

var initConn sec.SecureConn
var initErr error
done := make(chan struct{})
go func() {
defer close(done)
initConn, initErr = initTransport.SecureOutboundWithEarlyData(context.TODO(), init, respTransport.localID, earlyData)
}()

respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
<-done

require.NoError(t, initErr)
require.NoError(t, respErr)
return initConn.(*secureSession), respConn.(*secureSession)
}

Expand Down Expand Up @@ -373,3 +378,17 @@ func TestReadUnencryptedFails(t *testing.T) {
require.Error(t, err)
require.Equal(t, 0, afterLen)
}

func TestEarlyData(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
earlyDataChan := make(chan []byte, 1)
respTransport := newTestTransport(t, crypto.Ed25519, 2048, WithEarlyDataHandler(func(b []byte) error {
earlyDataChan <- b
return nil
}))

initConn, respConn := connectWithEarlyData(t, initTransport, respTransport, []byte("foobar"))
defer initConn.Close()
defer respConn.Close()
require.Equal(t, []byte("foobar"), <-earlyDataChan)
}