Skip to content

Commit

Permalink
feat: add more oauth options (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
moshloop authored Oct 19, 2023
1 parent dcd9c66 commit 4da14c5
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 22 deletions.
10 changes: 8 additions & 2 deletions http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ import (

type TraceConfig = middlewares.TraceConfig

type OauthConfig = middlewares.OauthConfig

var AuthStyleInHeader = middlewares.AuthStyleInHeader
var AuthStyleInParams = middlewares.AuthStyleInParams
var AuthStyleAutoDetect = middlewares.AuthStyleAutoDetect

var TraceAll = TraceConfig{
MaxBodyLength: 4096,
Body: true,
Expand Down Expand Up @@ -210,8 +216,8 @@ func (c *Client) Auth(username, password string) *Client {
return c
}

func (c *Client) OAuth(clientID, clientSecret, tokenURL string, scopes ...string) *Client {
c.Use(middlewares.NewOauthTransport(clientID, clientSecret, tokenURL, scopes...).RoundTripper)
func (c *Client) OAuth(config middlewares.OauthConfig) *Client {
c.Use(middlewares.NewOauthTransport(config).RoundTripper)
return c
}

Expand Down
8 changes: 7 additions & 1 deletion http/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ func TestExample(t *testing.T) {
scopes = []string{"https://graph.microsoft.com/.default"}
)

req := http.NewClient().OAuth(clientID, clientSecret, tokenURL, scopes...).R(ctx)
req := http.NewClient().OAuth(
http.OauthConfig{
ClientID: clientID,
ClientSecret: clientSecret,
TokenURL: tokenURL,
Scopes: scopes}).
R(ctx)
response, err := req.Get("https://graph.microsoft.com/v1.0/users")
if err != nil {
t.Fatalf("error: %v", err)
Expand Down
77 changes: 58 additions & 19 deletions http/middlewares/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middlewares
import (
"fmt"
netHttp "net/http"
"net/url"
"time"

"github.com/flanksource/commons/hash"
Expand All @@ -11,34 +12,72 @@ import (
"golang.org/x/oauth2/clientcredentials"
)

func NewOauthTransport(clientID, clientSecret, tokenURL string, scopes ...string) *oauthConfig {
return &oauthConfig{
clientID: clientID,
clientSecret: clientSecret,
tokenURL: tokenURL,
scopes: scopes,
cache: cache.New(time.Minute*15, time.Hour),
func NewOauthTransport(config OauthConfig) *oauthRoundTripper {
return &oauthRoundTripper{OauthConfig: config, cache: cache.New(time.Minute*15, time.Hour)}
}

type AuthStyle oauth2.AuthStyle

var AuthStyleInHeader = AuthStyle(oauth2.AuthStyleInHeader)
var AuthStyleInParams = AuthStyle(oauth2.AuthStyleInParams)
var AuthStyleAutoDetect = AuthStyle(oauth2.AuthStyleAutoDetect)

type OauthConfig struct {
ClientID string
ClientSecret string
TokenURL string
Scopes []string
Params map[string]string
AuthStyle AuthStyle
Tracer func(msg string)
}

func (c *OauthConfig) AuthStyleInHeader() *OauthConfig {
c.AuthStyle = AuthStyleInHeader
return c
}

func (c *OauthConfig) AuthStyleInParams() *OauthConfig {
c.AuthStyle = AuthStyleInParams
return c
}

func (c *OauthConfig) getSanitizedSecret() string {
if len(c.ClientSecret) <= 4 {
return c.ClientSecret
}
return c.ClientSecret[0:4] + "****"
}

func (c OauthConfig) String() string {
return fmt.Sprintf("url=%s id=%s, secret=%s scopes=%s params=%s", c.TokenURL, c.ClientID, c.getSanitizedSecret(), c.Scopes, c.Params)
}

type oauthConfig struct {
clientID string
clientSecret string
tokenURL string
scopes []string
cache *cache.Cache
type oauthRoundTripper struct {
OauthConfig
cache *cache.Cache
}

func toUrlValues(m map[string]string) url.Values {
values := url.Values{}
for k, v := range m {
values[k] = []string{v}
}
return values
}

func (t *oauthConfig) RoundTripper(rt netHttp.RoundTripper) netHttp.RoundTripper {
func (t *oauthRoundTripper) RoundTripper(rt netHttp.RoundTripper) netHttp.RoundTripper {
return RoundTripperFunc(func(ogRequest *netHttp.Request) (*netHttp.Response, error) {
config := clientcredentials.Config{
ClientID: t.clientID,
ClientSecret: t.clientSecret,
TokenURL: t.tokenURL,
Scopes: t.scopes,
ClientID: t.ClientID,
ClientSecret: t.ClientSecret,
TokenURL: t.TokenURL,
Scopes: t.Scopes,
EndpointParams: toUrlValues(t.Params),
AuthStyle: oauth2.AuthStyle(t.AuthStyle),
}

cacheKey := oauthCacheKey(t.clientID, t.clientSecret, t.tokenURL, t.scopes)
cacheKey := oauthCacheKey(t.ClientID, t.ClientSecret, t.TokenURL, t.Scopes)
var token *oauth2.Token
if val, ok := t.cache.Get(cacheKey); ok {
token, _ = val.(*oauth2.Token)
Expand Down

0 comments on commit 4da14c5

Please sign in to comment.