diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 1d8de2a01f..168497bf39 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -3,11 +3,16 @@ package websocket import ( "io" "net" + "sync" "time" ws "github.com/gorilla/websocket" ) +// GracefulCloseTimeout is the time to wait trying to gracefully close a +// connection before simply cutting it. +var GracefulCloseTimeout = 100 * time.Millisecond + var _ net.Conn = (*Conn)(nil) // Conn implements net.Conn interface for gorilla/websocket. @@ -16,6 +21,7 @@ type Conn struct { DefaultMessageType int done func() reader io.Reader + closeOnce sync.Once } func (c *Conn) Read(b []byte) (int, error) { @@ -73,13 +79,22 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } +// Close closes the connection. Only the first call to Close will receive the +// close error, subsequent and concurrent calls will return nil. +// This method is thread-safe. func (c *Conn) Close() error { - if c.done != nil { - c.done() - } + var err error + c.closeOnce.Do(func() { + if c.done != nil { + c.done() + // Be nice to GC + c.done = nil + } - c.Conn.WriteMessage(ws.CloseMessage, nil) - return c.Conn.Close() + c.Conn.WriteControl(ws.CloseMessage, nil, time.Now().Add(GracefulCloseTimeout)) + err = c.Conn.Close() + }) + return err } func (c *Conn) LocalAddr() net.Addr { diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 7aab918576..61bfa14230 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "io" "io/ioutil" "testing" "testing/iotest" @@ -53,3 +54,92 @@ func TestWebsocketListen(t *testing.T) { t.Fatal("got wrong message", out, msg) } } + +func TestConcurrentClose(t *testing.T) { + zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws") + if err != nil { + t.Fatal(err) + } + + tpt := &WebsocketTransport{} + l, err := tpt.Listen(zero) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := []byte("HELLO WORLD") + + go func() { + d, _ := tpt.Dialer(nil) + for i := 0; i < 100; i++ { + c, err := d.Dial(l.Multiaddr()) + if err != nil { + t.Error(err) + return + } + + go c.Write(msg) + go c.Close() + } + }() + + for i := 0; i < 100; i++ { + c, err := l.Accept() + if err != nil { + t.Fatal(err) + } + c.Close() + } +} + +func TestWriteZero(t *testing.T) { + zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws") + if err != nil { + t.Fatal(err) + } + + tpt := &WebsocketTransport{} + l, err := tpt.Listen(zero) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := []byte(nil) + + go func() { + d, _ := tpt.Dialer(nil) + c, err := d.Dial(l.Multiaddr()) + defer c.Close() + if err != nil { + t.Error(err) + return + } + + for i := 0; i < 100; i++ { + n, err := c.Write(msg) + if n != 0 { + t.Errorf("expected to write 0 bytes, wrote %d", n) + } + if err != nil { + t.Error(err) + return + } + } + }() + + c, err := l.Accept() + defer c.Close() + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 100) + n, err := c.Read(buf) + if n != 0 { + t.Errorf("read %d bytes, expected 0", n) + } + if err != io.EOF { + t.Errorf("expected EOF, got err: %s", err) + } +}