Skip to content

Commit

Permalink
Read oauth2 secret from file (#293)
Browse files Browse the repository at this point in the history
* Read oauth2 secret from file

Signed-off-by: Julien Pivotto <roidelapluie@inuits.eu>
  • Loading branch information
roidelapluie authored Apr 26, 2021
1 parent 10f0b67 commit 2270f5d
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 20 deletions.
105 changes: 86 additions & 19 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,20 @@ func (u URL) MarshalYAML() (interface{}, error) {

// OAuth2 is the oauth2 client configuration.
type OAuth2 struct {
ClientID string `yaml:"client_id"`
ClientSecret Secret `yaml:"client_secret"`
Scopes []string `yaml:"scopes,omitempty"`
TokenURL string `yaml:"token_url"`
EndpointParams map[string]string `yaml:"endpoint_params,omitempty"`
ClientID string `yaml:"client_id"`
ClientSecret Secret `yaml:"client_secret"`
ClientSecretFile string `yaml:"client_secret_file"`
Scopes []string `yaml:"scopes,omitempty"`
TokenURL string `yaml:"token_url"`
EndpointParams map[string]string `yaml:"endpoint_params,omitempty"`
}

// SetDirectory joins any relative file paths with dir.
func (a *OAuth2) SetDirectory(dir string) {
if a == nil {
return
}
a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile)
}

// HTTPClientConfig configures an HTTP client.
Expand Down Expand Up @@ -151,6 +160,7 @@ func (c *HTTPClientConfig) SetDirectory(dir string) {
c.TLSConfig.SetDirectory(dir)
c.BasicAuth.SetDirectory(dir)
c.Authorization.SetDirectory(dir)
c.OAuth2.SetDirectory(dir)
c.BearerTokenFile = JoinDir(dir, c.BearerTokenFile)
}

Expand Down Expand Up @@ -196,8 +206,13 @@ func (c *HTTPClientConfig) Validate() error {
c.BearerTokenFile = ""
}
}
if c.BasicAuth != nil && c.OAuth2 != nil {
return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured")
if c.OAuth2 != nil {
if c.BasicAuth != nil {
return fmt.Errorf("at most one of basic_auth, oauth2 & authorization must be configured")
}
if len(c.OAuth2.ClientSecret) > 0 && len(c.OAuth2.ClientSecretFile) > 0 {
return fmt.Errorf("at most one of oauth2 client_secret & client_secret_file must be configured")
}
}
return nil
}
Expand Down Expand Up @@ -347,7 +362,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
}

if cfg.OAuth2 != nil {
rt = cfg.OAuth2.NewOAuth2RoundTripper(context.Background(), rt)
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt)
}
// Return a new configured RoundTripper.
return rt, nil
Expand Down Expand Up @@ -462,20 +477,72 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() {
}
}

func (c *OAuth2) NewOAuth2RoundTripper(ctx context.Context, next http.RoundTripper) http.RoundTripper {
config := &clientcredentials.Config{
ClientID: c.ClientID,
ClientSecret: string(c.ClientSecret),
Scopes: c.Scopes,
TokenURL: c.TokenURL,
EndpointParams: mapToValues(c.EndpointParams),
type oauth2RoundTripper struct {
config *OAuth2
rt http.RoundTripper
next http.RoundTripper
secret string
mtx sync.RWMutex
}

func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper {
return &oauth2RoundTripper{
config: config,
next: next,
}
}

tokenSource := config.TokenSource(ctx)
func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
secret string
changed bool
)

return &oauth2.Transport{
Base: next,
Source: tokenSource,
if rt.config.ClientSecretFile != "" {
data, err := ioutil.ReadFile(rt.config.ClientSecretFile)
if err != nil {
return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %s", rt.config.ClientSecretFile, err)
}
secret = strings.TrimSpace(string(data))
rt.mtx.RLock()
changed = secret != rt.secret
rt.mtx.RUnlock()
}

if changed || rt.rt == nil {
if rt.config.ClientSecret != "" {
secret = string(rt.config.ClientSecret)
}

config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Scopes: rt.config.Scopes,
TokenURL: rt.config.TokenURL,
EndpointParams: mapToValues(rt.config.EndpointParams),
}

tokenSource := config.TokenSource(context.Background())

rt.mtx.Lock()
rt.secret = secret
rt.rt = &oauth2.Transport{
Base: rt.next,
Source: tokenSource,
}
rt.mtx.Unlock()
}

rt.mtx.RLock()
currentRT := rt.rt
rt.mtx.RUnlock()
return currentRT.RoundTrip(req)
}

func (rt *oauth2RoundTripper) CloseIdleConnections() {
// OAuth2 RT does not support CloseIdleConnections() but the next RT might.
if ci, ok := rt.next.(closeIdler); ok {
ci.CloseIdleConnections()
}
}

Expand Down
118 changes: 117 additions & 1 deletion config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ var invalidHTTPClientConfigs = []struct {
httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml",
errMsg: `authorization type cannot be set to "basic", use "basic_auth" instead`,
},
{
httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml",
errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured",
},
}

func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) {
Expand Down Expand Up @@ -1136,7 +1140,7 @@ endpoint_params:
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
}

rt := expectedConfig.NewOAuth2RoundTripper(context.Background(), http.DefaultTransport)
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)

client := http.Client{
Transport: rt,
Expand All @@ -1148,3 +1152,115 @@ endpoint_params:
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}
}

func TestOAuth2WithFile(t *testing.T) {
var expectedAuth *string
var previousAuth string
tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != *expectedAuth {
t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth)
}
if auth == previousAuth {
t.Fatal("token endpoint called twice")
}
previousAuth = auth
res, _ := json.Marshal(testServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)
}))
defer tokenTS.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer 12345" {
t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth)
}
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()

secretFile, err := ioutil.TempFile("", "oauth2_secret")
if err != nil {
t.Fatal(err)
}
defer os.Remove(secretFile.Name())

var yamlConfig = fmt.Sprintf(`
client_id: 1
client_secret_file: %s
scopes:
- A
- B
token_url: %s
endpoint_params:
hi: hello
`, secretFile.Name(), tokenTS.URL)
expectedConfig := OAuth2{
ClientID: "1",
ClientSecretFile: secretFile.Name(),
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: tokenTS.URL,
}

var unmarshalledConfig OAuth2
err = yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig)
if err != nil {
t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
}
if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)

client := http.Client{
Transport: rt,
}

tk := "Basic MToxMjM0NTY="
expectedAuth = &tk
if _, err := secretFile.Write([]byte("123456")); err != nil {
t.Fatal(err)
}
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

authorization := resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}

// Making a second request with the same file content should not re-call the token API.
resp, err = client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

tk = "Basic MToxMjM0NTY3"
expectedAuth = &tk
if _, err := secretFile.Write([]byte("7")); err != nil {
t.Fatal(err)
}

_, err = client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

// Making a second request with the same file content should not re-call the token API.
_, err = client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

authorization = resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}
}
3 changes: 3 additions & 0 deletions config/testdata/http.conf.oauth2-secret-and-file-set.bad.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
oauth2:
client_secret: "mysecret"
client_secret_file: "mysecret"

0 comments on commit 2270f5d

Please sign in to comment.