Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix backend ctx #529

Merged
merged 10 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ Unreleased changes are available as `avenga/couper:edge` container.

* **Fixed**
* configuration related panic while loading backends with [`oauth2` block](./docs/REFERENCE.md#oauth2-cc-block) which depends on other defined backends ([#524](https://github.com/avenga/couper/pull/524))
* erroneous retries for [`oauth2`](./docs/REFERENCE.md#oauth2-cc-block) backend authorization with `retries = 0` ([#528](https://github.com/avenga/couper/pull/528))
* erroneous retries for [`oauth2`](./docs/REFERENCE.md#oauth2-cc-block) backend authorization ([#529](https://github.com/avenga/couper/pull/529))
* with `retries = 0` ([#528](https://github.com/avenga/couper/pull/528))
* with `retries` > `0` and related origin configuration ([#529](https://github.com/avenga/couper/pull/529))
* race condition resulting in empty [`backends.<label>.health.state` variable](docs/REFERENCE.md#backends) ([#530](https://github.com/avenga/couper/pull/530))

---
Expand Down
42 changes: 27 additions & 15 deletions handler/transport/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,9 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {
}, err
}

// for token-request retry purposes
originalReq := req.Clone(req.Context())

if err := b.withTokenRequest(req); err != nil {
// originalReq for token-request retry purposes
originalReq, err := b.withTokenRequest(req)
if err != nil {
return nil, err
}

Expand All @@ -136,23 +135,34 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {

// Execute before <b.evalTransport()> due to right
// handling of query-params in the URL attribute.
if err := eval.ApplyRequestContext(hclCtx, ctxBody, req); err != nil {
if err = eval.ApplyRequestContext(hclCtx, ctxBody, req); err != nil {
return nil, err
}

// TODO: split timing eval
tc, err := b.evalTransport(hclCtx, ctxBody, req)
if err != nil {
return nil, err
}

// first traffic pins the origin settings to transportConfResult
b.transportOnce.Do(func() {
b.initOnce(tc)
})

deadlineErr := b.withTimeout(req, tc)
// use result and apply context timings
b.healthyMu.RLock()
tconf := b.transportConfResult
b.healthyMu.RUnlock()
tconf.ConnectTimeout = tc.ConnectTimeout
tconf.TTFBTimeout = tc.TTFBTimeout
tconf.Timeout = tc.Timeout

deadlineErr := b.withTimeout(req, &tconf)

req.URL.Host = tc.Origin
req.URL.Scheme = tc.Scheme
req.Host = tc.Hostname
req.URL.Host = tconf.Origin
req.URL.Scheme = tconf.Scheme
req.Host = tconf.Hostname

// handler.Proxy marks proxy round-trips since we should not handle headers twice.
_, isProxyReq := req.Context().Value(request.RoundTripProxy).(bool)
Expand Down Expand Up @@ -187,9 +197,9 @@ func (b *Backend) RoundTrip(req *http.Request) (*http.Response, error) {

var beresp *http.Response
if b.openAPIValidator != nil {
beresp, err = b.openAPIValidate(req, tc, deadlineErr)
beresp, err = b.openAPIValidate(req, &tconf, deadlineErr)
} else {
beresp, err = b.innerRoundTrip(req, tc, deadlineErr)
beresp, err = b.innerRoundTrip(req, &tconf, deadlineErr)
}

if err != nil {
Expand Down Expand Up @@ -307,23 +317,25 @@ func (b *Backend) innerRoundTrip(req *http.Request, tc *Config, deadlineErr <-ch
return beresp, nil
}

func (b *Backend) withTokenRequest(req *http.Request) error {
func (b *Backend) withTokenRequest(req *http.Request) (*http.Request, error) {
if b.tokenRequest == nil {
return nil
return nil, nil
}

trValue, _ := req.Context().Value(request.BackendTokenRequest).(string)
if trValue != "" { // prevent loop
return nil
return nil, nil
}

ctx := context.WithValue(req.Context(), request.BackendTokenRequest, "tr")
// Reset for upstream transport; prevent mixing values.
// tokenRequest will have their own backend configuration.
ctx = context.WithValue(ctx, request.BackendParams, nil)

originalReq := req.Clone(req.Context())

// WithContext() instead of Clone() due to header-map modification.
return b.tokenRequest.WithToken(req.WithContext(ctx))
return originalReq, b.tokenRequest.WithToken(req.WithContext(ctx))
}

func (b *Backend) withRetryTokenRequest(req *http.Request, res *http.Response) (bool, error) {
Expand Down
10 changes: 2 additions & 8 deletions handler/transport/oauth2_req_auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transport

import (
"context"
"fmt"
"net/http"
"sync"
Expand Down Expand Up @@ -82,14 +81,9 @@ func (oa *OAuth2ReqAuth) RetryWithToken(req *http.Request, res *http.Response) (

oa.memStore.Del(oa.storageKey)

if *oa.config.Retries < 1 {
return false, nil
}

ctx := req.Context()
if retries, ok := ctx.Value(request.TokenRequestRetries).(uint8); !ok || retries < *oa.config.Retries {
ctx = context.WithValue(ctx, request.TokenRequestRetries, retries+1)

if retries, ok := ctx.Value(request.TokenRequestRetries).(*uint8); !ok || *retries < *oa.config.Retries {
*retries++ // increase ptr value instead of context value
req.Header.Del("Authorization")
err := oa.WithToken(req.WithContext(ctx)) // WithContext due to header manipulation
return true, err
Expand Down
8 changes: 5 additions & 3 deletions logging/upstream_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ func (u *UpstreamLog) RoundTrip(req *http.Request) (*http.Response, error) {
fields["request"] = requestFields

berespBytes := int64(0)
logCtxCh := make(chan hcl.Body, 10)
logCtxCh := make(chan hcl.Body, 17) // TODO: Will block with oauth2 token retries >= 17
tokenRetries := uint8(0)
outctx := context.WithValue(req.Context(), request.LogCustomUpstream, logCtxCh)
outctx = context.WithValue(outctx, request.BackendBytes, &berespBytes)
outctx = context.WithValue(outctx, request.TokenRequestRetries, &tokenRetries)
oCtx, openAPIContext := validation.NewWithContext(outctx)
outreq := req.WithContext(httptrace.WithClientTrace(oCtx, clientTrace))

Expand Down Expand Up @@ -123,8 +125,8 @@ func (u *UpstreamLog) RoundTrip(req *http.Request) (*http.Response, error) {
if tr, ok := outreq.Context().Value(request.TokenRequest).(string); ok && tr != "" {
fields["token_request"] = tr

if retries, exist := outreq.Context().Value(request.TokenRequestRetries).(uint8); exist && retries > 0 {
fields["token_request_retry"] = retries
if tokenRetries > 0 {
fields["token_request_retry"] = tokenRetries
}
}

Expand Down
67 changes: 67 additions & 0 deletions server/http_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,70 @@ func TestBackend_Unhealthy(t *testing.T) {
})
}
}

func TestBackend_Oauth2_TokenEndpoint(t *testing.T) {
helper := test.New(t)

requestCount := 0
origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
_, werr := rw.Write([]byte(`{"path": "` + r.URL.Path + `"}`))
requestCount++
helper.Must(werr)
}))
defer origin.Close()

tokenEndpoint := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
_, werr := rw.Write([]byte(`{
"access_token": "my-token",
"expires_in": 120
}`))
helper.Must(werr)
}))
defer origin.Close()

retries := 3
shutdown, _ := newCouperWithTemplate("testdata/integration/backends/07_couper.hcl", helper,
map[string]interface{}{
"origin": origin.URL,
"token_endpoint": tokenEndpoint.URL,
"retries": retries,
})
defer shutdown()

client := test.NewHTTPClient()

req, err := http.NewRequest(http.MethodGet, "http://couper.dev:8080/test-path", nil)
helper.Must(err)
res, err := client.Do(req)
helper.Must(err)

if res.StatusCode != http.StatusUnauthorized {
t.Errorf("want status %d, got: %d", http.StatusUnauthorized, res.StatusCode)
}

if res.Header.Get("Content-Type") != "application/json" {
t.Errorf("want json content-type")
return
}

type result struct {
Path string
}

b, err := io.ReadAll(res.Body)
helper.Must(res.Body.Close())

r := &result{}
helper.Must(json.Unmarshal(b, r))

if r.Path != "/test-path" {
t.Errorf("path property want: %q, got: %q", "/test-path", r.Path)
}

if requestCount != retries+1 {
t.Errorf("unexpected number of requests, want: %d, got: %d", retries+1, requestCount)
}
}
3 changes: 2 additions & 1 deletion server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1864,14 +1864,15 @@ func TestHTTPServer_Endpoint_Evaluation(t *testing.T) {
exp expectation
}

// first traffic pins the origin (transport conf)
for _, tc := range []testCase{
{"/my-waffik/my.host.de/" + testBackend.Addr()[7:], expectation{
Host: "my.host.de",
Origin: testBackend.Addr()[7:],
Path: "/anything",
}},
{"/my-respo/my.host.com/" + testBackend.Addr()[7:], expectation{
Host: "my.host.com",
Host: "my.host.de",
Origin: testBackend.Addr()[7:],
Path: "/anything",
}},
Expand Down
23 changes: 23 additions & 0 deletions server/testdata/integration/backends/07_couper.hcl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
server {
api {
endpoint "/**" {
proxy {
backend = "rs"
}
}
}
}

definitions {
backend "rs" {
origin = "{{ .origin }}"

oauth2 {
grant_type = "client_credentials"
client_id = "cli"
client_secret = "cls"
token_endpoint = "{{ .token_endpoint }}/token"
retries = {{ .retries }}
}
}
}