Skip to content
This repository has been archived by the owner on May 19, 2020. It is now read-only.

Fix GetValidToken to save refreshed access tokens #1241

Merged
merged 1 commit into from
Oct 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion controllers/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (c *Context) Ping(rw web.ResponseWriter, req *web.Request) {

// LoginHandshake is the handler where we authenticate the user and the user authorizes this application access to information.
func (c *Context) LoginHandshake(rw web.ResponseWriter, req *web.Request) {
if token := helpers.GetValidToken(req.Request, c.Settings); token != nil {
if token := helpers.GetValidToken(req.Request, rw, c.Settings); token != nil {
// We should just go to dashboard if the user already has a valid token.
dashboardURL := fmt.Sprintf("%s%s", c.Settings.AppURL, "/#/dashboard")
http.Redirect(rw, req.Request, dashboardURL, http.StatusFound)
Expand Down
4 changes: 2 additions & 2 deletions controllers/secure.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type ResponseHandler func(http.ResponseWriter, *http.Response)
// If the token is 1) present and expired or 2) not present, it will return unauthorized.
func (c *SecureContext) OAuth(rw web.ResponseWriter, req *web.Request, next web.NextMiddlewareFunc) {
// Get valid token if it exists from session store.
if token := helpers.GetValidToken(req.Request, c.Settings); token != nil {
if token := helpers.GetValidToken(req.Request, rw, c.Settings); token != nil {
c.Token = *token
} else {
// If no token, return unauthorized.
Expand All @@ -53,7 +53,7 @@ func (c *SecureContext) LoginRequired(rw web.ResponseWriter, r *web.Request, nex
rw.Header().Set("pragma", "no-cache")
rw.Header().Set("expires", "-1")

token := helpers.GetValidToken(r.Request, c.Settings)
token := helpers.GetValidToken(r.Request, rw, c.Settings)
if token != nil {
next(rw, r)
} else {
Expand Down
42 changes: 16 additions & 26 deletions helpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package helpers
import (
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"time"

Expand All @@ -16,7 +15,7 @@ import (
var TimeoutConstant = time.Second * 20

// GetValidToken is a helper function that returns a token struct only if it finds a non expired token for the session.
func GetValidToken(req *http.Request, settings *Settings) *oauth2.Token {
func GetValidToken(req *http.Request, rw http.ResponseWriter, settings *Settings) *oauth2.Token {
// Get session from session store.
session, _ := settings.Sessions.Get(req, "session")
// If for some reason we can't get or create a session, bail out.
Expand All @@ -25,33 +24,24 @@ func GetValidToken(req *http.Request, settings *Settings) *oauth2.Token {
}

// Attempt to get the token from this session.
if token, ok := session.Values["token"].(oauth2.Token); ok {
// If valid, just return.
if token.Valid() {
return &token
}
token, ok := session.Values["token"].(oauth2.Token)
if !ok {
return nil
}

// Will ensure not expired
rv, err := settings.OAuthConfig.TokenSource(settings.CreateContext(), &token).Token()
if err != nil {
return nil
}

// Attempt to refresh token using oauth2 Client
// https://godoc.org/golang.org/x/oauth2#Config.Client
reqURL := fmt.Sprintf("%s%s", settings.ConsoleAPI, "/v2/info")
request, _ := http.NewRequest("GET", reqURL, nil)
request.Close = true
client := settings.OAuthConfig.Client(settings.CreateContext(), &token)
// Prevents lingering goroutines from living forever.
// http://stackoverflow.com/questions/16895294/how-to-set-timeout-for-http-get-requests-in-golang/25344458#25344458
client.Timeout = TimeoutConstant
resp, err := client.Do(request)
if resp != nil {
defer resp.Body.Close()
}
if err != nil {
return nil
}
return &token
// Did it change? if so save it in our cookie so we don't have to refresh again on every request
if rv.AccessToken != token.AccessToken || !rv.Expiry.Equal(token.Expiry) {
session.Values["token"] = *rv
session.Save(req, rw)
}

// If couldn't find token or if it's expired, return nil
return nil
return rv
}

// GenerateRandomBytes returns securely generated random bytes.
Expand Down
5 changes: 4 additions & 1 deletion helpers/helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package helpers_test

import (
"net/http/httptest"

"github.com/18F/cg-dashboard/helpers"
"github.com/18F/cg-dashboard/helpers/testhelpers"

Expand Down Expand Up @@ -41,14 +43,15 @@ var getValidTokenTests = []tokenTestData{
func TestGetValidToken(t *testing.T) {
mockRequest, _ := http.NewRequest("GET", "", nil)
mockSettings := helpers.Settings{}
mockResponse := httptest.NewRecorder()

for _, test := range getValidTokenTests {
// Initialize a new session store.
store := testhelpers.MockSessionStore{}
store.ResetSessionData(test.sessionData, test.sessionName)
mockSettings.Sessions = store

value := helpers.GetValidToken(mockRequest, &mockSettings)
value := helpers.GetValidToken(mockRequest, mockResponse, &mockSettings)
if (value == nil) == test.returnValueNull {
} else {
t.Errorf("Test %s did not meet expected value. Expected: %t. Actual: %t\n", test.testName, test.returnValueNull, (value == nil))
Expand Down