diff --git a/xds/internal/balancer/outlierdetection/balancer.go b/xds/internal/balancer/outlierdetection/balancer.go index 5630b6fcce7f..46903bb7acc9 100644 --- a/xds/internal/balancer/outlierdetection/balancer.go +++ b/xds/internal/balancer/outlierdetection/balancer.go @@ -142,7 +142,6 @@ func (bb) Name() string { type scUpdate struct { scw *subConnWrapper state balancer.SubConnState - cb func(balancer.SubConnState) } type ejectionUpdate struct { @@ -346,7 +345,7 @@ func (b *outlierDetectionBalancer) ResolverError(err error) { b.child.ResolverError(err) } -func (b *outlierDetectionBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState, cb func(balancer.SubConnState)) { +func (b *outlierDetectionBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { b.mu.Lock() defer b.mu.Unlock() scw, ok := b.scWrappers[sc] @@ -362,12 +361,11 @@ func (b *outlierDetectionBalancer) updateSubConnState(sc balancer.SubConn, state b.scUpdateCh.Put(&scUpdate{ scw: scw, state: state, - cb: cb, }) } func (b *outlierDetectionBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { - b.updateSubConnState(sc, state, nil) + b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state) } func (b *outlierDetectionBalancer) Close() { @@ -474,7 +472,7 @@ func (b *outlierDetectionBalancer) UpdateState(s balancer.State) { func (b *outlierDetectionBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { var sc balancer.SubConn oldListener := opts.StateListener - opts.StateListener = func(state balancer.SubConnState) { b.updateSubConnState(sc, state, oldListener) } + opts.StateListener = func(state balancer.SubConnState) { b.updateSubConnState(sc, state) } sc, err := b.cc.NewSubConn(addrs, opts) if err != nil { return nil, err @@ -483,6 +481,7 @@ func (b *outlierDetectionBalancer) NewSubConn(addrs []resolver.Address, opts bal SubConn: sc, addresses: addrs, scUpdateCh: b.scUpdateCh, + listener: oldListener, } b.mu.Lock() defer b.mu.Unlock() @@ -624,8 +623,8 @@ func (b *outlierDetectionBalancer) handleSubConnUpdate(u *scUpdate) { scw.latestState = u.state if !scw.ejected { b.childMu.Lock() - if u.cb != nil { - u.cb(u.state) + if scw.listener != nil { + scw.listener(u.state) } else { b.child.UpdateSubConnState(scw, u.state) } @@ -647,7 +646,11 @@ func (b *outlierDetectionBalancer) handleEjectedUpdate(u *ejectionUpdate) { } } b.childMu.Lock() - b.child.UpdateSubConnState(scw, stateToUpdate) + if scw.listener != nil { + scw.listener(stateToUpdate) + } else { + b.child.UpdateSubConnState(scw, stateToUpdate) + } b.childMu.Unlock() } diff --git a/xds/internal/balancer/outlierdetection/balancer_test.go b/xds/internal/balancer/outlierdetection/balancer_test.go index 4c3051f8fef3..16dd395404bb 100644 --- a/xds/internal/balancer/outlierdetection/balancer_test.go +++ b/xds/internal/balancer/outlierdetection/balancer_test.go @@ -1085,7 +1085,7 @@ func (s) TestEjectUnejectSuccessRate(t *testing.T) { // Since no addresses are ejected, a SubConn update should forward down // to the child. - od.UpdateSubConnState(scw1.(*subConnWrapper).SubConn, balancer.SubConnState{ + od.updateSubConnState(scw1.(*subConnWrapper).SubConn, balancer.SubConnState{ ConnectivityState: connectivity.Connecting, }) @@ -1147,7 +1147,7 @@ func (s) TestEjectUnejectSuccessRate(t *testing.T) { // that address should not be forwarded downward. These SubConn updates // will be cached to update the child sometime in the future when the // address gets unejected. - od.UpdateSubConnState(pi.SubConn, balancer.SubConnState{ + od.updateSubConnState(pi.SubConn, balancer.SubConnState{ ConnectivityState: connectivity.Connecting, }) sCtx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) @@ -1564,7 +1564,7 @@ func (s) TestConcurrentOperations(t *testing.T) { // Call balancer.Balancers synchronously in this goroutine, upholding the // balancer.Balancer API guarantee. - od.UpdateSubConnState(scw1.(*subConnWrapper).SubConn, balancer.SubConnState{ + od.updateSubConnState(scw1.(*subConnWrapper).SubConn, balancer.SubConnState{ ConnectivityState: connectivity.Connecting, }) od.ResolverError(errors.New("some error")) diff --git a/xds/internal/balancer/outlierdetection/subconn_wrapper.go b/xds/internal/balancer/outlierdetection/subconn_wrapper.go index 71a996f29ae0..0fa422d8f262 100644 --- a/xds/internal/balancer/outlierdetection/subconn_wrapper.go +++ b/xds/internal/balancer/outlierdetection/subconn_wrapper.go @@ -31,6 +31,7 @@ import ( // whether or not this SubConn is ejected. type subConnWrapper struct { balancer.SubConn + listener func(balancer.SubConnState) // addressInfo is a pointer to the subConnWrapper's corresponding address // map entry, if the map entry exists.