diff --git a/src/aggregator/aggregator/forwarded_writer.go b/src/aggregator/aggregator/forwarded_writer.go index 882fc54aa3..43a98c2534 100644 --- a/src/aggregator/aggregator/forwarded_writer.go +++ b/src/aggregator/aggregator/forwarded_writer.go @@ -33,6 +33,7 @@ import ( xerrors "github.com/m3db/m3/src/x/errors" "github.com/uber-go/tally" + "go.uber.org/atomic" ) const ( @@ -141,7 +142,7 @@ type forwardedWriter struct { shard uint32 client client.AdminClient - closed bool + closed atomic.Bool aggregations map[idKey]*forwardedAggregation // Aggregations for each forward metric id metrics forwardedWriterMetrics aggregationMetrics *forwardedAggregationMetrics @@ -168,7 +169,7 @@ func (w *forwardedWriter) Register( metricID id.RawID, aggKey aggregationKey, ) (writeForwardedMetricFn, onForwardedAggregationDoneFn, error) { - if w.closed { + if w.closed.Load() { w.metrics.registerWriterClosed.Inc(1) return nil, nil, errForwardedWriterClosed } @@ -188,7 +189,7 @@ func (w *forwardedWriter) Unregister( metricID id.RawID, aggKey aggregationKey, ) error { - if w.closed { + if w.closed.Load() { w.metrics.unregisterWriterClosed.Inc(1) return errForwardedWriterClosed } @@ -219,6 +220,10 @@ func (w *forwardedWriter) Prepare() { } func (w *forwardedWriter) Flush() error { + if w.closed.Load() { + return errForwardedWriterClosed + } + if err := w.client.Flush(); err != nil { w.metrics.flushErrorsClient.Inc(1) return err @@ -230,12 +235,9 @@ func (w *forwardedWriter) Flush() error { // NB: Do not close the client here as it is shared by all the forward // writers. The aggregator is responsible for closing the client. func (w *forwardedWriter) Close() error { - if w.closed { + if w.closed.Swap(true) { return errForwardedWriterClosed } - w.closed = true - w.client = nil - w.aggregations = nil return nil } diff --git a/src/aggregator/aggregator/forwarded_writer_test.go b/src/aggregator/aggregator/forwarded_writer_test.go index ac8c806daa..9fc916f959 100644 --- a/src/aggregator/aggregator/forwarded_writer_test.go +++ b/src/aggregator/aggregator/forwarded_writer_test.go @@ -21,7 +21,9 @@ package aggregator import ( + "sync" "testing" + "time" "github.com/m3db/m3/src/aggregator/client" "github.com/m3db/m3/src/metrics/aggregation" @@ -32,6 +34,7 @@ import ( "github.com/m3db/m3/src/metrics/policy" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" ) @@ -54,11 +57,43 @@ func TestForwardedWriterRegisterWriterClosed(t *testing.T) { mt = metric.CounterType mid = id.RawID("foo") aggKey = testForwardedWriterAggregationKey + closed = make(chan struct{}) + wg sync.WaitGroup ) - w.Close() - _, _, err := w.Register(mt, mid, aggKey) - require.Equal(t, errForwardedWriterClosed, err) + c.EXPECT().Flush().AnyTimes() + + wg.Add(1) + go func() { + defer wg.Done() + + for { + var err error + + assert.NotPanics(t, func() { + if err = w.Flush(); err != nil { + require.Equal(t, errForwardedWriterClosed, err) + } + }) + + if err != nil { + break + } + time.Sleep(1 * time.Microsecond) + } + + assert.NotPanics(t, func() { + _, _, err := w.Register(mt, mid, aggKey) + require.Equal(t, errForwardedWriterClosed, err) + + err = w.Flush() + require.Equal(t, errForwardedWriterClosed, err) + }) + }() + + require.NoError(t, w.Close()) + close(closed) + wg.Wait() } func TestForwardedWriterRegisterNewAggregation(t *testing.T) { @@ -390,12 +425,10 @@ func TestForwardedWriterCloseWriterClosed(t *testing.T) { w = newForwardedWriter(0, c, tally.NoopScope) ) - // Close the writer and validate that the fields are nil'ed out. + // Close the writer require.NoError(t, w.Close()) fw := w.(*forwardedWriter) - require.True(t, fw.closed) - require.Nil(t, fw.client) - require.Nil(t, fw.aggregations) + require.True(t, fw.closed.Load()) // Closing the writer a second time results in an error. require.Equal(t, errForwardedWriterClosed, w.Close())