Skip to content

Commit

Permalink
Merge pull request #4457 from tonistiigi/cancel-cause
Browse files Browse the repository at this point in the history
replace context.WithCancel with WithCancelCause
  • Loading branch information
tonistiigi authored Dec 12, 2023
2 parents 4ab32be + 09648f4 commit 7eb2c8e
Show file tree
Hide file tree
Showing 59 changed files with 274 additions and 226 deletions.
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ linters-settings:
forbid:
- '^fmt\.Errorf(# use errors\.Errorf instead)?$'
- '^logrus\.(Trace|Debug|Info|Warn|Warning|Error|Fatal)(f|ln)?(# use bklog\.G or bklog\.L instead of logrus directly)?$'
- '^context\.WithCancel(# use context\.WithCancelCause instead)?$'
- '^context\.WithTimeout(# use context\.WithTimeoutCause instead)?$'
- '^context\.WithDeadline(# use context\.WithDeadline instead)?$'
- '^ctx\.Err(# use context\.Cause instead)?$'
importas:
alias:
- pkg: "github.com/opencontainers/image-spec/specs-go/v1"
Expand Down
2 changes: 1 addition & 1 deletion cache/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ func (cm *cacheManager) prune(ctx context.Context, ch chan client.UsageInfo, opt

select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
default:
return cm.prune(ctx, ch, opt)
}
Expand Down
10 changes: 6 additions & 4 deletions cache/remotecache/azblob/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ func (ce *exporter) uploadManifest(ctx context.Context, manifestKey string, read
return errors.Wrap(err, "error creating container client")
}

ctx, cnclFn := context.WithTimeout(ctx, time.Minute*5)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))

_, err = blobClient.Upload(ctx, reader, &azblob.BlockBlobUploadOptions{})
if err != nil {
Expand All @@ -170,8 +171,9 @@ func (ce *exporter) uploadBlobIfNotExists(ctx context.Context, blobKey string, r
return errors.Wrap(err, "error creating container client")
}

uploadCtx, cnclFn := context.WithTimeout(ctx, time.Minute*5)
defer cnclFn()
uploadCtx, cnclFn := context.WithCancelCause(ctx)
uploadCtx, _ = context.WithTimeoutCause(uploadCtx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))

// Only upload if the blob doesn't exist
eTagAny := azblob.ETagAny
Expand Down
15 changes: 9 additions & 6 deletions cache/remotecache/azblob/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ func createContainerClient(ctx context.Context, config *Config) (*azblob.Contain
}
}

ctx, cnclFn := context.WithTimeout(ctx, time.Second*60)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Second*60, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))

containerClient, err := serviceClient.NewContainerClient(config.Container)
if err != nil {
Expand All @@ -148,8 +149,9 @@ func createContainerClient(ctx context.Context, config *Config) (*azblob.Contain

var se *azblob.StorageError
if errors.As(err, &se) && se.ErrorCode == azblob.StorageErrorCodeContainerNotFound {
ctx, cnclFn := context.WithTimeout(ctx, time.Minute*5)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))
_, err := containerClient.Create(ctx, &azblob.ContainerCreateOptions{})
if err != nil {
return nil, errors.Wrapf(err, "failed to create cache container %s", config.Container)
Expand Down Expand Up @@ -177,8 +179,9 @@ func blobExists(ctx context.Context, containerClient *azblob.ContainerClient, bl
return false, errors.Wrap(err, "error creating blob client")
}

ctx, cnclFn := context.WithTimeout(ctx, time.Second*60)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Second*60, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))
_, err = blobClient.GetProperties(ctx, &azblob.BlobGetPropertiesOptions{})
if err == nil {
return true, nil
Expand Down
5 changes: 3 additions & 2 deletions cache/remotecache/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ func getContentStore(ctx context.Context, sm *session.Manager, g session.Group,
if sessionID == "" {
return nil, errors.New("local cache exporter/importer requires session")
}
timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
timeoutCtx, cancel := context.WithCancelCause(context.Background())
timeoutCtx, _ = context.WithTimeoutCause(timeoutCtx, 5*time.Second, errors.WithStack(context.DeadlineExceeded))
defer cancel(errors.WithStack(context.Canceled))

caller, err := sm.Get(timeoutCtx, sessionID, false)
if err != nil {
Expand Down
16 changes: 8 additions & 8 deletions client/build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ func testClientGatewayContainerPID1Tty(t *testing.T, sb integration.Sandbox) {
output := bytes.NewBuffer(nil)

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()

st := llb.Image("busybox:latest")
Expand Down Expand Up @@ -1015,7 +1015,7 @@ func testClientGatewayContainerCancelPID1Tty(t *testing.T, sb integration.Sandbo
output := bytes.NewBuffer(nil)

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
ctx, cancel := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer cancel()

st := llb.Image("busybox:latest")
Expand Down Expand Up @@ -1141,7 +1141,7 @@ func testClientGatewayContainerExecTty(t *testing.T, sb integration.Sandbox) {
inputR, inputW := io.Pipe()
output := bytes.NewBuffer(nil)
b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()
st := llb.Image("busybox:latest")

Expand Down Expand Up @@ -1233,7 +1233,7 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo
inputR, inputW := io.Pipe()
output := bytes.NewBuffer(nil)
b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()
st := llb.Image("busybox:latest")

Expand Down Expand Up @@ -1266,8 +1266,8 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo
defer pid1.Wait()
defer ctr.Release(ctx)

execCtx, cancel := context.WithCancel(ctx)
defer cancel()
execCtx, cancel := context.WithCancelCause(ctx)
defer cancel(errors.WithStack(context.Canceled))

prompt := newTestPrompt(execCtx, t, inputW, output)
pid2, err := ctr.Start(execCtx, client.StartRequest{
Expand All @@ -1281,7 +1281,7 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo
require.NoError(t, err)

prompt.SendExpect("echo hi", "hi")
cancel()
cancel(errors.WithStack(context.Canceled))

err = pid2.Wait()
require.ErrorIs(t, err, context.Canceled)
Expand Down Expand Up @@ -2132,7 +2132,7 @@ func testClientGatewayContainerSignal(t *testing.T, sb integration.Sandbox) {
product := "buildkit_test"

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()

st := llb.Image("busybox:latest")
Expand Down
2 changes: 1 addition & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (c *Client) Wait(ctx context.Context) error {

select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-time.After(time.Second):
}
c.conn.ResetConnectBackoff()
Expand Down
6 changes: 3 additions & 3 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7407,8 +7407,8 @@ func testInvalidExporter(t *testing.T, sb integration.Sandbox) {

// moby/buildkit#492
func testParallelLocalBuilds(t *testing.T, sb integration.Sandbox) {
ctx, cancel := context.WithCancel(sb.Context())
defer cancel()
ctx, cancel := context.WithCancelCause(sb.Context())
defer cancel(errors.WithStack(context.Canceled))

c, err := New(ctx, sb.Address())
require.NoError(t, err)
Expand Down Expand Up @@ -9832,7 +9832,7 @@ func testLLBMountPerformance(t *testing.T, sb integration.Sandbox) {
def, err := st.Marshal(sb.Context())
require.NoError(t, err)

timeoutCtx, cancel := context.WithTimeout(sb.Context(), time.Minute)
timeoutCtx, cancel := context.WithTimeoutCause(sb.Context(), time.Minute, nil)
defer cancel()
_, err = c.Solve(timeoutCtx, def, SolveOpt{}, nil)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion client/llb/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (as *asyncState) Do(ctx context.Context, c *Constraints) error {
if err != nil {
select {
case <-ctx.Done():
if errors.Is(err, ctx.Err()) {
if errors.Is(err, context.Cause(ctx)) {
return res, err
}
default:
Expand Down
12 changes: 6 additions & 6 deletions client/solve.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
}
eg, ctx := errgroup.WithContext(ctx)

statusContext, cancelStatus := context.WithCancel(context.Background())
defer cancelStatus()
statusContext, cancelStatus := context.WithCancelCause(context.Background())
defer cancelStatus(errors.WithStack(context.Canceled))

if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
statusContext = trace.ContextWithSpan(statusContext, span)
Expand Down Expand Up @@ -230,16 +230,16 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
frontendAttrs[k] = v
}

solveCtx, cancelSolve := context.WithCancel(ctx)
solveCtx, cancelSolve := context.WithCancelCause(ctx)
var res *SolveResponse
eg.Go(func() error {
ctx := solveCtx
defer cancelSolve()
defer cancelSolve(errors.WithStack(context.Canceled))

defer func() { // make sure the Status ends cleanly on build errors
go func() {
<-time.After(3 * time.Second)
cancelStatus()
cancelStatus(errors.WithStack(context.Canceled))
}()
if !opt.SessionPreInitialized {
bklog.G(ctx).Debugf("stopping session")
Expand Down Expand Up @@ -298,7 +298,7 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
select {
case <-solveCtx.Done():
case <-time.After(5 * time.Second):
cancelSolve()
cancelSolve(errors.WithStack(context.Canceled))
}

return err
Expand Down
5 changes: 3 additions & 2 deletions cmd/buildctl/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ func ResolveClient(c *cli.Context) (*client.Client, error) {

timeout := time.Duration(c.GlobalInt("timeout"))
if timeout > 0 {
ctx2, cancel := context.WithTimeout(ctx, timeout*time.Second)
ctx2, cancel := context.WithCancelCause(ctx)
ctx2, _ = context.WithTimeoutCause(ctx2, timeout*time.Second, errors.WithStack(context.DeadlineExceeded))
ctx = ctx2
defer cancel()
defer cancel(errors.WithStack(context.Canceled))
}

cl, err := client.New(ctx, c.GlobalString("addr"), opts...)
Expand Down
14 changes: 7 additions & 7 deletions cmd/buildkitd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ func main() {
if os.Geteuid() > 0 {
return errors.New("rootless mode requires to be executed as the mapped root in a user namespace; you may use RootlessKit for setting up the namespace")
}
ctx, cancel := context.WithCancel(appcontext.Context())
defer cancel()
ctx, cancel := context.WithCancelCause(appcontext.Context())
defer cancel(errors.WithStack(context.Canceled))

cfg, err := config.LoadFile(c.GlobalString("config"))
if err != nil {
Expand Down Expand Up @@ -344,9 +344,9 @@ func main() {
select {
case serverErr := <-errCh:
err = serverErr
cancel()
cancel(err)
case <-ctx.Done():
err = ctx.Err()
err = context.Cause(ctx)
}

bklog.G(ctx).Infof("stopping server")
Expand Down Expand Up @@ -634,14 +634,14 @@ func unaryInterceptor(globalCtx context.Context, tp trace.TracerProvider) grpc.U
withTrace := otelgrpc.UnaryServerInterceptor(otelgrpc.WithTracerProvider(tp), otelgrpc.WithPropagators(propagators))

return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(errors.WithStack(context.Canceled))

go func() {
select {
case <-ctx.Done():
case <-globalCtx.Done():
cancel()
cancel(context.Cause(globalCtx))
}
}()

Expand Down
4 changes: 2 additions & 2 deletions control/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,10 @@ func (c *Controller) Session(stream controlapi.Control_SessionServer) error {
conn, closeCh, opts := grpchijack.Hijack(stream)
defer conn.Close()

ctx, cancel := context.WithCancel(stream.Context())
ctx, cancel := context.WithCancelCause(stream.Context())
go func() {
<-closeCh
cancel()
cancel(errors.WithStack(context.Canceled))
}()

err := c.opt.SessionManager.HandleConn(ctx, conn, opts)
Expand Down
5 changes: 3 additions & 2 deletions control/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ func (gwf *GatewayForwarder) lookupForwarder(ctx context.Context) (gateway.LLBBr
return nil, errors.New("no buildid found in context")
}

ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
ctx, cancel := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, 3*time.Second, errors.WithStack(context.DeadlineExceeded))
defer cancel(errors.WithStack(context.Canceled))

go func() {
<-ctx.Done()
Expand Down
17 changes: 9 additions & 8 deletions executor/containerdexecutor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func (w *containerdExecutor) Exec(ctx context.Context, id string, process execut
}
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case err, ok := <-details.done:
if !ok || err == nil {
return errors.Errorf("container %s has stopped", id)
Expand Down Expand Up @@ -336,8 +336,8 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces

// handle signals (and resize) in separate go loop so it does not
// potentially block the container cancel/exit status loop below.
eventCtx, eventCancel := context.WithCancel(ctx)
defer eventCancel()
eventCtx, eventCancel := context.WithCancelCause(ctx)
defer eventCancel(errors.WithStack(context.Canceled))
go func() {
for {
select {
Expand Down Expand Up @@ -371,21 +371,22 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces
}
}()

var cancel func()
var cancel func(error)
var killCtxDone <-chan struct{}
ctxDone := ctx.Done()
for {
select {
case <-ctxDone:
ctxDone = nil
var killCtx context.Context
killCtx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
killCtx, cancel = context.WithCancelCause(context.Background())
killCtx, _ = context.WithTimeoutCause(killCtx, 10*time.Second, errors.WithStack(context.DeadlineExceeded))
killCtxDone = killCtx.Done()
p.Kill(killCtx, syscall.SIGKILL)
io.Cancel()
case status := <-statusCh:
if cancel != nil {
cancel()
cancel(errors.WithStack(context.Canceled))
}
trace.SpanFromContext(ctx).AddEvent(
"Container exited",
Expand All @@ -403,15 +404,15 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces
}
select {
case <-ctx.Done():
exitErr.Err = errors.Wrap(ctx.Err(), exitErr.Error())
exitErr.Err = errors.Wrap(context.Cause(ctx), exitErr.Error())
default:
}
return exitErr
}
return nil
case <-killCtxDone:
if cancel != nil {
cancel()
cancel(errors.WithStack(context.Canceled))
}
io.Cancel()
return errors.Errorf("failed to kill process on cancel")
Expand Down
Loading

0 comments on commit 7eb2c8e

Please sign in to comment.