Skip to content

Commit

Permalink
Tidy up Mu, termination process and keep alive
Browse files Browse the repository at this point in the history
It was unclear what Client.mu protected.
Using and using contexts to manage shutdown is easier to follow.
Reimplement Pinger using Context

Closes #227
Ref #148
  • Loading branch information
MattBrittan committed Jan 15, 2024
1 parent 2aef8db commit b368787
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 221 deletions.
18 changes: 15 additions & 3 deletions autopaho/auto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,15 @@ func TestReconnect(t *testing.T) {
cancelFn func() // Function to cancel test server context
done chan struct{} // Will be closed when the test server has disconnected (and shutdown)
}
tsConnUpChan := make(chan tsConnUpMsg) // Message will be sent when test server connection is up
pahoConnUpChan := make(chan struct{}) // When autopaho reports connection is up write to channel will occur
tsConnUpChan := make(chan tsConnUpMsg, 1) // Message will be sent when test server connection is up (buffered so we can detect unexpected attempts)
pahoConnUpChan := make(chan struct{}, 1) // When autopaho reports connection is up write to channel will occur

atCount := 0

// If we don't set the pinger, paho will recreate it each time; to confirm issue #277 does not reoccur we set it
pinger := paho.NewDefaultPinger()
pinger.SetDebug(paholog.NewTestLogger(t, "pinger:"))

config := ClientConfig{
ServerUrls: []*url.URL{server},
KeepAlive: 60,
Expand All @@ -205,7 +209,8 @@ func TestReconnect(t *testing.T) {
PahoDebug: logger,
PahoErrors: logger,
ClientConfig: paho.ClientConfig{
ClientID: "test",
ClientID: "test",
PingHandler: pinger,
},
}

Expand Down Expand Up @@ -249,6 +254,13 @@ func TestReconnect(t *testing.T) {
t.Fatal("timeout awaiting reconnection up")
}

// Ensure connection is stable (ref issue #227 where pinger caused connection to drop)
select {
case <-tsConnUpChan:
t.Fatalf("connection should be stable after reconnection")
case <-time.After(shortDelay):
}

// Clean shutdown
cancel() // Cancelling outer context will cascade

Expand Down
4 changes: 2 additions & 2 deletions autopaho/examples/docker/publisher/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ func main() {
Payload: msg,
})
if err != nil {
fmt.Printf("error publishing: %s\n", err)
fmt.Printf("error publishing message %s: %s\n", msg, err)
} else if pr.ReasonCode != 0 && pr.ReasonCode != 16 { // 16 = Server received message but there are no subscribers
fmt.Printf("reason code %d received\n", pr.ReasonCode)
fmt.Printf("reason code %d received for message %s\n", pr.ReasonCode, msg)
} else if cfg.printMessages {
fmt.Printf("sent message: %s\n", msg)
}
Expand Down
118 changes: 68 additions & 50 deletions paho/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ type (
Session session.SessionManager
autoCloseSession bool

AuthHandler Auther
PingHandler Pinger
defaultPinger bool
AuthHandler Auther
PingHandler Pinger

// Router - new inbound messages will be passed to the `Route(*packets.Publish)` function.
//
Expand Down Expand Up @@ -112,18 +111,23 @@ type (
}
// Client is the struct representing an MQTT client
Client struct {
mu sync.Mutex
config ClientConfig

// OnPublishReceived copy of OnPublishReceived from ClientConfig (perhaps with added callback form Router)
onPublishReceived []func(PublishReceived) (bool, error)
onPublishReceivedTracker []int // Used to track positions in above
onPublishReceivedMu sync.Mutex

// authResponse is used for handling the MQTTv5 authentication exchange.
// authResponse is used for handling the MQTTv5 authentication exchange (MUST be buffered)
authResponse chan<- packets.ControlPacket
stop chan struct{}
done chan struct{} // closed when shutdown complete (only valid after Connect returns nil error)
authResponseMu sync.Mutex // protects the above

cancelFunc func()

connectCalled bool // if true `Connect` has been called and a connection is being managed
connectCalledMu sync.Mutex // protects the above

done <-chan struct{} // closed when shutdown complete (only valid after Connect returns nil error)
publishPackets chan *packets.Publish
acksTracker acksTracker
workers sync.WaitGroup
Expand Down Expand Up @@ -202,7 +206,6 @@ func NewClient(conf ClientConfig) *Client {
c.onPublishReceivedTracker = make([]int, len(c.onPublishReceived)) // Must have the same number of elements as onPublishReceived

if c.config.PingHandler == nil {
c.config.defaultPinger = true
c.config.PingHandler = NewDefaultPinger()
}
if c.config.OnClientError == nil {
Expand All @@ -224,17 +227,29 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
return nil, fmt.Errorf("client connection is nil")
}

// The connection is in c.config.Conn which is inaccessible to the user.
// The end result of `Connect` (possibly some time after it returns) will be to close the connection so calling
// Connect twice is invalid.
c.connectCalledMu.Lock()
if c.connectCalled {
c.connectCalledMu.Unlock()
return nil, fmt.Errorf("connect must only be called once")
}
c.connectCalled = true
c.connectCalledMu.Unlock()

// The passed in ctx applies to the connection process only. clientCtx applies to Client (signals
clientCtx, cancelFunc := context.WithCancel(context.Background())
done := make(chan struct{})
cleanup := func() {
close(c.stop)
cancelFunc()
close(c.publishPackets)
_ = c.config.Conn.Close()
close(c.done)
c.mu.Unlock()
close(done)
}

c.mu.Lock()
c.stop = make(chan struct{})
c.done = make(chan struct{})
c.cancelFunc = cancelFunc
c.done = done

var publishPacketsSize uint16 = math.MaxUint16
if cp.Properties != nil && cp.Properties.ReceiveMaximum != nil {
Expand Down Expand Up @@ -314,8 +329,9 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
return ca, fmt.Errorf("session error: %w", err)
}

// no more possible calls to cleanup(), defer an unlock
defer c.mu.Unlock()
// the connection is now fully up and a nil error will be returned.
// cleanup() must not be called past this point and will be handled by `shutdown`
context.AfterFunc(clientCtx, func() { c.shutdown(done) })

if ca.Properties != nil {
if ca.Properties.ServerKeepAlive != nil {
Expand Down Expand Up @@ -347,7 +363,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
go func() {
defer c.workers.Done()
defer c.debug.Println("returning from ping handler worker")
if err := c.config.PingHandler.Run(c.config.Conn, keepalive); err != nil {
if err := c.config.PingHandler.Run(clientCtx, c.config.Conn, keepalive); err != nil {
go c.error(fmt.Errorf("ping handler error: %w", err))
}
}()
Expand All @@ -367,7 +383,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
go func() {
defer c.workers.Done()
defer c.debug.Println("returning from incoming worker")
c.incoming()
c.incoming(clientCtx)
}()

if c.config.EnableManualAcknowledgment {
Expand All @@ -386,7 +402,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
t := time.NewTicker(sendAcksInterval)
for {
select {
case <-c.stop:
case <-clientCtx.Done():
return
case <-t.C:
c.acksTracker.flush(func(pbs []*packets.Publish) {
Expand Down Expand Up @@ -426,6 +442,8 @@ func (c *Client) ack(pb *packets.Publish) {
c.config.Session.Ack(pb)
}

// routePublishPackets listens on c.publishPackets and passes received messages to the handlers
// terminates when publishPackets closed
func (c *Client) routePublishPackets() {
for pb := range c.publishPackets {
// Copy onPublishReceived so lock is only held briefly
Expand Down Expand Up @@ -468,13 +486,13 @@ func (c *Client) routePublishPackets() {
// Disconnect, the Stop channel is closed or there is an error reading
// a packet from the network connection
// Closes `c.publishPackets` when done (should be the only thing sending on this channel)
func (c *Client) incoming() {
func (c *Client) incoming(ctx context.Context) {
defer c.debug.Println("client stopping, incoming stopping")
defer close(c.publishPackets)

for {
select {
case <-c.stop:
case <-ctx.Done():
return
default:
recv, err := packets.ReadPacket(c.config.Conn)
Expand All @@ -495,9 +513,14 @@ func (c *Client) incoming() {
if c.config.AuthHandler != nil {
go c.config.AuthHandler.Authenticated()
}
c.authResponseMu.Lock()
if c.authResponse != nil {
c.authResponse <- *recv
select { // authResponse must be buffered, and we should only receive 1 AUTH packet a time
case c.authResponse <- *recv:
default:
}
}
c.authResponseMu.Unlock()
case packets.AuthContinueAuthentication:
if c.config.AuthHandler != nil {
if _, err := c.config.AuthHandler.Authenticate(AuthFromPacketAuth(ap)).Packet().WriteTo(c.config.Conn); err != nil {
Expand All @@ -513,24 +536,25 @@ func (c *Client) incoming() {
c.config.Session.PacketReceived(recv, c.publishPackets)
} else {
c.debug.Printf("received QoS%d PUBLISH", pb.QoS)
c.mu.Lock()
select {
case <-c.stop:
c.mu.Unlock()
case <-ctx.Done():
return
default:
c.publishPackets <- pb
c.mu.Unlock()
case c.publishPackets <- pb:
}
}
case packets.PUBACK, packets.PUBCOMP, packets.SUBACK, packets.UNSUBACK, packets.PUBREC, packets.PUBREL:
c.config.Session.PacketReceived(recv, c.publishPackets)
case packets.DISCONNECT:
pd := recv.Content.(*packets.Disconnect)
c.debug.Println("received DISCONNECT")
c.authResponseMu.Lock()
if c.authResponse != nil {
c.authResponse <- *recv
select { // authResponse must be buffered, and we should only receive 1 AUTH packet a time
case c.authResponse <- *recv:
default:
}
}
c.authResponseMu.Unlock()
c.config.Session.ConnectionLost(pd) // this may impact the session state
go func() {
if c.config.OnServerDisconnect != nil {
Expand All @@ -548,23 +572,17 @@ func (c *Client) incoming() {
}
}

// close terminates the connection and waits for a clean shutdown
// may be called multiple times (subsequent calls will wait on previously requested shutdown)
func (c *Client) close() {
c.mu.Lock()
defer c.mu.Unlock()

select {
case <-c.stop:
// already shutting down, return when shutdown complete
<-c.done
return
default:
}

close(c.stop)
c.cancelFunc() // cleanup handled by AfterFunc defined in Connect
<-c.done
}

c.debug.Println("client stopped")
c.config.PingHandler.Stop()
c.debug.Println("ping stopped")
// shutdown cleanly shutdown the client
// This should only be called via the AfterFunc in `Connect` (shutdown must not be called more than once)
func (c *Client) shutdown(done chan<- struct{}) {
c.debug.Println("client stop requested")
_ = c.config.Conn.Close()
c.debug.Println("conn closed")
c.acksTracker.reset()
Expand All @@ -578,7 +596,7 @@ func (c *Client) close() {
c.debug.Println("session updated, waiting on workers")
c.workers.Wait()
c.debug.Println("workers done")
close(c.done)
close(done)
}

// error is called to signify that an error situation has occurred, this
Expand All @@ -605,17 +623,17 @@ func (c *Client) serverDisconnect(d *Disconnect) {
func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, error) {
c.debug.Println("client initiated reauthentication")
authResp := make(chan packets.ControlPacket, 1)
c.mu.Lock()
c.authResponseMu.Lock()
if c.authResponse != nil {
c.mu.Unlock()
c.authResponseMu.Unlock()
return nil, fmt.Errorf("previous authentication is still in progress")
}
c.authResponse = authResp
c.mu.Unlock()
c.authResponseMu.Unlock()
defer func() {
c.mu.Lock()
c.authResponseMu.Lock()
c.authResponse = nil
c.mu.Unlock()
c.authResponseMu.Unlock()
}()

c.debug.Println("sending AUTH")
Expand Down
Loading

0 comments on commit b368787

Please sign in to comment.