Skip to content

Commit

Permalink
added github token support (#366)
Browse files Browse the repository at this point in the history
added github token support via ACTIONS_ID_TOKEN_REQUEST_URL and ACTIONS_ID_TOKEN_REQUEST_TOKEN in workload identity login
  • Loading branch information
weinong committed Dec 1, 2023
1 parent a7eddd9 commit 837674f
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 88 deletions.
150 changes: 103 additions & 47 deletions pkg/token/federatedIdentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,30 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"

"github.com/Azure/go-autorest/autorest/adal"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

const (
actionsIDTokenRequestToken = "ACTIONS_ID_TOKEN_REQUEST_TOKEN"
actionsIDTokenRequestURL = "ACTIONS_ID_TOKEN_REQUEST_URL"
azureADAudience = "api://AzureADTokenExchange"
defaultScope = "/.default"
)

type workloadIdentityToken struct {
clientID string
tenantID string
federatedTokenFile string
authorityHost string
serverID string
serverID string
client confidential.Client
}

type githubTokenResponse struct {
Value string `json:"value"`
}

func newWorkloadIdentityToken(clientID, federatedTokenFile, authorityHost, serverID, tenantID string) (TokenProvider, error) {
Expand All @@ -27,8 +38,9 @@ func newWorkloadIdentityToken(clientID, federatedTokenFile, authorityHost, serve
if tenantID == "" {
return nil, errors.New("tenantID cannot be empty")
}
if federatedTokenFile == "" {
return nil, errors.New("federatedTokenFile cannot be empty")
hasActionsIDToken := os.Getenv(actionsIDTokenRequestToken) != "" && os.Getenv(actionsIDTokenRequestURL) != ""
if federatedTokenFile == "" && !hasActionsIDToken {
return nil, errors.New("either ACTIONS_ID_TOKEN_REQUEST_TOKEN and ACTIONS_ID_TOKEN_REQUEST_URL environment variables have to be set or federated token file has to be provided")
}
if authorityHost == "" {
return nil, errors.New("authorityHost cannot be empty")
Expand All @@ -37,35 +49,34 @@ func newWorkloadIdentityToken(clientID, federatedTokenFile, authorityHost, serve
return nil, errors.New("serverID cannot be empty")
}

var cred confidential.Credential
if federatedTokenFile != "" {
cred = newCredentialFromTokenFile(federatedTokenFile)
} else {
cred = newCredentialFromGithub()
}

client, err := confidential.New(fmt.Sprintf("%s%s/oauth2/token", authorityHost, tenantID), clientID, cred)
if err != nil {
return nil, fmt.Errorf("failed to create confidential client for federated workload identity. %s", err)
}

return &workloadIdentityToken{
clientID: clientID,
tenantID: tenantID,
federatedTokenFile: federatedTokenFile,
authorityHost: authorityHost,
serverID: serverID,
serverID: serverID,
client: client,
}, nil
}

func (p *workloadIdentityToken) Token() (adal.Token, error) {
emptyToken := adal.Token{}
cred, err := newCredential(p.federatedTokenFile)
if err != nil {
return emptyToken, err
}

// create the confidential client to request an AAD token
confidentialClientApp, err := createClient(p.authorityHost, p.tenantID, p.clientID, cred)
if err != nil {
return emptyToken, err
}

resource := strings.TrimSuffix(p.serverID, "/")
// .default needs to be added to the scope
if !strings.HasSuffix(resource, ".default") {
resource += "/.default"
resource += defaultScope
}

result, err := confidentialClientApp.AcquireTokenByCredential(context.Background(), []string{resource})
result, err := p.client.AcquireTokenByCredential(context.Background(), []string{resource})
if err != nil {
return emptyToken, fmt.Errorf("failed to acquire token. %s", err)
}
Expand All @@ -77,33 +88,20 @@ func (p *workloadIdentityToken) Token() (adal.Token, error) {
}, nil
}

// newCredential creates a confidential.Credential from the provided token file
func newCredential(federatedTokenFile string) (confidential.Credential, error) {
signedAssertion, err := readJWTFromFS(federatedTokenFile)
if err != nil {
return confidential.Credential{}, fmt.Errorf("failed to read signed assertion from token file: %s", err)
}
// Having the callback return the string read from the token file most closely resembles the implementation
// used in NewCredFromAssertion which was deprecated and used previously in this code.
signedAssertionCallback := func(_ context.Context, _ confidential.AssertionRequestOptions) (string, error) {
return signedAssertion, nil
// newCredentialFromTokenFile creates a confidential.Credential from provided token file
func newCredentialFromTokenFile(federatedTokenFile string) confidential.Credential {
cb := func(_ context.Context, _ confidential.AssertionRequestOptions) (string, error) {
return readJWTFromFS(federatedTokenFile)
}
return confidential.NewCredFromAssertionCallback(signedAssertionCallback), nil
return confidential.NewCredFromAssertionCallback(cb)
}

// createClient creates a confidential.Client
func createClient(authorityHost string, tenantID string, clientID string, cred confidential.Credential) (confidential.Client, error) {
authority := fmt.Sprintf("%s%s/oauth2/token", authorityHost, tenantID)
confidentialClientApp, err := confidential.New(
authority,
clientID,
cred)

if err != nil {
return confidential.Client{}, fmt.Errorf("failed to create confidential client app. %s", err)
// newCredentialFromGithub creates a confidential.Credential from github id token
func newCredentialFromGithub() confidential.Credential {
cb := func(ctx context.Context, _ confidential.AssertionRequestOptions) (string, error) {
return getGitHubToken(ctx)
}

return confidentialClientApp, err
return confidential.NewCredFromAssertionCallback(cb)
}

// readJWTFromFS reads the jwt from file system
Expand All @@ -114,3 +112,61 @@ func readJWTFromFS(tokenFilePath string) (string, error) {
}
return string(token), nil
}

func getGitHubToken(ctx context.Context) (string, error) {
reqToken := os.Getenv(actionsIDTokenRequestToken)
reqURL := os.Getenv(actionsIDTokenRequestURL)

if reqToken == "" || reqURL == "" {
return "", errors.New("ACTIONS_ID_TOKEN_REQUEST_TOKEN or ACTIONS_ID_TOKEN_REQUEST_URL is not set")
}

u, err := url.Parse(reqURL)
if err != nil {
return "", fmt.Errorf("unable to parse ACTIONS_ID_TOKEN_REQUEST_URL: %w", err)
}
q := u.Query()
q.Set("audience", azureADAudience)
u.RawQuery = q.Encode()

req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return "", err
}

// reference:
// https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/about-security-hardening-with-openid-connect
req.Header.Set("Authorization", fmt.Sprintf("bearer %s", reqToken))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json; api-version=2.0")

client := http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
var body string
b, err := io.ReadAll(resp.Body)
if err != nil {
body = err.Error()
} else {
body = string(b)
}

return "", fmt.Errorf("github actions ID token request failed with status code: %d, response body: %s", resp.StatusCode, body)
}

var tokenResp githubTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", err
}

if tokenResp.Value == "" {
return "", errors.New("github actions ID token is empty")
}

return tokenResp.Value, nil
}
74 changes: 56 additions & 18 deletions pkg/token/federatedIdentity_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package token

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

Expand All @@ -19,7 +23,7 @@ func TestNewWorkloadIdentityTokenProviderEmpty(t *testing.T) {
name: "tenantID cannot be empty",
},
{
name: "federatedTokenFile cannot be empty",
name: "either ACTIONS_ID_TOKEN_REQUEST_TOKEN and ACTIONS_ID_TOKEN_REQUEST_URL environment variables have to be set or federated token file has to be provided",
},
{
name: "authorityHost cannot be empty",
Expand All @@ -38,16 +42,14 @@ func TestNewWorkloadIdentityTokenProviderEmpty(t *testing.T) {
switch {
case strings.Contains(name, "clientID"):
_, err = newWorkloadIdentityToken("", "", "", "", "")
case strings.Contains(name, "federatedTokenFile"):
case strings.Contains(name, "federated token file"):
_, err = newWorkloadIdentityToken("test", "", "", "", "test")
case strings.Contains(name, "authorityHost"):
_, err = newWorkloadIdentityToken("test", "test", "", "", "test")
case strings.Contains(name, "serverID"):
_, err = newWorkloadIdentityToken("test", "test", "test", "", "test")
case strings.Contains(name, "tenantID"):
_, err = newWorkloadIdentityToken("test", "test", "test", "test", "")
default:
fmt.Println(false)
}

if !testutils.ErrorContains(err, data.name) {
Expand All @@ -57,25 +59,61 @@ func TestNewWorkloadIdentityTokenProviderEmpty(t *testing.T) {
}
}

func TestNewWorkloadIdentityToken(t *testing.T) {
workloadIdentityToken := workloadIdentityToken{}
_, err := workloadIdentityToken.Token()

if !testutils.ErrorContains(err, "failed to read signed assertion from token file:") {
func TestReadJWTFromFSEmptyString(t *testing.T) {
_, err := readJWTFromFS("")
if !testutils.ErrorContains(err, "no such file or directory") {
t.Errorf("unexpected error: %v", err)
}
}

func TestNewCredentialEmptyString(t *testing.T) {
_, err := newCredential("")
if !testutils.ErrorContains(err, "failed to read signed assertion from token file:") {
t.Errorf("unexpected error: %v", err)
}
func invalidHttpRequest(w http.ResponseWriter, msg string) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(msg))
}

func TestReadJWTFromFSEmptyString(t *testing.T) {
_, err := readJWTFromFS("")
if !testutils.ErrorContains(err, "no such file or directory") {
t.Errorf("unexpected error: %v", err)
func TestUseGitHubToken(t *testing.T) {
var (
ghToken = "foo-bar"
oidcToken = "oidc-token"
)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
invalidHttpRequest(w, fmt.Sprintf("unexpected method: %s", r.Method))
return
}
if r.URL.Query().Get("audience") != azureADAudience {
invalidHttpRequest(w, fmt.Sprintf("unexpected audience: %s", r.URL.Query().Get("audience")))
return
}
if r.Header.Get("Authorization") != fmt.Sprintf("bearer %s", ghToken) {
invalidHttpRequest(w, fmt.Sprintf("unexpected Authorization header: %s", r.Header.Get("Authorization")))
return
}
if r.Header.Get("Content-Type") != "application/json" {
invalidHttpRequest(w, fmt.Sprintf("unexpected Content-Type header: %s", r.Header.Get("Content-Type")))
return
}
if r.Header.Get("Accept") != "application/json; api-version=2.0" {
invalidHttpRequest(w, fmt.Sprintf("unexpected Accept header: %s", r.Header.Get("Accept")))
return
}
tokenResponse := githubTokenResponse{
Value: oidcToken,
}

json.NewEncoder(w).Encode(tokenResponse)
}))
defer ts.Close()

t.Setenv(actionsIDTokenRequestURL, ts.URL)
t.Setenv(actionsIDTokenRequestToken, ghToken)

token, err := getGitHubToken(context.Background())
if err != nil {
t.Fatalf("getGitHubToken returned unexpected error: %s", err)
}
if token != oidcToken {
t.Fatalf("got token: %s, expected: %s", token, oidcToken)
}

}
2 changes: 1 addition & 1 deletion pkg/token/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (p *InteractiveToken) TokenWithOptions(options *azcore.ClientOptions) (adal

// Request a new Interactive token provider
authorityFromConfig := p.oAuthConfig.AuthorityEndpoint
scopes := []string{p.resourceID + "/.default"}
scopes := []string{p.resourceID + defaultScope}
clientOpts := azcore.ClientOptions{Cloud: cloud.Configuration{
ActiveDirectoryAuthorityHost: authorityFromConfig.String(),
}}
Expand Down
41 changes: 20 additions & 21 deletions pkg/token/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,28 +227,27 @@ func TestNewTokenProvider(t *testing.T) {
ClientID: "testclient",
ServerID: "testserver",
FederatedTokenFile: "testfile",
AuthorityHost: "testauthority",
AuthorityHost: "https://testauthority",
LoginMethod: "workloadidentity",
}
provider, err := newTokenProvider(options)
if err != nil || provider == nil {
t.Errorf("expected no error but got: %s", err)
}
workloadId := provider.(*workloadIdentityToken)
if workloadId.clientID != options.ClientID {
t.Errorf("expected provider client ID to be: %s but got: %s", options.ClientID, workloadId.clientID)
}
if workloadId.serverID != options.ServerID {
t.Errorf("expected provider server ID to be: %s but got: %s", options.ServerID, workloadId.serverID)
}
if workloadId.tenantID != options.TenantID {
t.Errorf("expected provider tenant ID to be: %s but got: %s", options.TenantID, workloadId.tenantID)
}
if workloadId.federatedTokenFile != options.FederatedTokenFile {
t.Errorf("expected provider federated token file to be: %s but got: %s", options.FederatedTokenFile, workloadId.federatedTokenFile)
}
if workloadId.authorityHost != options.AuthorityHost {
t.Errorf("expected provider authority host to be: %s but got: %s", options.AuthorityHost, workloadId.authorityHost)
}
t.Run("with token file", func(t *testing.T) {
provider, err := newTokenProvider(options)
if err != nil || provider == nil {
t.Errorf("expected no error but got: %s", err)
}
workloadId := provider.(*workloadIdentityToken)
if workloadId.serverID != options.ServerID {
t.Errorf("expected provider server ID to be: %s but got: %s", options.ServerID, workloadId.serverID)
}
})
t.Run("with Github token", func(t *testing.T) {
options.FederatedTokenFile = ""
t.Setenv(actionsIDTokenRequestToken, "fake-token")
t.Setenv(actionsIDTokenRequestURL, "fake-url")
provider, err := newTokenProvider(options)
if err != nil || provider == nil {
t.Errorf("expected no error but got: %s", err)
}
})
})
}
2 changes: 1 addition & 1 deletion pkg/token/serviceprincipaltoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (p *servicePrincipalToken) TokenWithOptions(options *azcore.ClientOptions)
var accessToken string
var expirationTimeUnix int64
var err error
scopes := []string{p.resourceID + "/.default"}
scopes := []string{p.resourceID + defaultScope}

// Request a new Azure token provider for service principal
if p.clientSecret != "" {
Expand Down

0 comments on commit 837674f

Please sign in to comment.