Skip to content

Commit

Permalink
Fixes a race in the scheduling limits. (#3417)
Browse files Browse the repository at this point in the history
* Fixes a race in the scheduling limits.

Found out the hard way that sharding might not listen correctly to context and so still run queries
after the upstream request ends. This can lead to panic when scheduling work which send on a closed channel.

I've added a safe guard to wait for all jobs to end before ending the original request but also attemped to fix the sharding downstreamer
to properly stop sending work when context is closed.

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>

* typo

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>
  • Loading branch information
cyriltovena authored Mar 2, 2021
1 parent 3e4566d commit 20ef66d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 18 deletions.
14 changes: 6 additions & 8 deletions pkg/querier/queryrange/downstreamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type instance struct {
}

func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQuery) ([]logql.Result, error) {
return in.For(queries, func(qry logql.DownstreamQuery) (logql.Result, error) {
return in.For(ctx, queries, func(qry logql.DownstreamQuery) (logql.Result, error) {
req := ParamsToLokiRequest(qry.Params).WithShards(qry.Shards).WithQuery(qry.Expr.String()).(*LokiRequest)
logger, ctx := spanlogger.New(ctx, "DownstreamHandler.instance")
defer logger.Finish()
Expand All @@ -72,6 +72,7 @@ func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQue

// For runs a function against a list of queries, collecting the results or returning an error. The indices are preserved such that input[i] maps to output[i].
func (in instance) For(
ctx context.Context,
queries []logql.DownstreamQuery,
fn func(logql.DownstreamQuery) (logql.Result, error),
) ([]logql.Result, error) {
Expand All @@ -81,16 +82,15 @@ func (in instance) For(
err error
}

done := make(chan struct{})
defer close(done)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan resp)

// Make one goroutine to dispatch the other goroutines, bounded by instance parallelism
go func() {
for i := 0; i < len(queries); i++ {
select {
case <-done:
case <-ctx.Done():
break
case <-in.locks:
go func(i int) {
Expand All @@ -108,7 +108,7 @@ func (in instance) For(

// Feed the result into the channel unless the work has completed.
select {
case <-done:
case <-ctx.Done():
case ch <- response:
}
}(i)
Expand All @@ -125,7 +125,6 @@ func (in instance) For(
results[resp.i] = resp.res
}
return results, nil

}

// convert to matrix
Expand All @@ -136,7 +135,6 @@ func sampleStreamToMatrix(streams []queryrange.SampleStream) parser.Value {
x.Metric = make(labels.Labels, 0, len(stream.Labels))
for _, l := range stream.Labels {
x.Metric = append(x.Metric, labels.Label(l))

}

x.Points = make([]promql.Point, 0, len(stream.Samples))
Expand Down
9 changes: 3 additions & 6 deletions pkg/querier/queryrange/downstreamer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ func TestResponseToResult(t *testing.T) {
}

func TestDownstreamHandler(t *testing.T) {

// Pretty poor test, but this is just a passthrough struct, so ensure we create locks
// and can consume them
h := DownstreamHandler{nil}
Expand Down Expand Up @@ -220,7 +219,7 @@ func TestInstanceFor(t *testing.T) {
var ct int

// ensure we can execute queries that number more than the parallelism parameter
_, err := in.For(queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
_, err := in.For(context.TODO(), queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
mtx.Lock()
defer mtx.Unlock()
ct++
Expand All @@ -233,7 +232,7 @@ func TestInstanceFor(t *testing.T) {
// ensure an early error abandons the other queues queries
in = mkIn()
ct = 0
_, err = in.For(queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
_, err = in.For(context.TODO(), queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
mtx.Lock()
defer mtx.Unlock()
ct++
Expand All @@ -250,6 +249,7 @@ func TestInstanceFor(t *testing.T) {

in = mkIn()
results, err := in.For(
context.TODO(),
[]logql.DownstreamQuery{
{
Shards: logql.Shards{
Expand All @@ -263,7 +263,6 @@ func TestInstanceFor(t *testing.T) {
},
},
func(qry logql.DownstreamQuery) (logql.Result, error) {

return logql.Result{
Data: logql.Streams{{
Labels: qry.Shards[0].String(),
Expand All @@ -285,7 +284,6 @@ func TestInstanceFor(t *testing.T) {
results,
)
ensureParallelism(t, in, in.parallelism)

}

func TestInstanceDownstream(t *testing.T) {
Expand Down Expand Up @@ -345,5 +343,4 @@ func TestInstanceDownstream(t *testing.T) {

require.Nil(t, err)
require.Equal(t, []logql.Result{expected}, results)

}
17 changes: 14 additions & 3 deletions pkg/querier/queryrange/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ func newWork(ctx context.Context, req queryrange.Request) work {
}

func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
var wg sync.WaitGroup
intermediate := make(chan work)
defer func() {
wg.Wait()
close(intermediate)
}()

ctx, cancel := context.WithCancel(r.Context())
defer cancel()

Expand All @@ -203,8 +210,6 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)
}

parallelism := rt.limits.MaxQueryParallelism(userid)
intermediate := make(chan work)
defer close(intermediate)

for i := 0; i < parallelism; i++ {
go func() {
Expand All @@ -222,13 +227,19 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)

response, err := rt.middleware.Wrap(
queryrange.HandlerFunc(func(ctx context.Context, r queryrange.Request) (queryrange.Response, error) {
wg.Add(1)
defer wg.Done()

if ctx.Err() != nil {
return nil, ctx.Err()
}
w := newWork(ctx, r)
intermediate <- w
select {
case response := <-w.result:
return response.response, response.err
case <-ctx.Done():
return nil, err
return nil, ctx.Err()
}
})).Do(ctx, request)
if err != nil {
Expand Down
30 changes: 29 additions & 1 deletion pkg/querier/queryrange/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func Test_seriesLimiter(t *testing.T) {
require.LessOrEqual(t, *c, 4)
}

func Test_MaxQueryPallelism(t *testing.T) {
func Test_MaxQueryParallelism(t *testing.T) {
maxQueryParallelism := 2
f, err := newfakeRoundTripper()
require.Nil(t, err)
Expand Down Expand Up @@ -186,3 +186,31 @@ func Test_MaxQueryPallelism(t *testing.T) {
maxFound := int(max.Load())
require.LessOrEqual(t, maxFound, maxQueryParallelism, "max query parallelism: ", maxFound, " went over the configured one:", maxQueryParallelism)
}

func Test_MaxQueryParallelismLateScheduling(t *testing.T) {
maxQueryParallelism := 2
f, err := newfakeRoundTripper()
require.Nil(t, err)

f.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// simulate some work
time.Sleep(20 * time.Millisecond)
}))
ctx := user.InjectOrgID(context.Background(), "foo")

r, err := http.NewRequestWithContext(ctx, "GET", "/query_range", http.NoBody)
require.Nil(t, err)

_, _ = NewLimitedRoundTripper(f, lokiCodec, fakeLimits{maxQueryParallelism: maxQueryParallelism},
queryrange.MiddlewareFunc(func(next queryrange.Handler) queryrange.Handler {
return queryrange.HandlerFunc(func(c context.Context, r queryrange.Request) (queryrange.Response, error) {
for i := 0; i < 10; i++ {
go func() {
_, _ = next.Do(c, &LokiRequest{})
}()
}
return nil, nil
})
}),
).RoundTrip(r)
}

0 comments on commit 20ef66d

Please sign in to comment.