Skip to content

Commit

Permalink
Merge pull request #270 from avenga/oauth2-locking
Browse files Browse the repository at this point in the history
Oauth 2.0: Locking for token requests
  • Loading branch information
Marcel Ludwig authored Aug 5, 2021
2 parents 7ec6d3e + cafd693 commit 0f8a835
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Unreleased changes are available as `avenga/couper:edge` container.

* **Fixed**
* No GZIP compression for small response bodies ([#186](https://github.com/avenga/couper/issues/186))
* Missing synchronization for OAuth2 access token requests ([#270](https://github.com/avenga/couper/issues/270))
* Missing error type for [request](docs/REFERENCE.md#request-block)/[response](docs/REFERENCE.md#response-block) body, json_body or form_body related HCL evaluation errors ([#276](https://github.com/avenga/couper/pull/276))
* [`request.url`](./docs/REFERENCE.md#request) and [`backend_requests.<label>.url`](./docs/REFERENCE.md#backend_requests) now contain a query string if present ([#278](https://github.com/avenga/couper/pull/278))
* [`backend_responses.<label>.status`](./docs/REFERENCE.md#backend_responses) is now integer ([#278](https://github.com/avenga/couper/pull/278))
Expand Down
50 changes: 36 additions & 14 deletions handler/transport/oauth2_req_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"sync"

"github.com/avenga/couper/cache"
"github.com/avenga/couper/config"
Expand All @@ -19,6 +20,7 @@ type OAuth2ReqAuth struct {
oauth2Client *oauth2.Client
config *config.OAuth2ReqAuth
memStore *cache.MemoryStore
locks sync.Map
next http.RoundTripper
}

Expand All @@ -29,23 +31,36 @@ func NewOAuth2ReqAuth(conf *config.OAuth2ReqAuth, memStore *cache.MemoryStore,
config: conf,
oauth2Client: oauth2Client,
memStore: memStore,
locks: sync.Map{},
next: next,
}, nil
}

// RoundTrip implements the <http.RoundTripper> interface.
func (oa *OAuth2ReqAuth) RoundTrip(req *http.Request) (*http.Response, error) {
storageKey := fmt.Sprintf("%p|%s|%s", &oa.oauth2Client.Backend, oa.config.ClientID, oa.config.ClientSecret)
if data := oa.memStore.Get(storageKey); data != "" {
token, terr := oa.readAccessToken(data)
if terr != nil {
// TODO this error is not connected to the OAuth2 client's backend
// In fact this can only be a JSON parse error or a missing access_token,
// which will occur after having requested the token from the authorization
// server. So the erroneous response will never be stored.
return nil, errors.Backend.Label(oa.config.BackendName).Message("token read error").With(terr)
}

if token, terr := oa.readAccessToken(storageKey); terr != nil {
// TODO this error is not connected to the OAuth2 client's backend
// In fact this can only be a JSON parse error or a missing access_token,
// which will occur after having requested the token from the authorization
// server. So the erroneous response will never be stored.
return nil, errors.Backend.Label(oa.config.BackendName).Message("token read error").With(terr)
} else if token != "" {
req.Header.Set("Authorization", "Bearer "+token)

return oa.next.RoundTrip(req)
}

value, _ := oa.locks.LoadOrStore(storageKey, &sync.Mutex{})
mutex := value.(*sync.Mutex)
mutex.Lock()

if token, terr := oa.readAccessToken(storageKey); terr != nil {
mutex.Unlock()
return nil, errors.Backend.Label(oa.config.BackendName).Message("token read error").With(terr)
} else if token != "" {
mutex.Unlock()
req.Header.Set("Authorization", "Bearer "+token)

return oa.next.RoundTrip(req)
Expand All @@ -54,11 +69,14 @@ func (oa *OAuth2ReqAuth) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
tokenResponse, tokenResponseData, token, err := oa.oauth2Client.GetTokenResponse(ctx)
if err != nil {
mutex.Unlock()
return nil, errors.Backend.Label(oa.config.BackendName).Message("token request error").With(err)
}

oa.updateAccessToken(tokenResponse, tokenResponseData, storageKey)

mutex.Unlock()

req.Header.Set("Authorization", "Bearer "+token)

res, err := oa.next.RoundTrip(req)
Expand All @@ -79,13 +97,17 @@ func (oa *OAuth2ReqAuth) RoundTrip(req *http.Request) (*http.Response, error) {
return res, err
}

func (oa *OAuth2ReqAuth) readAccessToken(data string) (string, error) {
_, token, err := oauth2.ParseTokenResponse([]byte(data))
if err != nil {
return "", err
func (oa *OAuth2ReqAuth) readAccessToken(key string) (string, error) {
if data := oa.memStore.Get(key); data != "" {
_, token, err := oauth2.ParseTokenResponse([]byte(data))
if err != nil {
return "", errors.Backend.Label(oa.config.BackendName).Message("token read error").With(err)
}

return token, nil
}

return token, nil
return "", nil
}

func (oa *OAuth2ReqAuth) updateAccessToken(jsonBytes []byte, jData map[string]interface{}, key string) {
Expand Down
186 changes: 185 additions & 1 deletion server/http_oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -199,7 +201,7 @@ func TestEndpoints_OAuth2_Options(t *testing.T) {
}
}

func TestOAuth2AccessControl(t *testing.T) {
func TestOAuth2_AccessControl(t *testing.T) {
client := newClient()
helper := test.New(t)

Expand Down Expand Up @@ -385,3 +387,185 @@ func TestOAuth2AccessControl(t *testing.T) {
})
}
}

func TestOAuth2_Locking(t *testing.T) {
helper := test.New(t)
client := newClient()

token := "token-"
var oauthRequestCount int32
oauthOrigin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&oauthRequestCount, 1)
if req.URL.Path == "/oauth2" {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusOK)

n := fmt.Sprintf("%d", atomic.LoadInt32(&oauthRequestCount))
body := []byte(`{
"access_token": "` + token + n + `",
"token_type": "bearer",
"expires_in": 1.5
}`)

// Slow down token request to test locking.
time.Sleep(1 * time.Second)

_, werr := rw.Write(body)
helper.Must(werr)

return
}

rw.WriteHeader(http.StatusBadRequest)
}))
defer oauthOrigin.Close()

ResourceOrigin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/resource" {
if auth := req.Header.Get("Authorization"); auth != "" {
rw.Header().Set("Token", auth[len("Bearer "):])
rw.WriteHeader(http.StatusNoContent)
}

return
}

rw.WriteHeader(http.StatusNotFound)
}))
defer ResourceOrigin.Close()

confPath := "testdata/oauth2/1_retries_couper.hcl"
shutdown, hook := newCouperWithTemplate(
confPath, test.New(t), map[string]interface{}{
"asOrigin": oauthOrigin.URL,
"rsOrigin": ResourceOrigin.URL,
},
)

defer func() {
if t.Failed() {
for _, e := range hook.Entries {
println(e.String())
}
}
shutdown()
}()

req, err := http.NewRequest(http.MethodGet, "http://anyserver:8080/", nil)
helper.Must(err)

hook.Reset()

req.URL.Path = "/"

var responses []*http.Response
var wg sync.WaitGroup

addLock := &sync.Mutex{}
// Fire 5 requests in parallel...
waitCh := make(chan struct{})
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-waitCh
res, err := client.Do(req)
helper.Must(err)

addLock.Lock()
responses = append(responses, res)
addLock.Unlock()

}()
}
close(waitCh)
wg.Wait()

for _, res := range responses {
if res.StatusCode != http.StatusNoContent {
t.Errorf("Expected status NoContent, got: %d", res.StatusCode)
}

if token+"1" != res.Header.Get("Token") {
t.Errorf("Invalid token given: want %s1, got: %s", token, res.Header.Get("Token"))
}
}

if oauthRequestCount != 1 {
t.Errorf("Too many OAuth2 requests: want 1, got: %d", oauthRequestCount)
}

t.Run("Lock is effective", func(st *testing.T) {
// Wait until token has expired.
time.Sleep(2 * time.Second)
h := test.New(st)

// Fetch new token.
go func() {
res, err := client.Do(req)
h.Must(err)

if token+"2" != res.Header.Get("Token") {
st.Errorf("Received wrong token: want %s2, got: %s", token, res.Header.Get("Token"))
}
}()

// Slow response due to lock
go func() {
start := time.Now()
res, err := client.Do(req)
h.Must(err)
timeElapsed := time.Since(start).Seconds()

if token+"2" != res.Header.Get("Token") {
st.Errorf("Received wrong token: want %s2, got: %s", token, res.Header.Get("Token"))
}

if timeElapsed < 1 {
st.Errorf("Response came too fast: dysfunctional lock?! (%v s)", timeElapsed)
}
}()
})

t.Run("Mem store expiry", func(st *testing.T) {
// Wait again until token has expired.
time.Sleep(2 * time.Second)
h := test.New(st)
// Request fresh token and store in memstore
res, err := client.Do(req)
h.Must(err)
if res.StatusCode != http.StatusNoContent {
st.Errorf("Unexpected response status: want %d, got: %d", http.StatusNoContent, res.StatusCode)
}

if token+"3" != res.Header.Get("Token") {
st.Errorf("Received wrong token: want %s3, got: %s", token, res.Header.Get("Token"))
}

if oauthRequestCount != 3 {
st.Errorf("Unexpected number of OAuth2 requests: want 3, got: %d", oauthRequestCount)
}

// Disconnect OAuth server
oauthOrigin.Close()

// Next request gets token from memstore
res, err = client.Do(req)
h.Must(err)
if res.StatusCode != http.StatusNoContent {
st.Errorf("Unexpected response status: want %d, got: %d", http.StatusNoContent, res.StatusCode)
}

if token+"3" != res.Header.Get("Token") {
st.Errorf("Wrong token from mem store: want %s3, got: %s", token, res.Header.Get("Token"))
}

// Wait until token has expired. Next request accesses the OAuth server again.
time.Sleep(2 * time.Second)
res, err = newClient().Do(req)
h.Must(err)
if res.StatusCode != http.StatusBadGateway {
st.Errorf("Unexpected response status: want %d, got: %d", http.StatusBadGateway, res.StatusCode)
}
})
}

0 comments on commit 0f8a835

Please sign in to comment.