diff --git a/contrib/token-server/main.go b/contrib/token-server/main.go index 138793c732a..8f9029eaed3 100644 --- a/contrib/token-server/main.go +++ b/contrib/token-server/main.go @@ -245,7 +245,7 @@ func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *h // Get response context. ctx, w = dcontext.WithResponseWriter(ctx, w) - challenge.SetHeaders(w) + challenge.SetHeaders(r, w) handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail(challenge.Error()), w) dcontext.GetResponseLogger(ctx).Info("get token authentication challenge") diff --git a/registry/auth/auth.go b/registry/auth/auth.go index 91c7af3faec..835eff73dcf 100644 --- a/registry/auth/auth.go +++ b/registry/auth/auth.go @@ -21,7 +21,7 @@ // if ctx, err := accessController.Authorized(ctx, access); err != nil { // if challenge, ok := err.(auth.Challenge) { // // Let the challenge write the response. -// challenge.SetHeaders(w) +// challenge.SetHeaders(r, w) // w.WriteHeader(http.StatusUnauthorized) // return // } else { @@ -87,7 +87,7 @@ type Challenge interface { // adding the an HTTP challenge header on the response message. Callers // are expected to set the appropriate HTTP status code (e.g. 401) // themselves. - SetHeaders(w http.ResponseWriter) + SetHeaders(r *http.Request, w http.ResponseWriter) } // AccessController controls access to registry resources based on a request diff --git a/registry/auth/htpasswd/access.go b/registry/auth/htpasswd/access.go index eddf7ac3d30..2611a23be86 100644 --- a/registry/auth/htpasswd/access.go +++ b/registry/auth/htpasswd/access.go @@ -111,7 +111,7 @@ type challenge struct { var _ auth.Challenge = challenge{} // SetHeaders sets the basic challenge header on the response. -func (ch challenge) SetHeaders(w http.ResponseWriter) { +func (ch challenge) SetHeaders(r *http.Request, w http.ResponseWriter) { w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", ch.realm)) } diff --git a/registry/auth/htpasswd/access_test.go b/registry/auth/htpasswd/access_test.go index 7a3d411e8c9..0bfc427e4e4 100644 --- a/registry/auth/htpasswd/access_test.go +++ b/registry/auth/htpasswd/access_test.go @@ -50,7 +50,7 @@ func TestBasicAccessController(t *testing.T) { if err != nil { switch err := err.(type) { case auth.Challenge: - err.SetHeaders(w) + err.SetHeaders(r, w) w.WriteHeader(http.StatusUnauthorized) return default: diff --git a/registry/auth/silly/access.go b/registry/auth/silly/access.go index f7bbe6e0816..3ead560dff6 100644 --- a/registry/auth/silly/access.go +++ b/registry/auth/silly/access.go @@ -82,7 +82,7 @@ type challenge struct { var _ auth.Challenge = challenge{} // SetHeaders sets a simple bearer challenge on the response. -func (ch challenge) SetHeaders(w http.ResponseWriter) { +func (ch challenge) SetHeaders(r *http.Request, w http.ResponseWriter) { header := fmt.Sprintf("Bearer realm=%q,service=%q", ch.realm, ch.service) if ch.scope != "" { diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index 0a5103e6c34..19824494962 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -21,7 +21,7 @@ func TestSillyAccessController(t *testing.T) { if err != nil { switch err := err.(type) { case auth.Challenge: - err.SetHeaders(w) + err.SetHeaders(r, w) w.WriteHeader(http.StatusUnauthorized) return default: diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index 3086c2cfb0f..33d18a48c07 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -76,10 +76,11 @@ var ( // authChallenge implements the auth.Challenge interface. type authChallenge struct { - err error - realm string - service string - accessSet accessSet + err error + realm string + autoRedirect bool + service string + accessSet accessSet } var _ auth.Challenge = authChallenge{} @@ -97,8 +98,14 @@ func (ac authChallenge) Status() int { // challengeParams constructs the value to be used in // the WWW-Authenticate response challenge header. // See https://tools.ietf.org/html/rfc6750#section-3 -func (ac authChallenge) challengeParams() string { - str := fmt.Sprintf("Bearer realm=%q,service=%q", ac.realm, ac.service) +func (ac authChallenge) challengeParams(r *http.Request) string { + var realm string + if ac.autoRedirect { + realm = fmt.Sprintf("https://%s/auth/token", r.Host) + } else { + realm = ac.realm + } + str := fmt.Sprintf("Bearer realm=%q,service=%q", realm, ac.service) if scope := ac.accessSet.scopeParam(); scope != "" { str = fmt.Sprintf("%s,scope=%q", str, scope) @@ -114,23 +121,25 @@ func (ac authChallenge) challengeParams() string { } // SetChallenge sets the WWW-Authenticate value for the response. -func (ac authChallenge) SetHeaders(w http.ResponseWriter) { - w.Header().Add("WWW-Authenticate", ac.challengeParams()) +func (ac authChallenge) SetHeaders(r *http.Request, w http.ResponseWriter) { + w.Header().Add("WWW-Authenticate", ac.challengeParams(r)) } // accessController implements the auth.AccessController interface. type accessController struct { - realm string - issuer string - service string - rootCerts *x509.CertPool - trustedKeys map[string]libtrust.PublicKey + realm string + autoRedirect bool + issuer string + service string + rootCerts *x509.CertPool + trustedKeys map[string]libtrust.PublicKey } // tokenAccessOptions is a convenience type for handling // options to the contstructor of an accessController. type tokenAccessOptions struct { realm string + autoRedirect bool issuer string service string rootCertBundle string @@ -153,6 +162,12 @@ func checkOptions(options map[string]interface{}) (tokenAccessOptions, error) { opts.realm, opts.issuer, opts.service, opts.rootCertBundle = vals[0], vals[1], vals[2], vals[3] + autoRedirect, ok := options["autoredirect"].(bool) + if !ok { + return opts, fmt.Errorf("token auth requires a valid option bool: autoredirect") + } + opts.autoRedirect = autoRedirect + return opts, nil } @@ -205,11 +220,12 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, } return &accessController{ - realm: config.realm, - issuer: config.issuer, - service: config.service, - rootCerts: rootPool, - trustedKeys: trustedKeys, + realm: config.realm, + autoRedirect: config.autoRedirect, + issuer: config.issuer, + service: config.service, + rootCerts: rootPool, + trustedKeys: trustedKeys, }, nil } @@ -217,9 +233,10 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // for actions on resources described by the given access items. func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) { challenge := &authChallenge{ - realm: ac.realm, - service: ac.service, - accessSet: newAccessSet(accessItems...), + realm: ac.realm, + autoRedirect: ac.autoRedirect, + service: ac.service, + accessSet: newAccessSet(accessItems...), } req, err := dcontext.GetRequest(ctx) diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 03dce6fa6d9..69f3e78b6f2 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -333,6 +333,7 @@ func TestAccessController(t *testing.T) { "issuer": issuer, "service": service, "rootcertbundle": rootCertBundleFilename, + "autoredirect": false, } accessController, err := newAccessController(options) @@ -518,6 +519,7 @@ func TestNewAccessControllerPemBlock(t *testing.T) { "issuer": issuer, "service": service, "rootcertbundle": rootCertBundleFilename, + "autoredirect": false, } ac, err := newAccessController(options) diff --git a/registry/handlers/app.go b/registry/handlers/app.go index ece0e52fff8..978851bb369 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -847,7 +847,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont switch err := err.(type) { case auth.Challenge: // Add the appropriate WWW-Auth header - err.SetHeaders(w) + err.SetHeaders(r, w) if err := errcode.ServeJSON(w, errcode.ErrorCodeUnauthorized.WithDetail(accessRecords)); err != nil { dcontext.GetLogger(context).Errorf("error serving error json: %v (from %v)", err, context.Errors)