From 4da14c5f5bcf69ef39fa6228c8c1ec3306f95a9c Mon Sep 17 00:00:00 2001 From: Moshe Immerman Date: Thu, 19 Oct 2023 10:34:26 +0300 Subject: [PATCH] feat: add more oauth options (#88) --- http/client.go | 10 ++++- http/examples_test.go | 8 +++- http/middlewares/oauth.go | 77 +++++++++++++++++++++++++++++---------- 3 files changed, 73 insertions(+), 22 deletions(-) diff --git a/http/client.go b/http/client.go index fd843a1..cfe8e3a 100644 --- a/http/client.go +++ b/http/client.go @@ -16,6 +16,12 @@ import ( type TraceConfig = middlewares.TraceConfig +type OauthConfig = middlewares.OauthConfig + +var AuthStyleInHeader = middlewares.AuthStyleInHeader +var AuthStyleInParams = middlewares.AuthStyleInParams +var AuthStyleAutoDetect = middlewares.AuthStyleAutoDetect + var TraceAll = TraceConfig{ MaxBodyLength: 4096, Body: true, @@ -210,8 +216,8 @@ func (c *Client) Auth(username, password string) *Client { return c } -func (c *Client) OAuth(clientID, clientSecret, tokenURL string, scopes ...string) *Client { - c.Use(middlewares.NewOauthTransport(clientID, clientSecret, tokenURL, scopes...).RoundTripper) +func (c *Client) OAuth(config middlewares.OauthConfig) *Client { + c.Use(middlewares.NewOauthTransport(config).RoundTripper) return c } diff --git a/http/examples_test.go b/http/examples_test.go index 4b16017..917a2d6 100644 --- a/http/examples_test.go +++ b/http/examples_test.go @@ -31,7 +31,13 @@ func TestExample(t *testing.T) { scopes = []string{"https://graph.microsoft.com/.default"} ) - req := http.NewClient().OAuth(clientID, clientSecret, tokenURL, scopes...).R(ctx) + req := http.NewClient().OAuth( + http.OauthConfig{ + ClientID: clientID, + ClientSecret: clientSecret, + TokenURL: tokenURL, + Scopes: scopes}). + R(ctx) response, err := req.Get("https://graph.microsoft.com/v1.0/users") if err != nil { t.Fatalf("error: %v", err) diff --git a/http/middlewares/oauth.go b/http/middlewares/oauth.go index 0af525d..a0e679f 100644 --- a/http/middlewares/oauth.go +++ b/http/middlewares/oauth.go @@ -3,6 +3,7 @@ package middlewares import ( "fmt" netHttp "net/http" + "net/url" "time" "github.com/flanksource/commons/hash" @@ -11,34 +12,72 @@ import ( "golang.org/x/oauth2/clientcredentials" ) -func NewOauthTransport(clientID, clientSecret, tokenURL string, scopes ...string) *oauthConfig { - return &oauthConfig{ - clientID: clientID, - clientSecret: clientSecret, - tokenURL: tokenURL, - scopes: scopes, - cache: cache.New(time.Minute*15, time.Hour), +func NewOauthTransport(config OauthConfig) *oauthRoundTripper { + return &oauthRoundTripper{OauthConfig: config, cache: cache.New(time.Minute*15, time.Hour)} +} + +type AuthStyle oauth2.AuthStyle + +var AuthStyleInHeader = AuthStyle(oauth2.AuthStyleInHeader) +var AuthStyleInParams = AuthStyle(oauth2.AuthStyleInParams) +var AuthStyleAutoDetect = AuthStyle(oauth2.AuthStyleAutoDetect) + +type OauthConfig struct { + ClientID string + ClientSecret string + TokenURL string + Scopes []string + Params map[string]string + AuthStyle AuthStyle + Tracer func(msg string) +} + +func (c *OauthConfig) AuthStyleInHeader() *OauthConfig { + c.AuthStyle = AuthStyleInHeader + return c +} + +func (c *OauthConfig) AuthStyleInParams() *OauthConfig { + c.AuthStyle = AuthStyleInParams + return c +} + +func (c *OauthConfig) getSanitizedSecret() string { + if len(c.ClientSecret) <= 4 { + return c.ClientSecret } + return c.ClientSecret[0:4] + "****" +} + +func (c OauthConfig) String() string { + return fmt.Sprintf("url=%s id=%s, secret=%s scopes=%s params=%s", c.TokenURL, c.ClientID, c.getSanitizedSecret(), c.Scopes, c.Params) } -type oauthConfig struct { - clientID string - clientSecret string - tokenURL string - scopes []string - cache *cache.Cache +type oauthRoundTripper struct { + OauthConfig + cache *cache.Cache +} + +func toUrlValues(m map[string]string) url.Values { + values := url.Values{} + for k, v := range m { + values[k] = []string{v} + } + return values } -func (t *oauthConfig) RoundTripper(rt netHttp.RoundTripper) netHttp.RoundTripper { +func (t *oauthRoundTripper) RoundTripper(rt netHttp.RoundTripper) netHttp.RoundTripper { return RoundTripperFunc(func(ogRequest *netHttp.Request) (*netHttp.Response, error) { config := clientcredentials.Config{ - ClientID: t.clientID, - ClientSecret: t.clientSecret, - TokenURL: t.tokenURL, - Scopes: t.scopes, + ClientID: t.ClientID, + ClientSecret: t.ClientSecret, + TokenURL: t.TokenURL, + Scopes: t.Scopes, + EndpointParams: toUrlValues(t.Params), + AuthStyle: oauth2.AuthStyle(t.AuthStyle), } - cacheKey := oauthCacheKey(t.clientID, t.clientSecret, t.tokenURL, t.scopes) + cacheKey := oauthCacheKey(t.ClientID, t.ClientSecret, t.TokenURL, t.Scopes) var token *oauth2.Token if val, ok := t.cache.Get(cacheKey); ok { token, _ = val.(*oauth2.Token)