Skip to content

Commit

Permalink
Add API support for cancelation contexts passed via QueryOptions and …
Browse files Browse the repository at this point in the history
…WriteOptions (hashicorp#8836)

Copy Consul API's format: QueryOptions.WithContext(context) will now return
a new QueryOption whose HTTP requests will be canceled with the context
provided (and similar for WriteOptions)
  • Loading branch information
benbuzbee authored and fredrikhgrelland committed Sep 28, 2020
1 parent d4ad874 commit a181747
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 2 deletions.
57 changes: 56 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"encoding/json"
"errors"
Expand Down Expand Up @@ -63,6 +64,10 @@ type QueryOptions struct {

// AuthToken is the secret ID of an ACL token
AuthToken string

// ctx is an optional context pass through to the underlying HTTP
// request layer. Use Context() and WithContext() to manage this.
ctx context.Context
}

// WriteOptions are used to parametrize a write
Expand All @@ -76,6 +81,10 @@ type WriteOptions struct {

// AuthToken is the secret ID of an ACL token
AuthToken string

// ctx is an optional context pass through to the underlying HTTP
// request layer. Use Context() and WithContext() to manage this.
ctx context.Context
}

// QueryMeta is used to return meta data about a query
Expand Down Expand Up @@ -517,6 +526,7 @@ type request struct {
token string
body io.Reader
obj interface{}
ctx context.Context
}

// setQueryOptions is used to annotate the request with
Expand Down Expand Up @@ -549,6 +559,7 @@ func (r *request) setQueryOptions(q *QueryOptions) {
for k, v := range q.Params {
r.params.Set(k, v)
}
r.ctx = q.Context()
}

// durToMsec converts a duration to a millisecond specified string
Expand All @@ -571,6 +582,7 @@ func (r *request) setWriteOptions(q *WriteOptions) {
if q.AuthToken != "" {
r.token = q.AuthToken
}
r.ctx = q.Context()
}

// toHTTP converts the request to an HTTP request
Expand All @@ -587,8 +599,15 @@ func (r *request) toHTTP() (*http.Request, error) {
}
}

ctx := func() context.Context {
if r.ctx != nil {
return r.ctx
}
return context.Background()
}()

// Create the HTTP request
req, err := http.NewRequest(r.method, r.url.RequestURI(), r.body)
req, err := http.NewRequestWithContext(ctx, r.method, r.url.RequestURI(), r.body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -982,3 +1001,39 @@ func requireOK(d time.Duration, resp *http.Response, e error) (time.Duration, *h
}
return d, resp, nil
}

// Context returns the context used for canceling HTTP requests related to this query
func (o *QueryOptions) Context() context.Context {
if o != nil && o.ctx != nil {
return o.ctx
}
return context.Background()
}

// WithContext creates a copy of the query options using the provided context to cancel related HTTP requests
func (o *QueryOptions) WithContext(ctx context.Context) *QueryOptions {
o2 := new(QueryOptions)
if o != nil {
*o2 = *o
}
o2.ctx = ctx
return o2
}

// Context returns the context used for canceling HTTP requests related to this write
func (o *WriteOptions) Context() context.Context {
if o != nil && o.ctx != nil {
return o.ctx
}
return context.Background()
}

// WithContext creates a copy of the write options using the provided context to cancel related HTTP requests
func (o *WriteOptions) WithContext(ctx context.Context) *WriteOptions {
o2 := new(WriteOptions)
if o != nil {
*o2 = *o
}
o2.ctx = ctx
return o2
}
49 changes: 49 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package api

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -198,6 +200,53 @@ func TestSetQueryOptions(t *testing.T) {
}
}

func TestQueryOptionsContext(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
c, s := makeClient(t, nil, nil)
defer s.Stop()
q := (&QueryOptions{
WaitIndex: 10000,
}).WithContext(ctx)

if q.ctx != ctx {
t.Fatalf("expected context to be set")
}

go func() {
cancel()
}()
_, _, err := c.Jobs().List(q)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected job wait to fail with canceled, got %s", err)
}
}

func TestWriteOptionsContext(t *testing.T) {
// No blocking query to test a real cancel of a pending request so
// just test that if we pass a pre-canceled context, writes fail quickly
t.Parallel()

c, err := NewClient(DefaultConfig())
if err != nil {
t.Fatalf("failed to initialize client: %s", err)
}

ctx, cancel := context.WithCancel(context.Background())
w := (&WriteOptions{}).WithContext(ctx)

if w.ctx != ctx {
t.Fatalf("expected context to be set")
}

cancel()

_, _, err = c.Jobs().Deregister("jobid", true, w)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected job to fail with canceled, got %s", err)
}
}

func TestSetWriteOptions(t *testing.T) {
t.Parallel()
c, s := makeClient(t, nil, nil)
Expand Down
57 changes: 56 additions & 1 deletion vendor/github.com/hashicorp/nomad/api/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a181747

Please sign in to comment.