From 20ef66df65b629ce93e3a04d5d40f98de71c0982 Mon Sep 17 00:00:00 2001 From: Cyril Tovena Date: Tue, 2 Mar 2021 15:10:15 +0100 Subject: [PATCH] Fixes a race in the scheduling limits. (#3417) * Fixes a race in the scheduling limits. Found out the hard way that sharding might not listen correctly to context and so still run queries after the upstream request ends. This can lead to panic when scheduling work which send on a closed channel. I've added a safe guard to wait for all jobs to end before ending the original request but also attemped to fix the sharding downstreamer to properly stop sending work when context is closed. Signed-off-by: Cyril Tovena * typo Signed-off-by: Cyril Tovena --- pkg/querier/queryrange/downstreamer.go | 14 +++++----- pkg/querier/queryrange/downstreamer_test.go | 9 +++---- pkg/querier/queryrange/limits.go | 17 +++++++++--- pkg/querier/queryrange/limits_test.go | 30 ++++++++++++++++++++- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/pkg/querier/queryrange/downstreamer.go b/pkg/querier/queryrange/downstreamer.go index c494790a13fa..e04602b2e06b 100644 --- a/pkg/querier/queryrange/downstreamer.go +++ b/pkg/querier/queryrange/downstreamer.go @@ -56,7 +56,7 @@ type instance struct { } func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQuery) ([]logql.Result, error) { - return in.For(queries, func(qry logql.DownstreamQuery) (logql.Result, error) { + return in.For(ctx, queries, func(qry logql.DownstreamQuery) (logql.Result, error) { req := ParamsToLokiRequest(qry.Params).WithShards(qry.Shards).WithQuery(qry.Expr.String()).(*LokiRequest) logger, ctx := spanlogger.New(ctx, "DownstreamHandler.instance") defer logger.Finish() @@ -72,6 +72,7 @@ func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQue // For runs a function against a list of queries, collecting the results or returning an error. The indices are preserved such that input[i] maps to output[i]. func (in instance) For( + ctx context.Context, queries []logql.DownstreamQuery, fn func(logql.DownstreamQuery) (logql.Result, error), ) ([]logql.Result, error) { @@ -81,16 +82,15 @@ func (in instance) For( err error } - done := make(chan struct{}) - defer close(done) - + ctx, cancel := context.WithCancel(ctx) + defer cancel() ch := make(chan resp) // Make one goroutine to dispatch the other goroutines, bounded by instance parallelism go func() { for i := 0; i < len(queries); i++ { select { - case <-done: + case <-ctx.Done(): break case <-in.locks: go func(i int) { @@ -108,7 +108,7 @@ func (in instance) For( // Feed the result into the channel unless the work has completed. select { - case <-done: + case <-ctx.Done(): case ch <- response: } }(i) @@ -125,7 +125,6 @@ func (in instance) For( results[resp.i] = resp.res } return results, nil - } // convert to matrix @@ -136,7 +135,6 @@ func sampleStreamToMatrix(streams []queryrange.SampleStream) parser.Value { x.Metric = make(labels.Labels, 0, len(stream.Labels)) for _, l := range stream.Labels { x.Metric = append(x.Metric, labels.Label(l)) - } x.Points = make([]promql.Point, 0, len(stream.Samples)) diff --git a/pkg/querier/queryrange/downstreamer_test.go b/pkg/querier/queryrange/downstreamer_test.go index 0c84f024a9d5..3c89eb84dacf 100644 --- a/pkg/querier/queryrange/downstreamer_test.go +++ b/pkg/querier/queryrange/downstreamer_test.go @@ -184,7 +184,6 @@ func TestResponseToResult(t *testing.T) { } func TestDownstreamHandler(t *testing.T) { - // Pretty poor test, but this is just a passthrough struct, so ensure we create locks // and can consume them h := DownstreamHandler{nil} @@ -220,7 +219,7 @@ func TestInstanceFor(t *testing.T) { var ct int // ensure we can execute queries that number more than the parallelism parameter - _, err := in.For(queries, func(_ logql.DownstreamQuery) (logql.Result, error) { + _, err := in.For(context.TODO(), queries, func(_ logql.DownstreamQuery) (logql.Result, error) { mtx.Lock() defer mtx.Unlock() ct++ @@ -233,7 +232,7 @@ func TestInstanceFor(t *testing.T) { // ensure an early error abandons the other queues queries in = mkIn() ct = 0 - _, err = in.For(queries, func(_ logql.DownstreamQuery) (logql.Result, error) { + _, err = in.For(context.TODO(), queries, func(_ logql.DownstreamQuery) (logql.Result, error) { mtx.Lock() defer mtx.Unlock() ct++ @@ -250,6 +249,7 @@ func TestInstanceFor(t *testing.T) { in = mkIn() results, err := in.For( + context.TODO(), []logql.DownstreamQuery{ { Shards: logql.Shards{ @@ -263,7 +263,6 @@ func TestInstanceFor(t *testing.T) { }, }, func(qry logql.DownstreamQuery) (logql.Result, error) { - return logql.Result{ Data: logql.Streams{{ Labels: qry.Shards[0].String(), @@ -285,7 +284,6 @@ func TestInstanceFor(t *testing.T) { results, ) ensureParallelism(t, in, in.parallelism) - } func TestInstanceDownstream(t *testing.T) { @@ -345,5 +343,4 @@ func TestInstanceDownstream(t *testing.T) { require.Nil(t, err) require.Equal(t, []logql.Result{expected}, results) - } diff --git a/pkg/querier/queryrange/limits.go b/pkg/querier/queryrange/limits.go index fffec3feee63..ac3a9814e0c4 100644 --- a/pkg/querier/queryrange/limits.go +++ b/pkg/querier/queryrange/limits.go @@ -186,6 +186,13 @@ func newWork(ctx context.Context, req queryrange.Request) work { } func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + var wg sync.WaitGroup + intermediate := make(chan work) + defer func() { + wg.Wait() + close(intermediate) + }() + ctx, cancel := context.WithCancel(r.Context()) defer cancel() @@ -203,8 +210,6 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) } parallelism := rt.limits.MaxQueryParallelism(userid) - intermediate := make(chan work) - defer close(intermediate) for i := 0; i < parallelism; i++ { go func() { @@ -222,13 +227,19 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) response, err := rt.middleware.Wrap( queryrange.HandlerFunc(func(ctx context.Context, r queryrange.Request) (queryrange.Response, error) { + wg.Add(1) + defer wg.Done() + + if ctx.Err() != nil { + return nil, ctx.Err() + } w := newWork(ctx, r) intermediate <- w select { case response := <-w.result: return response.response, response.err case <-ctx.Done(): - return nil, err + return nil, ctx.Err() } })).Do(ctx, request) if err != nil { diff --git a/pkg/querier/queryrange/limits_test.go b/pkg/querier/queryrange/limits_test.go index 055c7e21ea30..ceaa1146f5eb 100644 --- a/pkg/querier/queryrange/limits_test.go +++ b/pkg/querier/queryrange/limits_test.go @@ -147,7 +147,7 @@ func Test_seriesLimiter(t *testing.T) { require.LessOrEqual(t, *c, 4) } -func Test_MaxQueryPallelism(t *testing.T) { +func Test_MaxQueryParallelism(t *testing.T) { maxQueryParallelism := 2 f, err := newfakeRoundTripper() require.Nil(t, err) @@ -186,3 +186,31 @@ func Test_MaxQueryPallelism(t *testing.T) { maxFound := int(max.Load()) require.LessOrEqual(t, maxFound, maxQueryParallelism, "max query parallelism: ", maxFound, " went over the configured one:", maxQueryParallelism) } + +func Test_MaxQueryParallelismLateScheduling(t *testing.T) { + maxQueryParallelism := 2 + f, err := newfakeRoundTripper() + require.Nil(t, err) + + f.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // simulate some work + time.Sleep(20 * time.Millisecond) + })) + ctx := user.InjectOrgID(context.Background(), "foo") + + r, err := http.NewRequestWithContext(ctx, "GET", "/query_range", http.NoBody) + require.Nil(t, err) + + _, _ = NewLimitedRoundTripper(f, lokiCodec, fakeLimits{maxQueryParallelism: maxQueryParallelism}, + queryrange.MiddlewareFunc(func(next queryrange.Handler) queryrange.Handler { + return queryrange.HandlerFunc(func(c context.Context, r queryrange.Request) (queryrange.Response, error) { + for i := 0; i < 10; i++ { + go func() { + _, _ = next.Do(c, &LokiRequest{}) + }() + } + return nil, nil + }) + }), + ).RoundTrip(r) +}