Skip to content

Commit

Permalink
Merge pull request #19 from libp2p/fix/thread-safe-close
Browse files Browse the repository at this point in the history
Make close thread safe
  • Loading branch information
Stebalien committed Sep 5, 2017
2 parents e57234e + da59505 commit a34b3e7
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 5 deletions.
25 changes: 20 additions & 5 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package websocket
import (
"io"
"net"
"sync"
"time"

ws "github.com/gorilla/websocket"
)

// GracefulCloseTimeout is the time to wait trying to gracefully close a
// connection before simply cutting it.
var GracefulCloseTimeout = 100 * time.Millisecond

var _ net.Conn = (*Conn)(nil)

// Conn implements net.Conn interface for gorilla/websocket.
Expand All @@ -16,6 +21,7 @@ type Conn struct {
DefaultMessageType int
done func()
reader io.Reader
closeOnce sync.Once
}

func (c *Conn) Read(b []byte) (int, error) {
Expand Down Expand Up @@ -73,13 +79,22 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return len(b), nil
}

// Close closes the connection. Only the first call to Close will receive the
// close error, subsequent and concurrent calls will return nil.
// This method is thread-safe.
func (c *Conn) Close() error {
if c.done != nil {
c.done()
}
var err error
c.closeOnce.Do(func() {
if c.done != nil {
c.done()
// Be nice to GC
c.done = nil
}

c.Conn.WriteMessage(ws.CloseMessage, nil)
return c.Conn.Close()
c.Conn.WriteControl(ws.CloseMessage, nil, time.Now().Add(GracefulCloseTimeout))
err = c.Conn.Close()
})
return err
}

func (c *Conn) LocalAddr() net.Addr {
Expand Down
90 changes: 90 additions & 0 deletions p2p/transport/websocket/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package websocket

import (
"bytes"
"io"
"io/ioutil"
"testing"
"testing/iotest"
Expand Down Expand Up @@ -53,3 +54,92 @@ func TestWebsocketListen(t *testing.T) {
t.Fatal("got wrong message", out, msg)
}
}

func TestConcurrentClose(t *testing.T) {
zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws")
if err != nil {
t.Fatal(err)
}

tpt := &WebsocketTransport{}
l, err := tpt.Listen(zero)
if err != nil {
t.Fatal(err)
}
defer l.Close()

msg := []byte("HELLO WORLD")

go func() {
d, _ := tpt.Dialer(nil)
for i := 0; i < 100; i++ {
c, err := d.Dial(l.Multiaddr())
if err != nil {
t.Error(err)
return
}

go c.Write(msg)
go c.Close()
}
}()

for i := 0; i < 100; i++ {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
c.Close()
}
}

func TestWriteZero(t *testing.T) {
zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws")
if err != nil {
t.Fatal(err)
}

tpt := &WebsocketTransport{}
l, err := tpt.Listen(zero)
if err != nil {
t.Fatal(err)
}
defer l.Close()

msg := []byte(nil)

go func() {
d, _ := tpt.Dialer(nil)
c, err := d.Dial(l.Multiaddr())
defer c.Close()
if err != nil {
t.Error(err)
return
}

for i := 0; i < 100; i++ {
n, err := c.Write(msg)
if n != 0 {
t.Errorf("expected to write 0 bytes, wrote %d", n)
}
if err != nil {
t.Error(err)
return
}
}
}()

c, err := l.Accept()
defer c.Close()
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 100)
n, err := c.Read(buf)
if n != 0 {
t.Errorf("read %d bytes, expected 0", n)
}
if err != io.EOF {
t.Errorf("expected EOF, got err: %s", err)
}
}

0 comments on commit a34b3e7

Please sign in to comment.