From 82f6983c39aa956604b86225c14eed00d5da0834 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Mon, 24 Apr 2023 16:55:34 -0700 Subject: [PATCH] oidc: add UserInfoEndpoint returning the discocvered URL This enables users detect if the provider.UserInfo method would fail ahead of time, by checking for the empty string in UserInfoEndpoint. Fixes #373 Fixes #374 --- oidc/oidc.go | 6 ++++++ oidc/oidc_test.go | 45 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/oidc/oidc.go b/oidc/oidc.go index a9098f3c..3026c77e 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -275,6 +275,12 @@ func (p *Provider) Endpoint() oauth2.Endpoint { return oauth2.Endpoint{AuthURL: p.authURL, TokenURL: p.tokenURL} } +// UserInfoEndpoint returns the OpenID Connect userinfo endpoint for the given +// provider. +func (p *Provider) UserInfoEndpoint() string { + return p.userInfoURL +} + // UserInfo represents the OpenID Connect userinfo claims. type UserInfo struct { Subject string `json:"sub"` diff --git a/oidc/oidc_test.go b/oidc/oidc_test.go index bbbbda42..342cf075 100644 --- a/oidc/oidc_test.go +++ b/oidc/oidc_test.go @@ -362,14 +362,19 @@ func (ts *testServer) run(t *testing.T) string { ] }` + var userInfoJSON string + if ts.userInfo != "" { + userInfoJSON = fmt.Sprintf(`"userinfo_endpoint": "%s/userinfo",`, server.URL) + } + wellKnown := fmt.Sprintf(`{ "issuer": "%[1]s", "authorization_endpoint": "%[1]s/auth", "token_endpoint": "%[1]s/token", "jwks_uri": "%[1]s/keys", - "userinfo_endpoint": "%[1]s/userinfo", + %[2]s "id_token_signing_alg_values_supported": ["RS256"] - }`, server.URL) + }`, server.URL, userInfoJSON) newMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) { _, err := io.WriteString(w, wellKnown) @@ -383,13 +388,15 @@ func (ts *testServer) run(t *testing.T) string { w.WriteHeader(500) } }) - newMux.HandleFunc("/userinfo", func(w http.ResponseWriter, req *http.Request) { - w.Header().Add("Content-Type", ts.contentType) - _, err := io.WriteString(w, ts.userInfo) - if err != nil { - w.WriteHeader(500) - } - }) + if ts.userInfo != "" { + newMux.HandleFunc("/userinfo", func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Content-Type", ts.contentType) + _, err := io.WriteString(w, ts.userInfo) + if err != nil { + w.WriteHeader(500) + } + }) + } t.Cleanup(server.Close) return server.URL } @@ -489,6 +496,13 @@ func TestUserInfoEndpoint(t *testing.T) { claims: []byte(userInfoJSONCognitoVariant), }, }, + { + name: "no userinfo endpoint", + server: testServer{ + contentType: "application/json", + userInfo: "", + }, + }, } for _, test := range tests { @@ -502,6 +516,19 @@ func TestUserInfoEndpoint(t *testing.T) { t.Fatalf("Failed to initialize provider for test %v", err) } + if test.server.userInfo == "" { + if provider.UserInfoEndpoint() != "" { + t.Errorf("expected UserInfoEndpoint to be empty, got %v", provider.UserInfoEndpoint()) + } + + // provider.UserInfo will error. + return + } + + if provider.UserInfoEndpoint() != serverURL+"/userinfo" { + t.Errorf("expected UserInfoEndpoint to be %v , got %v", serverURL+"/userinfo", provider.UserInfoEndpoint()) + } + fakeOauthToken := oauth2.Token{} info, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(&fakeOauthToken)) if err != nil {