diff --git a/interface.go b/interface.go index c3ad0c2..d48bf61 100644 --- a/interface.go +++ b/interface.go @@ -16,11 +16,9 @@ type SessionGenerator struct { PrivateKey ci.PrivKey } -// NewSession takes an insecure io.ReadWriter, sets up a TLS-like +// NewSession takes an insecure io.ReadWriter, performs a TLS-like // handshake with the other side, and returns a secure session. -// The handshake isn't run until the connection is read or written to. // See the source for the protocol details and security implementation. -// The provided Context is only needed for the duration of this function. func (sg *SessionGenerator) NewSession(ctx context.Context, insecure io.ReadWriteCloser) (Session, error) { return newSecureSession(ctx, sg.LocalID, sg.PrivateKey, insecure) } @@ -50,9 +48,6 @@ type Session interface { // ReadWriter returns the encrypted communication channel func (s *secureSession) ReadWriter() msgio.ReadWriteCloser { - if err := s.Handshake(); err != nil { - return &closedRW{err} - } return s.secure } @@ -68,59 +63,15 @@ func (s *secureSession) LocalPrivateKey() ci.PrivKey { // RemotePeer retrieves the remote peer. func (s *secureSession) RemotePeer() peer.ID { - if err := s.Handshake(); err != nil { - return "" - } return s.remotePeer } // RemotePublicKey retrieves the remote public key. func (s *secureSession) RemotePublicKey() ci.PubKey { - if err := s.Handshake(); err != nil { - return nil - } return s.remote.permanentPubKey } // Close closes the secure session func (s *secureSession) Close() error { - s.handshakeMu.Lock() - defer s.handshakeMu.Unlock() - if s.secure == nil { - return s.insecure.Close() // hadn't secured yet. - } return s.secure.Close() } - -// closedRW implements a stub msgio interface that's already -// closed and errored. -type closedRW struct { - err error -} - -func (c *closedRW) Read(buf []byte) (int, error) { - return 0, c.err -} - -func (c *closedRW) Write(buf []byte) (int, error) { - return 0, c.err -} - -func (c *closedRW) NextMsgLen() (int, error) { - return 0, c.err -} - -func (c *closedRW) ReadMsg() ([]byte, error) { - return nil, c.err -} - -func (c *closedRW) WriteMsg(buf []byte) error { - return c.err -} - -func (c *closedRW) Close() error { - return c.err -} - -func (c *closedRW) ReleaseMsg(m []byte) { -} diff --git a/protocol.go b/protocol.go index d1a6d0a..cbaa41e 100644 --- a/protocol.go +++ b/protocol.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "sync" "time" logging "github.com/ipfs/go-log" @@ -41,8 +40,6 @@ const nonceSize = 16 // secureSession encapsulates all the parameters needed for encrypting // and decrypting traffic from an insecure channel. type secureSession struct { - ctx context.Context - secure msgio.ReadWriteCloser insecure io.ReadWriteCloser insecureM msgio.ReadWriter @@ -55,12 +52,10 @@ type secureSession struct { remote encParams sharedSecret []byte - - handshakeMu sync.Mutex // guards handshakeDone + handshakeErr - handshakeDone bool - handshakeErr error } +var _ Session = &secureSession{} + func (s *secureSession) Loggable() map[string]interface{} { m := make(map[string]interface{}) m["localPeer"] = s.localPeer.Pretty() @@ -83,25 +78,12 @@ func newSecureSession(ctx context.Context, local peer.ID, key ci.PrivKey, insecu return nil, fmt.Errorf("insecure ReadWriter is nil") } - s.ctx = ctx s.insecure = insecure s.insecureM = msgio.NewReadWriter(insecure) - return s, nil -} - -func (s *secureSession) Handshake() error { - s.handshakeMu.Lock() - defer s.handshakeMu.Unlock() - if s.handshakeErr != nil { - return s.handshakeErr - } - - if !s.handshakeDone { - s.handshakeErr = s.runHandshake() - s.handshakeDone = true - } - return s.handshakeErr + handshakeCtx, cancel := context.WithTimeout(ctx, HandshakeTimeout) // remove + defer cancel() + return s, s.runHandshake(handshakeCtx) } func hashSha256(data []byte) mh.Multihash { @@ -118,11 +100,7 @@ func hashSha256(data []byte) mh.Multihash { // runHandshake performs initial communication over insecure channel to share // keys, IDs, and initiate communication, assigning all necessary params. // requires the duplex channel to be a msgio.ReadWriter (for framed messaging) -func (s *secureSession) runHandshake() error { - defer func() { s.ctx = nil }() // clear to save memory - ctx, cancel := context.WithTimeout(s.ctx, HandshakeTimeout) // remove - defer cancel() - +func (s *secureSession) runHandshake(ctx context.Context) error { // ============================================================================= // step 1. Propose -- propose cipher suite + send pubkeys + nonce