Skip to content

Commit

Permalink
Merge pull request #1608 from liggitt/oauth_csrf_validation
Browse files Browse the repository at this point in the history
Merged by openshift-bot
  • Loading branch information
OpenShift Bot committed Apr 6, 2015
2 parents 444fb5f + 575e412 commit 9b0b0f8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 21 deletions.
76 changes: 57 additions & 19 deletions pkg/auth/oauth/external/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package external

import (
"encoding/base64"
"errors"
"fmt"
"net/http"
Expand All @@ -12,6 +13,7 @@ import (

authapi "github.com/openshift/origin/pkg/auth/api"
"github.com/openshift/origin/pkg/auth/oauth/handlers"
"github.com/openshift/origin/pkg/auth/server/csrf"
)

// Handler exposes an external oauth provider flow (including the call back) as an oauth.handlers.AuthenticationHandler to allow our internal oauth
Expand Down Expand Up @@ -78,7 +80,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
authData, err := authReq.HandleRequest(req)
if err != nil {
glog.V(4).Infof("Error handling request: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

Expand All @@ -89,7 +91,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
accessData, err := accessReq.GetToken()
if err != nil {
glog.V(4).Infof("Error getting access token:", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

Expand All @@ -98,66 +100,89 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
identity, ok, err := h.provider.GetUserIdentity(accessData)
if err != nil {
glog.V(4).Infof("Error getting userIdentityInfo info: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}
if !ok {
glog.V(4).Infof("Could not get userIdentityInfo info from access token")
h.errorHandler.AuthenticationError(errors.New("Could not get userIdentityInfo info from access token"), w, req)
err := errors.New("Could not get userIdentityInfo info from access token")
h.handleError(err, w, req)
return
}

user, err := h.mapper.UserFor(identity)
glog.V(4).Infof("Got userIdentityMapping: %#v", user)
if err != nil {
glog.V(4).Infof("Error creating or updating mapping for: %#v due to %v", identity, err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

ok, err = h.state.Check(authData.State, w, req)
if !ok {
glog.V(4).Infof("State is invalid")
h.errorHandler.AuthenticationError(errors.New("State is invalid"), w, req)
err := errors.New("State is invalid")
h.handleError(err, w, req)
return
}
if err != nil {
glog.V(4).Infof("Error verifying state: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}

_, err = h.success.AuthenticationSucceeded(user, authData.State, w, req)
if err != nil {
glog.V(4).Infof("Error calling success handler: %v", err)
h.errorHandler.AuthenticationError(err, w, req)
h.handleError(err, w, req)
return
}
}

func (h *Handler) handleError(err error, w http.ResponseWriter, req *http.Request) {
handled, err := h.errorHandler.AuthenticationError(err, w, req)
if handled {
return
}
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`An error occurred`))
}

// defaultState provides default state-building, validation, and parsing to contain CSRF and "then" redirection
type defaultState struct{}
type defaultState struct {
csrf csrf.CSRF
}

func DefaultState() State {
return defaultState{}
func DefaultState(csrf csrf.CSRF) State {
return &defaultState{csrf}
}

func (defaultState) Generate(w http.ResponseWriter, req *http.Request) (string, error) {
func (d *defaultState) Generate(w http.ResponseWriter, req *http.Request) (string, error) {
csrfToken, err := d.csrf.Generate(w, req)
if err != nil {
return "", err
}

state := url.Values{
"csrf": {"..."}, // TODO: get csrf
"csrf": {csrfToken},
"then": {req.URL.String()},
}
return state.Encode(), nil
return encodeState(state)
}

func (defaultState) Check(state string, w http.ResponseWriter, req *http.Request) (bool, error) {
values, err := url.ParseQuery(state)
func (d *defaultState) Check(state string, w http.ResponseWriter, req *http.Request) (bool, error) {
values, err := decodeState(state)
if err != nil {
return false, err
}
csrf := values.Get("csrf")
if csrf != "..." {
return false, fmt.Errorf("State did not contain valid CSRF token (expected %s, got %s)", "...", csrf)

ok, err := d.csrf.Check(req, csrf)
if err != nil {
return false, err
}
if !ok {
return false, fmt.Errorf("State did not contain a valid CSRF token")
}

then := values.Get("then")
Expand All @@ -169,7 +194,7 @@ func (defaultState) Check(state string, w http.ResponseWriter, req *http.Request
}

func (defaultState) AuthenticationSucceeded(user user.Info, state string, w http.ResponseWriter, req *http.Request) (bool, error) {
values, err := url.ParseQuery(state)
values, err := decodeState(state)
if err != nil {
return false, err
}
Expand All @@ -182,3 +207,16 @@ func (defaultState) AuthenticationSucceeded(user user.Info, state string, w http
http.Redirect(w, req, then, http.StatusFound)
return true, nil
}

// URL-encode, then base-64 encode for OAuth providers that don't do a good job of treating the state param like an opaque value
func encodeState(values url.Values) (string, error) {
return base64.URLEncoding.EncodeToString([]byte(values.Encode())), nil
}

func decodeState(state string) (url.Values, error) {
decodedState, err := base64.URLEncoding.DecodeString(state)
if err != nil {
return nil, err
}
return url.ParseQuery(string(decodedState))
}
4 changes: 2 additions & 2 deletions pkg/cmd/server/origin/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (c *AuthConfig) getAuthenticationHandler(mux cmdutil.Mux, errorHandler hand
return nil, fmt.Errorf("unexpected oauth provider %#v", provider)
}

state := external.DefaultState()
state := external.DefaultState(getCSRF())
oauthHandler, err := external.NewExternalOAuthRedirector(oauthProvider, state, c.Options.MasterPublicURL+callbackPath, successHandler, errorHandler, identityMapper)
if err != nil {
return nil, fmt.Errorf("unexpected error: %v", err)
Expand Down Expand Up @@ -406,7 +406,7 @@ func (c *AuthConfig) getAuthenticationSuccessHandler() handlers.AuthenticationSu

switch identityProvider.Provider.Object.(type) {
case (*configapi.OAuthRedirectingIdentityProvider):
successHandlers = append(successHandlers, external.DefaultState().(handlers.AuthenticationSuccessHandler))
successHandlers = append(successHandlers, external.DefaultState(getCSRF()).(handlers.AuthenticationSuccessHandler))
}

if !addedRedirectSuccessHandler && configapi.IsPasswordAuthenticator(identityProvider) {
Expand Down

0 comments on commit 9b0b0f8

Please sign in to comment.