Skip to content

Commit

Permalink
balancer: automatically stop producers on subchannel state changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley committed Sep 23, 2024
1 parent 8ea3460 commit a6a5fc1
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 366 deletions.
13 changes: 9 additions & 4 deletions balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@ type SubConn interface {
Connect()
// GetOrBuildProducer returns a reference to the existing Producer for this
// ProducerBuilder in this SubConn, or, if one does not currently exist,
// creates a new one and returns it. Returns a close function which must
// be called when the Producer is no longer needed.
// creates a new one and returns it. Returns a close function which may be
// called when the Producer is no longer needed. Otherwise the producer
// will automatically be closed upon connection loss or subchannel close.
// Should only be called on a SubConn in state Ready. Otherwise the
// producer will be unable to create streams.
GetOrBuildProducer(ProducerBuilder) (p Producer, close func())
// Shutdown shuts down the SubConn gracefully. Any started RPCs will be
// allowed to complete. No future calls should be made on the SubConn.
Expand Down Expand Up @@ -452,8 +455,10 @@ type ProducerBuilder interface {
// Build creates a Producer. The first parameter is always a
// grpc.ClientConnInterface (a type to allow creating RPCs/streams on the
// associated SubConn), but is declared as `any` to avoid a dependency
// cycle. Should also return a close function that will be called when all
// references to the Producer have been given up.
// cycle. Build also returns a close function that will be called when all
// references to the Producer have been given up for a SubConn, or when a
// connectivity state change occurs on the SubConn. The close function
// should always block until all asynchronous cleanup work is completed.
Build(grpcClientConnInterface any) (p Producer, close func())
}

Expand Down
22 changes: 12 additions & 10 deletions balancer/weightedroundrobin/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,17 +526,21 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
w.cfg = cfg
w.mu.Unlock()

newPeriod := cfg.OOBReportingPeriod
if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport &&
newPeriod == oldCfg.OOBReportingPeriod {
cfg.OOBReportingPeriod == oldCfg.OOBReportingPeriod {
// Load reporting wasn't enabled before or after, or load reporting was
// enabled before and after, and had the same period. (Note that with
// load reporting disabled, OOBReportingPeriod is always 0.)
return
}
// (Optionally stop and) start the listener to use the new config's
// settings for OOB reporting.
if w.connectivityState == connectivity.Ready {
// (Re)start the listener to use the new config's settings for OOB
// reporting.
w.updateORCAListener(cfg)
}
}

func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) {
if w.stopORCAListener != nil {
w.stopORCAListener()
}
Expand All @@ -545,9 +549,9 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
return
}
if w.logger.V(2) {
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod)
w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, cfg.OOBReportingPeriod)
}
opts := orca.OOBListenerOptions{ReportInterval: time.Duration(newPeriod)}
opts := orca.OOBListenerOptions{ReportInterval: time.Duration(cfg.OOBReportingPeriod)}
w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts)
}

Expand All @@ -569,11 +573,9 @@ func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connect
w.mu.Lock()
w.nonEmptySince = time.Time{}
w.lastUpdated = time.Time{}
cfg := w.cfg
w.mu.Unlock()
case connectivity.Shutdown:
if w.stopORCAListener != nil {
w.stopORCAListener()
}
w.updateORCAListener(cfg)
}

oldCS := w.connectivityState
Expand Down
39 changes: 22 additions & 17 deletions balancer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,20 @@ type acBalancerWrapper struct {
ccb *ccBalancerWrapper // read-only
stateListener func(balancer.SubConnState)

mu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
producersMu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
}

// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error, readyChan chan struct{}) {
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Invalidate all producers on any state change.
acbw.closeProducers()

// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
Expand All @@ -275,15 +278,6 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve
setConnectedAddress(&scs, curAddr)
}
acbw.stateListener(scs)
acbw.ac.mu.Lock()
defer acbw.ac.mu.Unlock()
if s == connectivity.Ready {
// When changing states to READY, close stateReadyChan. Wait until
// after we notify the LB policy's listener(s) in order to prevent
// ac.getTransport() from unblocking before the LB policy starts
// tracking the subchannel as READY.
close(readyChan)
}
})
}

Expand All @@ -300,14 +294,15 @@ func (acbw *acBalancerWrapper) Connect() {
}

func (acbw *acBalancerWrapper) Shutdown() {
acbw.closeProducers()
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
}

// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
// ready, blocks until it is or ctx expires. Returns an error when the context
// expires or the addrConn is shut down.
func (acbw *acBalancerWrapper) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
transport, err := acbw.ac.getTransport(ctx)
transport, err := acbw.ac.getTransport()
if err != nil {
return nil, err
}
Expand All @@ -334,8 +329,8 @@ type refCountedProducer struct {
}

func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (balancer.Producer, func()) {
acbw.mu.Lock()
defer acbw.mu.Unlock()
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()

// Look up existing producer from this builder.
pData := acbw.producers[pb]
Expand All @@ -352,13 +347,23 @@ func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (
// and delete the refCountedProducer from the map if the total reference
// count goes to zero.
unref := func() {
acbw.mu.Lock()
acbw.producersMu.Lock()
pData.refs--
if pData.refs == 0 {
defer pData.close() // Run outside the acbw mutex
delete(acbw.producers, pb)
}
acbw.mu.Unlock()
acbw.producersMu.Unlock()
}
return pData.producer, grpcsync.OnceFunc(unref)
}

func (acbw *acBalancerWrapper) closeProducers() {
acbw.producersMu.Lock()
defer acbw.producersMu.Unlock()
for pb, pData := range acbw.producers {
pData.refs = 0
pData.close()
delete(acbw.producers, pb)
}
}
54 changes: 17 additions & 37 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,14 +825,13 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
}

ac := &addrConn{
state: connectivity.Idle,
cc: cc,
addrs: copyAddresses(addrs),
scopts: opts,
dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
stateReadyChan: make(chan struct{}),
state: connectivity.Idle,
cc: cc,
addrs: copyAddresses(addrs),
scopts: opts,
dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
}
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if
Expand Down Expand Up @@ -1179,8 +1178,7 @@ type addrConn struct {
addrs []resolver.Address // All addresses that the resolver resolved to.

// Use updateConnectivityState for updating addrConn's connectivity state.
state connectivity.State
stateReadyChan chan struct{} // closed and recreated on every READY state change.
state connectivity.State

backoffIdx int // Needs to be stateful for resetConnectBackoff.
resetBackoff chan struct{}
Expand All @@ -1193,22 +1191,14 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s {
return
}
if ac.state == connectivity.Ready {
// When leaving ready, re-create the ready channel.
ac.stateReadyChan = make(chan struct{})
}
if s == connectivity.Shutdown {
// Wake any producer waiting to create a stream on the transport.
close(ac.stateReadyChan)
}
ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v", s)
} else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, ac.curAddr, lastErr, ac.stateReadyChan)
ac.acbw.updateState(s, ac.curAddr, lastErr)
}

// adjustParams updates parameters used to create transports upon
Expand Down Expand Up @@ -1515,26 +1505,16 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport {
// getTransport waits until the addrconn is ready and returns the transport.
// If the context expires first, returns an appropriate status. If the
// addrConn is stopped first, returns an Unavailable status error.
func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) {
for ctx.Err() == nil {
ac.mu.Lock()
t, state, readyChan := ac.transport, ac.state, ac.stateReadyChan
ac.mu.Unlock()
if state == connectivity.Shutdown {
// Return an error immediately in only this case since a connection
// will never occur.
return nil, status.Errorf(codes.Unavailable, "SubConn shutting down")
}
func (ac *addrConn) getTransport() (transport.ClientTransport, error) {
ac.mu.Lock()
t, state := ac.transport, ac.state
ac.mu.Unlock()

select {
case <-ctx.Done():
case <-readyChan:
if state == connectivity.Ready {
return t, nil
}
}
if state != connectivity.Ready {
return nil, status.Errorf(codes.Unavailable, "SubConn state is %v; not Ready", state)
}
return nil, status.FromContextError(ctx.Err()).Err()

return t, nil
}

// tearDown starts to tear down the addrConn.
Expand Down
19 changes: 11 additions & 8 deletions orca/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ func (*producerBuilder) Build(cci any) (balancer.Producer, func()) {
backoff: internal.DefaultBackoffFunc,
}
return p, func() {
p.mu.Lock()
if p.stop != nil {
p.stop()
p.stop = nil
}
p.mu.Unlock()
<-p.stopped
}
}
Expand All @@ -67,19 +73,16 @@ type OOBListenerOptions struct {
ReportInterval time.Duration
}

// RegisterOOBListener registers an out-of-band load report listener on sc.
// Any OOBListener may only be registered once per subchannel at a time. The
// returned stop function must be called when no longer needed. Do not
// RegisterOOBListener registers an out-of-band load report listener on a Ready
// sc. Any OOBListener may only be registered once per subchannel at a time.
// The returned stop function must be called when no longer needed. Do not
// register a single OOBListener more than once per SubConn.
func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOptions) (stop func()) {
pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton)
p := pr.(*producer)

p.registerListener(l, opts.ReportInterval)

// TODO: When we can register for SubConn state updates, automatically call
// stop() on SHUTDOWN.

// If stop is called multiple times, prevent it from having any effect on
// subsequent calls.
return grpcsync.OnceFunc(func() {
Expand All @@ -96,13 +99,13 @@ type producer struct {
// is incremented when stream errors occur and is reset when the stream
// reports a result.
backoff func(int) time.Duration
stopped chan struct{} // closed when the run goroutine exits

mu sync.Mutex
intervals map[time.Duration]int // map from interval time to count of listeners requesting that time
listeners map[OOBListener]struct{} // set of registered listeners
minInterval time.Duration
stop func() // stops the current run goroutine
stopped chan struct{} // closed when the run goroutine exits
stop func() // stops the current run goroutine
}

// registerListener adds the listener and its requested report interval to the
Expand Down
13 changes: 10 additions & 3 deletions orca/producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
Expand Down Expand Up @@ -64,13 +65,19 @@ func (w *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubCon
if len(addrs) != 1 {
panic(fmt.Sprintf("got addrs=%v; want len(addrs) == 1", addrs))
}
var sc balancer.SubConn
opts.StateListener = func(scs balancer.SubConnState) {
if scs.ConnectivityState != connectivity.Ready {
return
}
l := getListenerInfo(addrs[0])
l.listener.cleanup = orca.RegisterOOBListener(sc, l.listener, l.opts)
l.sc = sc
}
sc, err := w.ClientConn.NewSubConn(addrs, opts)
if err != nil {
return sc, err
}
l := getListenerInfo(addrs[0])
l.listener.cleanup = orca.RegisterOOBListener(sc, l.listener, l.opts)
l.sc = sc
return sc, nil
}

Expand Down
Loading

0 comments on commit a6a5fc1

Please sign in to comment.