diff --git a/services/replicator/inconn.go b/services/replicator/inconn.go index 2131cdfa..2c7e8ff2 100644 --- a/services/replicator/inconn.go +++ b/services/replicator/inconn.go @@ -44,13 +44,11 @@ type ( metricsScope int perDestMetricsScope int - closeChannel chan struct{} // channel to indicate the connection should be closed - creditsCh chan int32 // channel to pass credits from readCreditsStream to writeMsgsStream - creditFlowExpiration time.Time // credit expiration is used to close the stream if we don't receive any credit for some period of time + creditsCh chan int32 // channel to pass credits from readCreditsStream to writeMsgsStream + creditFlowExpiration time.Time // credit expiration is used to close the stream if we don't receive any credit for some period of time - lk sync.Mutex - opened bool - closed bool + wg sync.WaitGroup + shutdownCh chan struct{} } ) @@ -75,68 +73,59 @@ func newInConnection(extUUID string, destPath string, stream storeStream.BStoreO destM3Client: metrics.NewClientWithTags(m3Client, metrics.Replicator, common.GetDestinationTags(destPath, localLogger)), metricsScope: metricsScope, perDestMetricsScope: perDestMetricsScope, - closeChannel: make(chan struct{}), creditsCh: make(chan int32, 5), creditFlowExpiration: time.Now().Add(creditFlowTimeout), + shutdownCh: make(chan struct{}), } return conn } func (conn *inConnection) open() { - conn.lk.Lock() - defer conn.lk.Unlock() - - if !conn.opened { - go conn.writeMsgsStream() - go conn.readCreditsStream() - - conn.opened = true - } + conn.wg.Add(2) + go conn.writeMsgsStream() + go conn.readCreditsStream() conn.logger.Info("in connection opened") } -func (conn *inConnection) close() { - conn.lk.Lock() - defer conn.lk.Unlock() +func (conn *inConnection) WaitUntilDone() { + conn.wg.Wait() +} - if !conn.closed { - close(conn.closeChannel) - conn.closed = true - } - conn.logger.Info("in connection closed") +func (conn *inConnection) shutdown() { + close(conn.shutdownCh) + conn.logger.Info(`in connection shutdown`) } func (conn *inConnection) readCreditsStream() { + defer conn.wg.Done() + defer close(conn.creditsCh) for { - select { - case <-conn.closeChannel: + msg, err := conn.stream.Read() + if err != nil { + conn.logger.WithField(common.TagErr, err).Info("read credit failed") return - default: - msg, err := conn.stream.Read() - if err != nil { - conn.logger.WithField(common.TagErr, err).Info("read credit failed") - go conn.close() - return - } + } - conn.m3Client.AddCounter(conn.metricsScope, metrics.ReplicatorInConnCreditsReceived, int64(msg.GetCredits())) + conn.m3Client.AddCounter(conn.metricsScope, metrics.ReplicatorInConnCreditsReceived, int64(msg.GetCredits())) - // send this to writeMsgsPump which keeps track of the local credits - // Make this non-blocking because writeMsgsStream could be closed before this - select { - case conn.creditsCh <- msg.GetCredits(): - default: - conn.logger. - WithField(`channelLen`, len(conn.creditsCh)). - WithField(`credits`, msg.GetCredits()). - Warn(`Dropped credits because of blocked channel`) - } + // send this to writeMsgsPump which keeps track of the local credits + // Make this non-blocking because writeMsgsStream could be closed before this + select { + case conn.creditsCh <- msg.GetCredits(): + case <-conn.shutdownCh: + return + default: + conn.logger. + WithField(`channelLen`, len(conn.creditsCh)). + WithField(`credits`, msg.GetCredits()). + Warn(`Dropped credits because of blocked channel`) } } } func (conn *inConnection) writeMsgsStream() { + defer conn.wg.Done() defer conn.stream.Done() flushTicker := time.NewTicker(flushTimeout) @@ -146,24 +135,31 @@ func (conn *inConnection) writeMsgsStream() { for { if localCredits == 0 { select { - case credit := <-conn.creditsCh: + case credit, ok := <-conn.creditsCh: + if !ok { + conn.logger.Info(`internal credit channel closed`) + return + } conn.extentCreditExpiration() localCredits += credit case <-time.After(creditFlowTimeout): conn.logger.Warn("credit flow timeout") if conn.isCreditFlowExpired() { conn.logger.Warn("credit flow expired") - go conn.close() + return } - case <-conn.closeChannel: + case <-conn.shutdownCh: return } } else { select { - case msg := <-conn.msgCh: + case msg, ok := <-conn.msgCh: + if !ok { + conn.logger.Info("msg channel closed") + return + } if err := conn.stream.Write(msg); err != nil { conn.logger.Error("write msg failed") - go conn.close() return } @@ -182,15 +178,19 @@ func (conn *inConnection) writeMsgsStream() { } localCredits-- - case credit := <-conn.creditsCh: + case credit, ok := <-conn.creditsCh: + if !ok { + conn.logger.Info(`internal credit channel closed`) + return + } conn.extentCreditExpiration() localCredits += credit case <-flushTicker.C: if err := conn.stream.Flush(); err != nil { conn.logger.Error(`flush msg failed`) - go conn.close() + return } - case <-conn.closeChannel: + case <-conn.shutdownCh: return } } diff --git a/services/replicator/outconn.go b/services/replicator/outconn.go index e5bb0b8e..2e62efea 100644 --- a/services/replicator/outconn.go +++ b/services/replicator/outconn.go @@ -49,12 +49,10 @@ type ( lastMsgReplicatedTime int64 totalMsgReplicated int32 - readMsgCountChannel chan int32 // channel to pass read msg count from readMsgStream to writeCreditsStream in order to issue more credits - closeChannel chan struct{} // channel to indicate the connection should be closed + readMsgCountChannel chan int32 // channel to pass read msg count from readMsgStream to writeCreditsStream in order to issue more credits - lk sync.Mutex - opened bool - closed bool + wg sync.WaitGroup + shutdownCh chan struct{} } ) @@ -66,6 +64,25 @@ const ( creditBatchSize = initialCreditSize / 10 ) +// Design philosophies (applies to InConnection as well): +// read pump reads from stream, write pump writes to stream +// read pump communicate with write pump using an internal channel. Read pump writes to the internal channel, and write pump reads from it +// +// Read pump close: +// trigger: gets a stream read error (remote shuts down the connection) +// action: +// 1. close the internal channel +// 2. (for outConn only) close msg channel +// +// +// Write pump close: +// trigger: +// 1. gets a stream write error (remote shuts down the connection) +// 2. internal channel is closed(caused by read pump close) +// 3. (for inConn only) msg channel is closed +// action: call stream.Done() +// + func newOutConnection(extUUID string, destPath string, stream storeStream.BStoreOpenReadStreamOutCall, logger bark.Logger, m3Client metrics.Client, metricsScope int) *outConnection { localLogger := logger.WithFields(bark.Fields{ common.TagExt: extUUID, @@ -82,43 +99,34 @@ func newOutConnection(extUUID string, destPath string, stream storeStream.BStore m3Client: m3Client, metricsScope: metricsScope, readMsgCountChannel: make(chan int32, 10), - closeChannel: make(chan struct{}), + shutdownCh: make(chan struct{}), } return conn } func (conn *outConnection) open() { - conn.lk.Lock() - defer conn.lk.Unlock() - - if !conn.opened { - go conn.writeCreditsStream() - go conn.readMsgStream() - - conn.opened = true - } + conn.wg.Add(2) + go conn.writeCreditsStream() + go conn.readMsgStream() conn.logger.Info("out connection opened") } -func (conn *outConnection) close() { - conn.lk.Lock() - defer conn.lk.Unlock() - - if !conn.closed { - close(conn.closeChannel) - conn.closed = true - } +func (conn *outConnection) WaitUntilDone() { + conn.wg.Wait() +} - conn.logger.Info("out connection closed") +func (conn *outConnection) shutdown() { + close(conn.shutdownCh) + conn.logger.Info(`out connection shutdown`) } func (conn *outConnection) writeCreditsStream() { + defer conn.wg.Done() defer conn.stream.Done() + if err := conn.sendCredits(initialCreditSize); err != nil { conn.logger.Error(`error writing initial credits`) - - go conn.close() return } @@ -128,17 +136,19 @@ func (conn *outConnection) writeCreditsStream() { if numMsgsRead > 0 { if err := conn.sendCredits(numMsgsRead); err != nil { conn.logger.Error(`error sending credits`) - - go conn.close() return } numMsgsRead = 0 } else { select { // Note: this will block until readMsgStream sends msg count to the channel, or the connection is closed - case msgsRead := <-conn.readMsgCountChannel: + case msgsRead, ok := <-conn.readMsgCountChannel: numMsgsRead += msgsRead - case <-conn.closeChannel: + if !ok { + conn.logger.Info(`read msg count channel closed`) + return + } + case <-conn.shutdownCh: return } } @@ -146,98 +156,45 @@ func (conn *outConnection) writeCreditsStream() { } func (conn *outConnection) readMsgStream() { - // lastSeqNum is used to track whether our sequence numbers are - // monotonically increasing - // We initialize this to -1 to skip the first message check - var lastSeqNum int64 = -1 + defer conn.wg.Done() + defer close(conn.readMsgCountChannel) + defer close(conn.msgsCh) - var sealMsgRead bool var numMsgsRead int32 // Note we must continue read until we hit an error before returning from this function // Because the websocket client only tear down the underlying connection when it gets a read error -readloop: for { rmc, err := conn.stream.Read() if err != nil { conn.logger.WithField(common.TagErr, err).Error(`Error reading msg`) - go conn.close() return } - switch rmc.GetType() { - case store.ReadMessageContentType_MESSAGE: - msg := rmc.GetMessage() - - if sealMsgRead { - conn.logger.WithFields(bark.Fields{ - "seqNum": msg.Message.GetSequenceNumber(), - }).Error("regular message read after seal message") - go conn.close() - continue readloop - } - - // Sequence number check to make sure we get monotonically increasing sequence number. - if lastSeqNum+1 != msg.Message.GetSequenceNumber() && lastSeqNum != -1 { - expectedSeqNum := 1 + lastSeqNum - - conn.logger.WithFields(bark.Fields{ - "seqNum": msg.Message.GetSequenceNumber(), - "expectedSeqNum": expectedSeqNum, - }).Error("sequence number out of order") - go conn.close() - continue readloop - } - - // update the lastSeqNum to this value - lastSeqNum = msg.Message.GetSequenceNumber() - + if rmc.GetType() == store.ReadMessageContentType_MESSAGE { conn.m3Client.IncCounter(conn.metricsScope, metrics.ReplicatorOutConnMsgRead) + } + if rmc.GetType() == store.ReadMessageContentType_SEALED { + conn.logger.WithField(`SequenceNumber`, rmc.GetSealed().GetSequenceNumber()).Info(`extent sealed`) + } - // now push msg to the msg channel (which will in turn be pushed to client) - // Note this is a blocking call here - select { - case conn.msgsCh <- rmc: - numMsgsRead++ - atomic.AddInt32(&conn.totalMsgReplicated, 1) - atomic.StoreInt64(&conn.lastMsgReplicatedTime, time.Now().UnixNano()) - case <-conn.closeChannel: - conn.logger.Info(`writing msg to the channel failed because of shutdown`) - continue readloop - } - - case store.ReadMessageContentType_SEALED: - seal := rmc.GetSealed() - conn.logger.WithField(`SequenceNumber`, seal.GetSequenceNumber()).Info(`extent sealed`) - sealMsgRead = true - - // now push msg to the msg channel (which will in turn be pushed to client) - // Note this is a blocking call here - select { - case conn.msgsCh <- rmc: - numMsgsRead++ - atomic.AddInt32(&conn.totalMsgReplicated, 1) - atomic.StoreInt64(&conn.lastMsgReplicatedTime, time.Now().UnixNano()) - case <-conn.closeChannel: - conn.logger.Info(`writing msg to the channel failed because of shutdown`) - } - - continue readloop - - case store.ReadMessageContentType_ERROR: - msgErr := rmc.GetError() - conn.logger.WithField(`Message`, msgErr.GetMessage()).Error(`received error from reading msg`) - go conn.close() - continue readloop - - default: - conn.logger.WithField(`Type`, rmc.GetType()).Error(`received ReadMessageContent with unrecognized type`) + // now push msg to the msg channel (which will in turn be pushed to client) + // Note this is a blocking call here + select { + case conn.msgsCh <- rmc: + numMsgsRead++ + atomic.AddInt32(&conn.totalMsgReplicated, 1) + atomic.StoreInt64(&conn.lastMsgReplicatedTime, time.Now().UnixNano()) + case <-conn.shutdownCh: + return } if numMsgsRead >= creditBatchSize { select { case conn.readMsgCountChannel <- numMsgsRead: numMsgsRead = 0 + case <-conn.shutdownCh: + return default: // Not the end of world if the channel is blocked conn.logger.WithField(`credit`, numMsgsRead).Info("readMsgStream: blocked sending credits; accumulating credits to send later") diff --git a/services/replicator/replicator.go b/services/replicator/replicator.go index 5afdd65e..5495039e 100644 --- a/services/replicator/replicator.go +++ b/services/replicator/replicator.go @@ -158,14 +158,13 @@ func (r *Replicator) Start(thriftService []thrift.TChanServer) { func (r *Replicator) Stop() { r.hostIDHeartbeater.Stop() for _, conn := range r.remoteReplicatorConn { - conn.close() + conn.shutdown() } for _, conn := range r.storehostConn { - conn.close() + conn.shutdown() } r.metadataReconciler.Stop() r.SCommon.Stop() - } // RegisterWSHandler is the implementation of WSService interface @@ -258,10 +257,8 @@ func (r *Replicator) OpenReplicationReadStreamHandler(w http.ResponseWriter, req inConn := newInConnection(extUUID, destDesc.GetPath(), inStream, outConn.msgsCh, r.logger, r.m3Client, metrics.OpenReplicationReadScope, metrics.OpenReplicationReadPerDestScope) inConn.open() - go r.manageInOutConn(inConn, outConn) - <-inConn.closeChannel - <-outConn.closeChannel - + outConn.WaitUntilDone() + inConn.WaitUntilDone() return } @@ -339,10 +336,8 @@ func (r *Replicator) OpenReplicationRemoteReadStreamHandler(w http.ResponseWrite inConn := newInConnection(extUUID, destDesc.GetPath(), inStream, outConn.msgsCh, r.logger, r.m3Client, metrics.OpenReplicationRemoteReadScope, metrics.OpenReplicationRemoteReadPerDestScope) inConn.open() - go r.manageInOutConn(inConn, outConn) - <-inConn.closeChannel - <-outConn.closeChannel - + outConn.WaitUntilDone() + inConn.WaitUntilDone() return } @@ -1438,16 +1433,3 @@ func (r *Replicator) createStoreHostReadStream(destUUID string, extUUID string, return } - -func (r *Replicator) manageInOutConn(inConn *inConnection, outConn *outConnection) { - for { - select { - case <-inConn.closeChannel: - go outConn.close() - return - case <-outConn.closeChannel: - go inConn.close() - return - } - } -}