From 29daa10f3efc4baf55cdc54072ab7a5e77f1734e Mon Sep 17 00:00:00 2001 From: Joshua MacDonald Date: Tue, 18 Jun 2024 12:20:52 -0700 Subject: [PATCH] [otelarrowreceiver] Ensure consume operations are not canceled at stream EOF (#33570) **Description:** Fixes a bug in the OTel Arrow receiver. When a stream reaches its end-of-life, the exporter closes the send channel and the receiver's `Recv()` loop receives an EOF error. This was inadvertently canceling a context too soon, such that requests in-flight during EOF would be canceled before finishing. **Link to tracking Issue:** #26491 **Testing:** Several tests cover this scenario, and they had to change for this fix. There is now an extra assertion in the healthy test channel to ensure that consumers never receive data on a canceled context (in testing). --- .chloggen/otelarrow-eof-cancel.yaml | 27 +++++ .../otelarrowreceiver/internal/arrow/arrow.go | 44 +++++--- .../internal/arrow/arrow_test.go | 102 ++++++++++-------- 3 files changed, 114 insertions(+), 59 deletions(-) create mode 100644 .chloggen/otelarrow-eof-cancel.yaml diff --git a/.chloggen/otelarrow-eof-cancel.yaml b/.chloggen/otelarrow-eof-cancel.yaml new file mode 100644 index 000000000000..fe50b49d6fc7 --- /dev/null +++ b/.chloggen/otelarrow-eof-cancel.yaml @@ -0,0 +1,27 @@ +# Use this changelog template to create an entry for release notes. + +# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix' +change_type: bug_fix + +# The name of the component, or a single word describing the area of concern, (e.g. filelogreceiver) +component: otelarrowreceiver + +# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`). +note: Ensure consume operations are not canceled at stream EOF. + +# Mandatory: One or more tracking issues related to the change. You can use the PR number here if no issue exists. +issues: [33570] + +# (Optional) One or more lines of additional information to render under the primary note. +# These lines will be padded with 2 spaces and then inserted directly into the document. +# Use pipe (|) for multiline entries. +subtext: + +# If your change doesn't affect end users or the exported elements of any package, +# you should instead start your pull request title with [chore] or use the "Skip Changelog" label. +# Optional: The change log or logs in which this entry should be included. +# e.g. '[user]' or '[user, api]' +# Include 'user' if the change is relevant to end users. +# Include 'api' if there is a change to a library API. +# Default: '[user]' +change_logs: [user] diff --git a/receiver/otelarrowreceiver/internal/arrow/arrow.go b/receiver/otelarrowreceiver/internal/arrow/arrow.go index 1aee357d9df1..facd400c29a2 100644 --- a/receiver/otelarrowreceiver/internal/arrow/arrow.go +++ b/receiver/otelarrowreceiver/internal/arrow/arrow.go @@ -378,9 +378,8 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr doneCtx, doneCancel := context.WithCancel(streamCtx) defer doneCancel() - // streamErrCh returns up to two errors from the sender and - // receiver threads started below. - streamErrCh := make(chan error, 2) + recvErrCh := make(chan error, 1) + sendErrCh := make(chan error, 1) pendingCh := make(chan batchResp, runtime.NumCPU()) // wg is used to ensure this thread returns after both @@ -390,6 +389,11 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr sendWG.Add(1) recvWG.Add(1) + // flushCtx controls the start of flushing. when this is canceled + // after the receiver finishes, the flush operation begins. + flushCtx, flushCancel := context.WithCancel(doneCtx) + defer flushCancel() + rstream := &receiverStream{ Receiver: r, } @@ -399,27 +403,41 @@ func (r *Receiver) anyStream(serverStream anyStreamServer, method string) (retEr defer recvWG.Done() defer r.recoverErr(&err) err = rstream.srvReceiveLoop(doneCtx, serverStream, pendingCh, method, ac) - streamErrCh <- err + recvErrCh <- err }() go func() { var err error defer sendWG.Done() defer r.recoverErr(&err) - err = rstream.srvSendLoop(doneCtx, serverStream, &recvWG, pendingCh) - streamErrCh <- err + // the sender receives flushCtx, which is canceled after the + // receiver returns (success or no). + err = rstream.srvSendLoop(flushCtx, serverStream, &recvWG, pendingCh) + sendErrCh <- err }() // Wait for sender/receiver threads to return before returning. defer recvWG.Wait() defer sendWG.Wait() - select { - case <-doneCtx.Done(): - return status.Error(codes.Canceled, "server stream shutdown") - case retErr = <-streamErrCh: - doneCancel() - return + for { + select { + case <-doneCtx.Done(): + return status.Error(codes.Canceled, "server stream shutdown") + case err := <-recvErrCh: + flushCancel() + if errors.Is(err, io.EOF) { + // the receiver returned EOF, next we + // expect the sender to finish. + continue + } + return err + case err := <-sendErrCh: + // explicit cancel here, in case the sender fails before + // the receiver does. break the receiver loop here: + doneCancel() + return err + } } } @@ -555,7 +573,7 @@ func (r *receiverStream) recvOne(streamCtx context.Context, serverStream anyStre if err != nil { if errors.Is(err, io.EOF) { - return status.Error(codes.Canceled, "client stream shutdown") + return err } else if errors.Is(err, context.Canceled) { return status.Error(codes.Canceled, "server stream shutdown") } diff --git a/receiver/otelarrowreceiver/internal/arrow/arrow_test.go b/receiver/otelarrowreceiver/internal/arrow/arrow_test.go index a11362789b01..a8a856ed93e1 100644 --- a/receiver/otelarrowreceiver/internal/arrow/arrow_test.go +++ b/receiver/otelarrowreceiver/internal/arrow/arrow_test.go @@ -97,19 +97,42 @@ type commonTestCase struct { } type testChannel interface { - onConsume() error + onConsume(ctx context.Context) error } -type healthyTestChannel struct{} +type healthyTestChannel struct { + t *testing.T +} + +func newHealthyTestChannel(t *testing.T) *healthyTestChannel { + return &healthyTestChannel{t: t} +} -func (healthyTestChannel) onConsume() error { - return nil +func (h healthyTestChannel) onConsume(ctx context.Context) error { + select { + case <-ctx.Done(): + h.t.Error("unexpected consume with canceled request") + return ctx.Err() + default: + return nil + } } -type unhealthyTestChannel struct{} +type unhealthyTestChannel struct { + t *testing.T +} -func (unhealthyTestChannel) onConsume() error { - return status.Errorf(codes.Unavailable, "consumer unhealthy") +func newUnhealthyTestChannel(t *testing.T) *unhealthyTestChannel { + return &unhealthyTestChannel{t: t} +} + +func (u unhealthyTestChannel) onConsume(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return status.Errorf(codes.Unavailable, "consumer unhealthy") + } } type recvResult struct { @@ -160,7 +183,7 @@ func (ctc *commonTestCase) doAndReturnConsumeTraces(tc testChannel) func(ctx con Ctx: ctx, Data: traces, } - return tc.onConsume() + return tc.onConsume(ctx) } } @@ -170,7 +193,7 @@ func (ctc *commonTestCase) doAndReturnConsumeMetrics(tc testChannel) func(ctx co Ctx: ctx, Data: metrics, } - return tc.onConsume() + return tc.onConsume(ctx) } } @@ -180,7 +203,7 @@ func (ctc *commonTestCase) doAndReturnConsumeLogs(tc testChannel) func(ctx conte Ctx: ctx, Data: logs, } - return tc.onConsume() + return tc.onConsume(ctx) } } @@ -420,7 +443,7 @@ func TestBoundedQueueWithPdataHeaders(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) td := testdata.GenerateTraces(tt.numTraces) @@ -468,7 +491,7 @@ func TestBoundedQueueWithPdataHeaders(t *testing.T) { func TestReceiverTraces(t *testing.T) { stdTesting := otelAssert.NewStdUnitTest(t) - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) td := testdata.GenerateTraces(2) @@ -491,7 +514,7 @@ func TestReceiverTraces(t *testing.T) { } func TestReceiverLogs(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) ld := testdata.GenerateLogs(2) @@ -510,7 +533,7 @@ func TestReceiverLogs(t *testing.T) { } func TestReceiverMetrics(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) stdTesting := otelAssert.NewStdUnitTest(t) @@ -534,7 +557,7 @@ func TestReceiverMetrics(t *testing.T) { } func TestReceiverRecvError(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) ctc.start(ctc.newRealConsumer, defaultBQ()) @@ -547,7 +570,7 @@ func TestReceiverRecvError(t *testing.T) { } func TestReceiverSendError(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) ld := testdata.GenerateLogs(2) @@ -587,7 +610,7 @@ func TestReceiverConsumeError(t *testing.T) { } for _, item := range data { - tc := unhealthyTestChannel{} + tc := newUnhealthyTestChannel(t) ctc := newCommonTestCase(t, tc) var batch *arrowpb.BatchArrowRecords @@ -646,7 +669,7 @@ func TestReceiverInvalidData(t *testing.T) { } for _, item := range data { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) var batch *arrowpb.BatchArrowRecords @@ -682,7 +705,7 @@ func TestReceiverMemoryLimit(t *testing.T) { } for _, item := range data { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) var batch *arrowpb.BatchArrowRecords @@ -738,7 +761,7 @@ func copyBatch(in *arrowpb.BatchArrowRecords) *arrowpb.BatchArrowRecords { } func TestReceiverEOF(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) stdTesting := otelAssert.NewStdUnitTest(t) @@ -771,9 +794,7 @@ func TestReceiverEOF(t *testing.T) { wg.Add(1) go func() { - err := ctc.wait() - // EOF is treated the same as Canceled. - requireCanceledStatus(t, err) + require.NoError(t, ctc.wait()) wg.Done() }() @@ -800,7 +821,7 @@ func TestReceiverHeadersNoAuth(t *testing.T) { } func testReceiverHeaders(t *testing.T, includeMeta bool) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) expectData := []map[string][]string{ @@ -855,9 +876,7 @@ func testReceiverHeaders(t *testing.T, includeMeta bool) { wg.Add(1) go func() { - err := ctc.wait() - // EOF is treated the same as Canceled. - requireCanceledStatus(t, err) + require.NoError(t, ctc.wait()) wg.Done() }() @@ -883,7 +902,7 @@ func testReceiverHeaders(t *testing.T, includeMeta bool) { } func TestReceiverCancel(t *testing.T) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) ctc.cancel() @@ -1159,7 +1178,7 @@ func TestReceiverAuthHeadersStream(t *testing.T) { } func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) { - tc := healthyTestChannel{} + tc := newHealthyTestChannel(t) ctc := newCommonTestCase(t, tc) expectData := []map[string][]string{ @@ -1245,7 +1264,7 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) { close(ctc.receive) }() - var expectErrs []bool + var expectCodes []arrowpb.StatusCode for _, testInput := range expectData { // The static stream context contains one extra variable. @@ -1256,7 +1275,7 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) { cpy[k] = v } - expectErr := false + expectCode := arrowpb.StatusCode_OK if dataAuth { hasAuth := false for _, val := range cpy["auth"] { @@ -1265,13 +1284,13 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) { if hasAuth { cpy["has_auth"] = []string{":+1:", ":100:"} } else { - expectErr = true + expectCode = arrowpb.StatusCode_UNAUTHENTICATED } } - expectErrs = append(expectErrs, expectErr) + expectCodes = append(expectCodes, expectCode) - if expectErr { + if expectCode != arrowpb.StatusCode_OK { continue } @@ -1286,23 +1305,14 @@ func testReceiverAuthHeaders(t *testing.T, includeMeta bool, dataAuth bool) { } } - err := ctc.wait() - // EOF is treated the same as Canceled - requireCanceledStatus(t, err) - - // Add in expectErrs for when receiver sees EOF, - // the status code will not be arrowpb.StatusCode_OK. - expectErrs = append(expectErrs, true) + require.NoError(t, ctc.wait()) + require.Equal(t, len(expectCodes), dataCount) require.Equal(t, len(expectData), dataCount) require.Equal(t, len(recvBatches), dataCount) for idx, batch := range recvBatches { - if expectErrs[idx] { - require.NotEqual(t, arrowpb.StatusCode_OK, batch.StatusCode) - } else { - require.Equal(t, arrowpb.StatusCode_OK, batch.StatusCode) - } + require.Equal(t, expectCodes[idx], batch.StatusCode) } }