From dad94df1c61b9c3e0a933dd09f0a981c4866c426 Mon Sep 17 00:00:00 2001 From: PlanetScale Actions Bot <60239337+planetscale-actions-bot@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:07:52 -0800 Subject: [PATCH] [latest-17.0](#4549): CherryPick(#15365): CI: Address data races on memorytopo Conn.closed (#4555) * backport of 4549 * Fix import conflict Signed-off-by: Matt Lord --------- Signed-off-by: Matt Lord Co-authored-by: Matt Lord --- go/vt/topo/memorytopo/election.go | 8 ++++---- go/vt/topo/memorytopo/lock.go | 2 +- go/vt/topo/memorytopo/memorytopo.go | 7 ++++--- go/vt/topo/memorytopo/watch.go | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/go/vt/topo/memorytopo/election.go b/go/vt/topo/memorytopo/election.go index 52bbe2e93ce..0a76c202de2 100644 --- a/go/vt/topo/memorytopo/election.go +++ b/go/vt/topo/memorytopo/election.go @@ -28,7 +28,7 @@ import ( func (c *Conn) NewLeaderParticipation(name, id string) (topo.LeaderParticipation, error) { c.factory.callstats.Add([]string{"NewLeaderParticipation"}, 1) - if c.closed { + if c.closed.Load() { return nil, ErrConnectionClosed } @@ -74,7 +74,7 @@ type cLeaderParticipation struct { // WaitForLeadership is part of the topo.LeaderParticipation interface. func (mp *cLeaderParticipation) WaitForLeadership() (context.Context, error) { - if mp.c.closed { + if mp.c.closed.Load() { return nil, ErrConnectionClosed } @@ -122,7 +122,7 @@ func (mp *cLeaderParticipation) Stop() { // GetCurrentLeaderID is part of the topo.LeaderParticipation interface func (mp *cLeaderParticipation) GetCurrentLeaderID(ctx context.Context) (string, error) { - if mp.c.closed { + if mp.c.closed.Load() { return "", ErrConnectionClosed } @@ -141,7 +141,7 @@ func (mp *cLeaderParticipation) GetCurrentLeaderID(ctx context.Context) (string, // WaitForNewLeader is part of the topo.LeaderParticipation interface func (mp *cLeaderParticipation) WaitForNewLeader(ctx context.Context) (<-chan string, error) { - if mp.c.closed { + if mp.c.closed.Load() { return nil, ErrConnectionClosed } diff --git a/go/vt/topo/memorytopo/lock.go b/go/vt/topo/memorytopo/lock.go index 0545ba8b182..afce7868469 100644 --- a/go/vt/topo/memorytopo/lock.go +++ b/go/vt/topo/memorytopo/lock.go @@ -116,7 +116,7 @@ func (ld *memoryTopoLockDescriptor) Unlock(ctx context.Context) error { } func (c *Conn) unlock(ctx context.Context, dirPath string) error { - if c.closed { + if c.closed.Load() { return ErrConnectionClosed } diff --git a/go/vt/topo/memorytopo/memorytopo.go b/go/vt/topo/memorytopo/memorytopo.go index 28bb035f2f8..ccc928991bc 100644 --- a/go/vt/topo/memorytopo/memorytopo.go +++ b/go/vt/topo/memorytopo/memorytopo.go @@ -25,6 +25,7 @@ import ( "math/rand" "strings" "sync" + "sync/atomic" "time" "vitess.io/vitess/go/stats" @@ -135,13 +136,13 @@ type Conn struct { factory *Factory cell string serverAddr string - closed bool + closed atomic.Bool } // dial returns immediately, unless the Conn points to the sentinel // UnreachableServerAddr, in which case it will block until the context expires. func (c *Conn) dial(ctx context.Context) error { - if c.closed { + if c.closed.Load() { return ErrConnectionClosed } if c.serverAddr == UnreachableServerAddr { @@ -154,7 +155,7 @@ func (c *Conn) dial(ctx context.Context) error { // Close is part of the topo.Conn interface. func (c *Conn) Close() { c.factory.callstats.Add([]string{"Close"}, 1) - c.closed = true + c.closed.Store(true) } type watch struct { diff --git a/go/vt/topo/memorytopo/watch.go b/go/vt/topo/memorytopo/watch.go index 8d9ef5cb54c..3651bcca9ce 100644 --- a/go/vt/topo/memorytopo/watch.go +++ b/go/vt/topo/memorytopo/watch.go @@ -27,7 +27,7 @@ import ( func (c *Conn) Watch(ctx context.Context, filePath string) (*topo.WatchData, <-chan *topo.WatchData, error) { c.factory.callstats.Add([]string{"Watch"}, 1) - if c.closed { + if c.closed.Load() { return nil, nil, ErrConnectionClosed } @@ -79,7 +79,7 @@ func (c *Conn) Watch(ctx context.Context, filePath string) (*topo.WatchData, <-c func (c *Conn) WatchRecursive(ctx context.Context, dirpath string) ([]*topo.WatchDataRecursive, <-chan *topo.WatchDataRecursive, error) { c.factory.callstats.Add([]string{"WatchRecursive"}, 1) - if c.closed { + if c.closed.Load() { return nil, nil, ErrConnectionClosed }