Skip to content

Commit

Permalink
balancer: automatically stop producers on subchannel state changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley committed Sep 20, 2024
1 parent 8ea3460 commit 5d945cf
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 361 deletions.
5 changes: 3 additions & 2 deletions balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ type SubConn interface {
Connect()
// GetOrBuildProducer returns a reference to the existing Producer for this
// ProducerBuilder in this SubConn, or, if one does not currently exist,
// creates a new one and returns it. Returns a close function which must
// be called when the Producer is no longer needed.
// creates a new one and returns it. Returns a close function which may be
// called when the Producer is no longer needed. Otherwise the producer
// will automatically be closed upon connection loss or subchannel close.
GetOrBuildProducer(ProducerBuilder) (p Producer, close func())
// Shutdown shuts down the SubConn gracefully. Any started RPCs will be
// allowed to complete. No future calls should be made on the SubConn.
Expand Down
22 changes: 12 additions & 10 deletions balancer/weightedroundrobin/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,17 +526,21 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
w.cfg = cfg
w.mu.Unlock()

newPeriod := cfg.OOBReportingPeriod
if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport &&
newPeriod == oldCfg.OOBReportingPeriod {
cfg.OOBReportingPeriod == oldCfg.OOBReportingPeriod {
// Load reporting wasn't enabled before or after, or load reporting was
// enabled before and after, and had the same period. (Note that with
// load reporting disabled, OOBReportingPeriod is always 0.)
return
}
// (Optionally stop and) start the listener to use the new config's
// settings for OOB reporting.
if w.connectivityState == connectivity.Ready {
// (Re)start the listener to use the new config's settings for OOB
// reporting.
w.updateORCAListener(cfg)
}
}

func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) {
if w.stopORCAListener != nil {
w.stopORCAListener()
}
Expand All @@ -545,9 +549,9 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
return
}
if w.logger.V(2) {
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod)
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, cfg.OOBReportingPeriod)
}
opts := orca.OOBListenerOptions{ReportInterval: time.Duration(newPeriod)}
opts := orca.OOBListenerOptions{ReportInterval: time.Duration(cfg.OOBReportingPeriod)}
w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts)
}

Expand All @@ -569,11 +573,9 @@ func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connect
w.mu.Lock()
w.nonEmptySince = time.Time{}
w.lastUpdated = time.Time{}
cfg := w.cfg
w.mu.Unlock()
case connectivity.Shutdown:
if w.stopORCAListener != nil {
w.stopORCAListener()
}
w.updateORCAListener(cfg)
}

oldCS := w.connectivityState
Expand Down
39 changes: 22 additions & 17 deletions balancer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,20 @@ type acBalancerWrapper struct {
ccb *ccBalancerWrapper // read-only
stateListener func(balancer.SubConnState)

mu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
producersMu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
}

// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error, readyChan chan struct{}) {
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Invalidate all producers on any state change.
acbw.closeProducers()

// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
Expand All @@ -275,15 +278,6 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve
setConnectedAddress(&scs, curAddr)
}
acbw.stateListener(scs)
acbw.ac.mu.Lock()
defer acbw.ac.mu.Unlock()
if s == connectivity.Ready {
// When changing states to READY, close stateReadyChan. Wait until
// after we notify the LB policy's listener(s) in order to prevent
// ac.getTransport() from unblocking before the LB policy starts
// tracking the subchannel as READY.
close(readyChan)
}
})
}

Expand All @@ -300,14 +294,15 @@ func (acbw *acBalancerWrapper) Connect() {
}

func (acbw *acBalancerWrapper) Shutdown() {
acbw.closeProducers()
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
}

// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
// ready, blocks until it is or ctx expires. Returns an error when the context
// expires or the addrConn is shut down.
func (acbw *acBalancerWrapper) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
transport, err := acbw.ac.getTransport(ctx)
transport, err := acbw.ac.getTransport()
if err != nil {
return nil, err
}
Expand All @@ -334,8 +329,8 @@ type refCountedProducer struct {
}

func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (balancer.Producer, func()) {
acbw.mu.Lock()
defer acbw.mu.Unlock()
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()

// Look up existing producer from this builder.
pData := acbw.producers[pb]
Expand All @@ -352,13 +347,23 @@ func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (
// and delete the refCountedProducer from the map if the total reference
// count goes to zero.
unref := func() {
acbw.mu.Lock()
acbw.producersMu.Lock()
pData.refs--
if pData.refs == 0 {
defer pData.close() // Run outside the acbw mutex
delete(acbw.producers, pb)
}
acbw.mu.Unlock()
acbw.producersMu.Unlock()
}
return pData.producer, grpcsync.OnceFunc(unref)
}

func (acbw *acBalancerWrapper) closeProducers() {
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()
for pb, pData := range acbw.producers {
pData.refs = 0
pData.close()
delete(acbw.producers, pb)
}
}
54 changes: 17 additions & 37 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,14 +825,13 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
}

ac := &addrConn{
state: connectivity.Idle,
cc: cc,
addrs: copyAddresses(addrs),
scopts: opts,
dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
stateReadyChan: make(chan struct{}),
state: connectivity.Idle,
cc: cc,
addrs: copyAddresses(addrs),
scopts: opts,
dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
}
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if
Expand Down Expand Up @@ -1179,8 +1178,7 @@ type addrConn struct {
addrs []resolver.Address // All addresses that the resolver resolved to.

// Use updateConnectivityState for updating addrConn's connectivity state.
state connectivity.State
stateReadyChan chan struct{} // closed and recreated on every READY state change.
state connectivity.State

backoffIdx int // Needs to be stateful for resetConnectBackoff.
resetBackoff chan struct{}
Expand All @@ -1193,22 +1191,14 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s {
return
}
if ac.state == connectivity.Ready {
// When leaving ready, re-create the ready channel.
ac.stateReadyChan = make(chan struct{})
}
if s == connectivity.Shutdown {
// Wake any producer waiting to create a stream on the transport.
close(ac.stateReadyChan)
}
ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v", s)
} else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, ac.curAddr, lastErr, ac.stateReadyChan)
ac.acbw.updateState(s, ac.curAddr, lastErr)
}

// adjustParams updates parameters used to create transports upon
Expand Down Expand Up @@ -1515,26 +1505,16 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport {
// getTransport waits until the addrconn is ready and returns the transport.
// If the context expires first, returns an appropriate status. If the
// addrConn is stopped first, returns an Unavailable status error.
func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) {
for ctx.Err() == nil {
ac.mu.Lock()
t, state, readyChan := ac.transport, ac.state, ac.stateReadyChan
ac.mu.Unlock()
if state == connectivity.Shutdown {
// Return an error immediately in only this case since a connection
// will never occur.
return nil, status.Errorf(codes.Unavailable, "SubConn shutting down")
}
func (ac *addrConn) getTransport() (transport.ClientTransport, error) {
ac.mu.Lock()
t, state := ac.transport, ac.state
ac.mu.Unlock()

select {
case <-ctx.Done():
case <-readyChan:
if state == connectivity.Ready {
return t, nil
}
}
if state != connectivity.Ready {
return nil, status.Errorf(codes.Unavailable, "SubConn state is %v; not Ready", state)
}
return nil, status.FromContextError(ctx.Err()).Err()

return t, nil
}

// tearDown starts to tear down the addrConn.
Expand Down
18 changes: 13 additions & 5 deletions orca/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ func (*producerBuilder) Build(cci any) (balancer.Producer, func()) {
backoff: internal.DefaultBackoffFunc,
}
return p, func() {
p.mu.Lock()
if p.stop != nil {
p.stop()
p.stop = nil
}
p.mu.Unlock()
<-p.stopped
}
}
Expand Down Expand Up @@ -77,9 +83,6 @@ func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOpt

p.registerListener(l, opts.ReportInterval)

// TODO: When we can register for SubConn state updates, automatically call
// stop() on SHUTDOWN.

// If stop is called multiple times, prevent it from having any effect on
// subsequent calls.
return grpcsync.OnceFunc(func() {
Expand All @@ -96,13 +99,13 @@ type producer struct {
// is incremented when stream errors occur and is reset when the stream
// reports a result.
backoff func(int) time.Duration
stopped chan struct{} // closed when the run goroutine exits

mu sync.Mutex
intervals map[time.Duration]int // map from interval time to count of listeners requesting that time
listeners map[OOBListener]struct{} // set of registered listeners
minInterval time.Duration
stop func() // stops the current run goroutine
stopped chan struct{} // closed when the run goroutine exits
stop func() // stops the current run goroutine
}

// registerListener adds the listener and its requested report interval to the
Expand Down Expand Up @@ -169,8 +172,10 @@ func (p *producer) updateRunLocked() {
// run manages the ORCA OOB stream on the subchannel.
func (p *producer) run(ctx context.Context, done chan struct{}, interval time.Duration) {
defer close(done)
logger.Info("XXXXXXXXXXXXXXXXXXX run?")

runStream := func() error {
logger.Info("XXXXXXXXXXXXXXXXXXX runStream?")
resetBackoff, err := p.runStream(ctx, interval)
if status.Code(err) == codes.Unimplemented {
// Unimplemented; do not retry.
Expand Down Expand Up @@ -205,14 +210,17 @@ func (p *producer) runStream(ctx context.Context, interval time.Duration) (reset
ReportInterval: durationpb.New(interval),
})
if err != nil {
logger.Info("XXXXXXXXXXXXXXXX stream err:", err.Error())
return false, err
}
logger.Info("XXXXXXXXXXXXXXXX started stream")

for {
report, err := stream.Recv()
if err != nil {
return resetBackoff, err
}
logger.Info("XXXXXXXXXXXXXXXX got report")
resetBackoff = true
p.mu.Lock()
for l := range p.listeners {
Expand Down
13 changes: 10 additions & 3 deletions orca/producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
Expand Down Expand Up @@ -64,13 +65,19 @@ func (w *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubCon
if len(addrs) != 1 {
panic(fmt.Sprintf("got addrs=%v; want len(addrs) == 1", addrs))
}
var sc balancer.SubConn
opts.StateListener = func(scs balancer.SubConnState) {
if scs.ConnectivityState != connectivity.Ready {
return
}
l := getListenerInfo(addrs[0])
l.listener.cleanup = orca.RegisterOOBListener(sc, l.listener, l.opts)
l.sc = sc
}
sc, err := w.ClientConn.NewSubConn(addrs, opts)
if err != nil {
return sc, err
}
l := getListenerInfo(addrs[0])
l.listener.cleanup = orca.RegisterOOBListener(sc, l.listener, l.opts)
l.sc = sc
return sc, nil
}

Expand Down
Loading

0 comments on commit 5d945cf

Please sign in to comment.