Skip to content

Commit

Permalink
[FIXED] Call ConnectedCB with RetryOnFailedConnect when initial conn …
Browse files Browse the repository at this point in the history
…failed

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Apr 19, 2024
1 parent 9d4b227 commit 8d178b4
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 29 deletions.
15 changes: 10 additions & 5 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -2875,15 +2875,20 @@ func (nc *Conn) doReconnect(err error) {
// This is where we are truly connected.
nc.status = CONNECTED

// Queue up the correct callback. If we are in initial connect state
// (using retry on failed connect), we will call the ConnectedCB,
// otherwise the ReconnectedCB.
if nc.Opts.ReconnectedCB != nil && !nc.initc {
nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) })
} else if nc.Opts.ConnectedCB != nil && nc.initc {
fmt.Println()
nc.ach.push(func() { nc.Opts.ConnectedCB(nc) })
}

// If we are here with a retry on failed connect, indicate that the
// initial connect is now complete.
nc.initc = false

// Queue up the reconnect callback.
if nc.Opts.ReconnectedCB != nil {
nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) })
}

// Release lock here, we will return below.
nc.mu.Unlock()

Expand Down
120 changes: 96 additions & 24 deletions test/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1094,16 +1094,21 @@ func TestCallbacksOrder(t *testing.T) {
}

func TestConnectHandler(t *testing.T) {
handler := func(ch chan bool) func(*nats.Conn) {
return func(*nats.Conn) {
ch <- true
}
}
t.Run("with RetryOnFailedConnect, connection established", func(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler),
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true))

if err != nil {
Expand All @@ -1113,59 +1118,126 @@ func TestConnectHandler(t *testing.T) {
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("with RetryOnFailedConnect, connection failed", func(t *testing.T) {
connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler),
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
select {
case <-connected:
t.Fatalf("ConnectedCB invoked when no connection established")
case <-time.After(100 * time.Millisecond):
if err = WaitTime(connected, 100*time.Millisecond); err == nil {
t.Fatal("Connected handler should not have been invoked")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("no RetryOnFailedConnect, connection established", func(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)
nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler))
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)))

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("no RetryOnFailedConnect, connection failed", func(t *testing.T) {
connected := make(chan bool)
connHandler := func(*nats.Conn) {
connected <- true
}
reconnected := make(chan bool)
_, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(connHandler))
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)))

if err == nil {
t.Fatalf("Expected error on connect, got nil")
}
select {
case <-connected:
t.Fatalf("ConnectedCB invoked when no connection established")
case <-time.After(100 * time.Millisecond):
if err = WaitTime(connected, 100*time.Millisecond); err == nil {
t.Fatal("Connected handler should not have been invoked")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("with RetryOnFailedConnect, initial connection failed, reconnect successful", func(t *testing.T) {
connected := make(chan bool)
reconnected := make(chan bool)

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true),
nats.ReconnectWait(100*time.Millisecond))
defer nc.Close()

s := RunDefaultServer()
defer s.Shutdown()

if err != nil {
t.Fatalf("Expected error on connect, got nil")
}
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for reconnect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}
})
t.Run("with RetryOnFailedConnect, initial connection successful, server restart", func(t *testing.T) {
connected := make(chan bool)
reconnected := make(chan bool)

s := RunDefaultServer()
defer s.Shutdown()

nc, err := nats.Connect(nats.DefaultURL,
nats.ConnectHandler(handler(connected)),
nats.ReconnectHandler(handler(reconnected)),
nats.RetryOnFailedConnect(true),
nats.ReconnectWait(100*time.Millisecond))
defer nc.Close()

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if err = Wait(connected); err != nil {
t.Fatal("Timeout waiting for connect handler")
}
if err = WaitTime(reconnected, 100*time.Millisecond); err == nil {
t.Fatal("Reconnect handler should not have been invoked")
}

s.Shutdown()

s = RunDefaultServer()
defer s.Shutdown()

if err = Wait(reconnected); err != nil {
t.Fatal("Timeout waiting for reconnect handler")
}
if err = WaitTime(connected, 100*time.Millisecond); err == nil {
t.Fatal("Connected handler should not have been invoked")
}
})
}
Expand Down

0 comments on commit 8d178b4

Please sign in to comment.