Skip to content

Commit

Permalink
Add unit tests for CreateChildContext (#5547)
Browse files Browse the repository at this point in the history
* change func() to emptyCancelFunc

* add edge cases to CreateChildContext

* add unit tests
  • Loading branch information
arzonus authored Dec 27, 2023
1 parent f2b177d commit 845a806
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 3 deletions.
20 changes: 17 additions & 3 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,25 +334,39 @@ func IsValidContext(ctx context.Context) error {
return nil
}

// emptyCancelFunc wraps an empty func by context.CancelFunc interface
var emptyCancelFunc = context.CancelFunc(func() {})

// CreateChildContext creates a child context which shorted context timeout
// from the given parent context
// tailroom must be in range [0, 1] and
// (1-tailroom) * parent timeout will be the new child context timeout
// if tailroom is less 0, tailroom will be considered as 0
// if tailroom is greater than 1, tailroom wil be considered as 1
func CreateChildContext(
parent context.Context,
tailroom float64,
) (context.Context, context.CancelFunc) {
if parent == nil {
return nil, func() {}
return nil, emptyCancelFunc
}
if parent.Err() != nil {
return parent, func() {}
return parent, emptyCancelFunc
}

now := time.Now()
deadline, ok := parent.Deadline()
if !ok || deadline.Before(now) {
return parent, func() {}
return parent, emptyCancelFunc
}

// if tailroom is about or less 0, then return a context with the same deadline as parent
if tailroom <= 0 {
return context.WithDeadline(parent, deadline)
}
// if tailroom is about or greater 1, then return a context with deadline of now
if tailroom >= 1 {
return context.WithDeadline(parent, now)
}

newDeadline := now.Add(time.Duration(math.Ceil(float64(deadline.Sub(now)) * (1.0 - tailroom))))
Expand Down
112 changes: 112 additions & 0 deletions common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"errors"
"fmt"
"math/rand"
"reflect"
"runtime"
"strconv"
"sync"
"testing"
Expand Down Expand Up @@ -799,3 +801,113 @@ func TestIsValidContext(t *testing.T) {
require.NoError(t, IsValidContext(ctx), "nil should be returned, because context timeout is later than now + contextExpireThreshold")
})
}

func TestCreateChildContext(t *testing.T) {
t.Run("nil parent", func(t *testing.T) {
gotCtx, gotFunc := CreateChildContext(nil, 0)
require.Nil(t, gotCtx)
require.Equal(t, funcName(emptyCancelFunc), funcName(gotFunc))
})
t.Run("canceled parent", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
gotCtx, gotFunc := CreateChildContext(ctx, 0)
require.Equal(t, ctx, gotCtx)
require.Equal(t, funcName(emptyCancelFunc), funcName(gotFunc))
})
t.Run("non-canceled parent without deadline", func(t *testing.T) {
ctx, _ := context.WithCancel(context.Background())
gotCtx, gotFunc := CreateChildContext(ctx, 0)
require.Equal(t, ctx, gotCtx)
require.Equal(t, funcName(emptyCancelFunc), funcName(gotFunc))
})
t.Run("context with deadline exceeded", func(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), -time.Second)
gotCtx, gotFunc := CreateChildContext(ctx, 0)
require.Equal(t, ctx, gotCtx)
require.Equal(t, funcName(emptyCancelFunc), funcName(gotFunc))
})

t.Run("tailroom is less or equal to 0", func(t *testing.T) {
testCase := func(t *testing.T, tailroom float64) {
deadline := time.Now().Add(time.Hour)
ctx, _ := context.WithDeadline(context.Background(), deadline)
gotCtx, gotFunc := CreateChildContext(ctx, tailroom)

gotDeadline, ok := gotCtx.Deadline()
require.True(t, ok)
require.Equal(t, deadline, gotDeadline, "deadline should be equal to parent deadline")

require.NotEqual(t, ctx, gotCtx)
require.NotEqual(t, funcName(emptyCancelFunc), funcName(gotFunc))
}

t.Run("0", func(t *testing.T) {
testCase(t, 0)
})
t.Run("-1", func(t *testing.T) {
testCase(t, -1)
})

})

t.Run("tailroom is greater or equal to 1", func(t *testing.T) {
testCase := func(t *testing.T, tailroom float64) {
deadline := time.Now().Add(time.Hour)
ctx, _ := context.WithDeadline(context.Background(), deadline)

// we can't mock time.Now, but we know that the deadline should be in
// range between the start and finish of function's execution
beforeNow := time.Now()
gotCtx, gotFunc := CreateChildContext(ctx, tailroom)
afterNow := time.Now()

gotDeadline, ok := gotCtx.Deadline()
require.True(t, ok)
require.NotEqual(t, deadline, gotDeadline)
require.Less(t, gotDeadline, deadline)

// gotDeadline should be between beforeNow and afterNow (exclusive)
require.GreaterOrEqual(t, afterNow, gotDeadline)
require.LessOrEqual(t, beforeNow, gotDeadline)

require.NotEqual(t, ctx, gotCtx)
require.NotEqual(t, funcName(emptyCancelFunc), funcName(gotFunc))
}
t.Run("1", func(t *testing.T) {
testCase(t, 1)
})
t.Run("2", func(t *testing.T) {
testCase(t, 2)
})
})
t.Run("tailroom is 0.5", func(t *testing.T) {
now := time.Now()
deadline := now.Add(time.Hour)

ctx, _ := context.WithDeadline(context.Background(), deadline)
gotCtx, gotFunc := CreateChildContext(ctx, 0.5)

gotDeadline, ok := gotCtx.Deadline()
require.True(t, ok)
require.NotEqual(t, deadline, gotDeadline)
require.Less(t, gotDeadline, deadline)

// we can't mock time.Now, so we assume that the deadline should be
// in range 29:59 and 30:01 minutes after start
minDeadline := now.Add(30*time.Minute - time.Second)
maxDeadline := now.Add(30*time.Minute + time.Second)

// gotDeadline should be between minDeadline and maxDeadline (exclusive)
require.GreaterOrEqual(t, maxDeadline, gotDeadline)
require.LessOrEqual(t, minDeadline, gotDeadline)

require.NotEqual(t, ctx, gotCtx)
require.NotEqual(t, funcName(emptyCancelFunc), funcName(gotFunc))
})
}

// funcName returns the name of the function
func funcName(fn any) string {
return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
}

0 comments on commit 845a806

Please sign in to comment.