Skip to content

Commit

Permalink
Merge pull request #143 from button-chen/master
Browse files Browse the repository at this point in the history
add connect timeout
  • Loading branch information
worg authored Oct 24, 2024
2 parents 80a8c9f + 3b00b6a commit fc01894
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
21 changes: 20 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stomp

import (
"context"
"errors"
"io"
"net"
Expand Down Expand Up @@ -66,6 +67,10 @@ type writeRequest struct {
// STOMP server is specified by network and addr. STOMP protocol
// options can be specified in opts.
func Dial(network, addr string, opts ...func(*Conn) error) (*Conn, error) {
return DialWithContext(context.Background(), network, addr, opts...)
}

func DialWithContext(ctx context.Context, network, addr string, opts ...func(*Conn) error) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
Expand All @@ -81,14 +86,18 @@ func Dial(network, addr string, opts ...func(*Conn) error) (*Conn, error) {
// so that if host has been explicitly specified it will override.
opts = append([]func(*Conn) error{ConnOpt.Host(host)}, opts...)

return Connect(c, opts...)
return ConnectWithContext(ctx, c, opts...)
}

// Connect creates a STOMP connection and performs the STOMP connect
// protocol sequence. The connection to the STOMP server has already
// been created by the program. The opts parameter provides the
// opportunity to specify STOMP protocol options.
func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error) {
return ConnectWithContext(context.Background(), conn, opts...)
}

func ConnectWithContext(ctx context.Context, conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error) {
reader := frame.NewReader(conn)
writer := frame.NewWriter(conn)

Expand Down Expand Up @@ -152,10 +161,20 @@ func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error)
return nil, err
}

connection, isNetConn := conn.(net.Conn)
deadline, ok := ctx.Deadline()
if ok && isNetConn {
connection.SetReadDeadline(deadline)
}

response, err := reader.Read()
if err != nil {
return nil, err
}
// Restore Conn-level deadlines
if ok && isNetConn {
connection.SetReadDeadline(time.Time{})
}
if response == nil {
return nil, errors.New("unexpected empty frame")
}
Expand Down
18 changes: 18 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stomp

import (
"context"
"fmt"
"io"
"time"
Expand Down Expand Up @@ -789,3 +790,20 @@ func (s *StompSuite) Test_ZeroTimeout(c *C) {

c.Assert(err, IsNil)
}

func (s *StompSuite) Test_ConnectWithContext(c *C) {
fc1, fc2 := testutil.NewFakeConn(c)

go func() {
buff := make([]byte, 1024)
fc2.Read(buff)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

_, err := ConnectWithContext(ctx, fc1)
// the err here is "io timeout" because the server did not reply to any stomp message
// and the connection waited longer than the 5 seconds we set
c.Assert(err, NotNil)
}
24 changes: 17 additions & 7 deletions testutil/fake_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ func (addr *FakeAddr) String() string {
// the net.Conn interface and is useful for simulating I/O between
// STOMP clients and a STOMP server.
type FakeConn struct {
C *C
writer io.WriteCloser
reader io.ReadCloser
localAddr net.Addr
remoteAddr net.Addr
C *C
writer io.WriteCloser
reader io.ReadCloser
localAddr net.Addr
remoteAddr net.Addr
readDeadline time.Time
}

var (
ErrClosing = errors.New("use of closed network connection")
ErrClosing = errors.New("use of closed network connection")
ErrIOTimeout = errors.New("io timeout")
)

// NewFakeConn returns a pair of fake connections suitable for
Expand Down Expand Up @@ -63,6 +65,13 @@ func NewFakeConn(c *C) (client *FakeConn, server *FakeConn) {
}

func (fc *FakeConn) Read(p []byte) (n int, err error) {
if !fc.readDeadline.IsZero() {
t := time.Until(fc.readDeadline)
if t.Seconds() > 0 {
time.Sleep(t)
}
return 0, ErrIOTimeout
}
n, err = fc.reader.Read(p)
return
}
Expand Down Expand Up @@ -105,7 +114,8 @@ func (fc *FakeConn) SetDeadline(t time.Time) error {
}

func (fc *FakeConn) SetReadDeadline(t time.Time) error {
panic("not implemented")
fc.readDeadline = t
return nil
}

func (fc *FakeConn) SetWriteDeadline(t time.Time) error {
Expand Down

0 comments on commit fc01894

Please sign in to comment.