Skip to content

Commit

Permalink
Merge pull request #607 from sledigabel/saml_caching
Browse files Browse the repository at this point in the history
Adding SAML Assertion caching feature
  • Loading branch information
Mark Wolfe authored Mar 11, 2021
2 parents 3251e9c + 0f1a450 commit eb7bc3c
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The process goes something like this:
* Prompt user for credentials
* Log in to Identity Provider using form based authentication
* Build a SAML assertion containing AWS roles
* Optionally cache the SAML assertion (the cache is not encrypted)
* Exchange the role and SAML assertion with [AWS STS service](https://docs.aws.amazon.com/STS/latest/APIReference/Welcome.html) to get a temporary set of credentials
* Save these credentials to an aws profile named "saml"

Expand Down Expand Up @@ -630,6 +631,14 @@ credential_process = saml2aws login --skip-prompt --quiet --credential-process -

When using the aws cli with the `mybucket` profile, the authentication process will be run and the aws will then be executed based on the returned credentials.

# Caching the saml2aws SAML assertion for immediate reuse

You can use the flag `--cache-saml` in order to cache the SAML assertion at authentication time. The SAML assertion cache has a very short validity (5 min) and can be used to authenticate to several roles with a single MFA validation.

there is a file per saml2aws profile, the cache directory is called `saml2aws` and is located in your `.aws` directory in your user homedir.

You can toggle `--cache-saml` during `login` or during `list-roles`, and you can set it once during `configure` and use it implicitly.

# License

This code is Copyright (c) 2018 [Versent](http://versent.com.au) and released under the MIT license. All rights not explicitly granted in the MIT license are reserved. See the included LICENSE.md file for more details.
Expand Down
31 changes: 28 additions & 3 deletions cmd/saml2aws/commands/list_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/versent/saml2aws/v2"
"github.com/versent/saml2aws/v2/helper/credentials"
"github.com/versent/saml2aws/v2/pkg/flags"
"github.com/versent/saml2aws/v2/pkg/samlcache"
)

// ListRoles will list available role ARNs
Expand All @@ -23,6 +24,12 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error {
return errors.Wrap(err, "error building login details")
}

// creates a cacheProvider, only used when --cache is set
// cacheProvider := samlcache.NewSAMLCacheProvider(account.Name, "")
cacheProvider := &samlcache.SAMLCacheProvider{
Account: account.Name,
}

loginDetails, err := resolveLoginDetails(account, loginFlags)
if err != nil {
log.Printf("%+v", err)
Expand All @@ -41,10 +48,28 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error {
return errors.Wrap(err, "error validating login details")
}

samlAssertion, err := provider.Authenticate(loginDetails)
if err != nil {
return errors.Wrap(err, "error authenticating to IdP")
var samlAssertion string
if account.SAMLCache {
if cacheProvider.IsValid() {
samlAssertion, err = cacheProvider.Read()
if err != nil {
logger.Debug("Could not read cache:", err)
}
}
}

if samlAssertion == "" {
// samlAssertion was not cached
samlAssertion, err = provider.Authenticate(loginDetails)
if err != nil {
return errors.Wrap(err, "error authenticating to IdP")
}
if account.SAMLCache {
err = cacheProvider.Write(samlAssertion)
if err != nil {
logger.Error("Could not write samlAssertion:", err)
}
}
}

if samlAssertion == "" {
Expand Down
31 changes: 28 additions & 3 deletions cmd/saml2aws/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/versent/saml2aws/v2/pkg/cfg"
"github.com/versent/saml2aws/v2/pkg/creds"
"github.com/versent/saml2aws/v2/pkg/flags"
"github.com/versent/saml2aws/v2/pkg/samlcache"
)

// Login login to ADFS
Expand All @@ -31,6 +32,10 @@ func Login(loginFlags *flags.LoginExecFlags) error {
}

sharedCreds := awsconfig.NewSharedCredentials(account.Profile, account.CredentialsFile)
// creates a cacheProvider, only used when --cache is set
cacheProvider := &samlcache.SAMLCacheProvider{
Account: account.Name,
}

logger.Debug("check if Creds Exist")

Expand Down Expand Up @@ -79,10 +84,30 @@ func Login(loginFlags *flags.LoginExecFlags) error {

log.Printf("Authenticating as %s ...", loginDetails.Username)

samlAssertion, err := provider.Authenticate(loginDetails)
if err != nil {
return errors.Wrap(err, "error authenticating to IdP")
var samlAssertion string
if account.SAMLCache {
if cacheProvider.IsValid() {
samlAssertion, err = cacheProvider.Read()
if err != nil {
return errors.Wrap(err, "Could not read saml cache")
}
} else {
logger.Debug("Cache is invalid")
}
}

if samlAssertion == "" {
// samlAssertion was not cached
samlAssertion, err = provider.Authenticate(loginDetails)
if err != nil {
return errors.Wrap(err, "error authenticating to IdP")
}
if account.SAMLCache {
err = cacheProvider.Write(samlAssertion)
if err != nil {
return errors.Wrap(err, "Could not write saml cache")
}
}
}

if samlAssertion == "" {
Expand Down
3 changes: 3 additions & 0 deletions cmd/saml2aws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func main() {
cmdConfigure.Flag("profile", "The AWS profile to save the temporary credentials. (env: SAML2AWS_PROFILE)").Envar("SAML2AWS_PROFILE").Short('p').StringVar(&commonFlags.Profile)
cmdConfigure.Flag("resource-id", "F5APM SAML resource ID of your company account. (env: SAML2AWS_F5APM_RESOURCE_ID)").Envar("SAML2AWS_F5APM_RESOURCE_ID").StringVar(&commonFlags.ResourceID)
cmdConfigure.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile)
cmdConfigure.Flag("cache-saml", "Caches the SAML response").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache)
configFlags := commonFlags

// `login` command and settings
Expand All @@ -105,6 +106,7 @@ func main() {
cmdLogin.Flag("force", "Refresh credentials even if not expired.").BoolVar(&loginFlags.Force)
cmdLogin.Flag("credential-process", "Enables AWS Credential Process support by outputting credentials to STDOUT in a JSON message.").BoolVar(&loginFlags.CredentialProcess)
cmdLogin.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile)
cmdLogin.Flag("cache-saml", "Caches the SAML response").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache)

// `exec` command and settings
cmdExec := app.Command("exec", "Exec the supplied command with env vars from STS token.")
Expand All @@ -128,6 +130,7 @@ func main() {

// `list` command and settings
cmdListRoles := app.Command("list-roles", "List available role ARNs.")
cmdListRoles.Flag("cache-saml", "Caches the SAML response").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache)
listRolesFlags := new(flags.LoginExecFlags)
listRolesFlags.CommonFlags = commonFlags

Expand Down
5 changes: 5 additions & 0 deletions pkg/cfg/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (

// IDPAccount saml IDP account
type IDPAccount struct {
Name string
AppID string `ini:"app_id"` // used by OneLogin and AzureAD
URL string `ini:"url"`
Username string `ini:"username"`
Expand All @@ -47,6 +48,7 @@ type IDPAccount struct {
HttpAttemptsCount string `ini:"http_attempts_count"`
HttpRetryDelay string `ini:"http_retry_delay"`
CredentialsFile string `ini:"credentials_file"`
SAMLCache bool `ini:"saml_cache"`
}

func (ia IDPAccount) String() string {
Expand Down Expand Up @@ -195,6 +197,9 @@ func (cm *ConfigManager) LoadIDPAccount(idpAccountName string) (*IDPAccount, err
return nil, errors.Wrap(err, "Unable to read idp account")
}

// adding Name at Load time for the IdpAccount to have awareness of "self"
account.Name = idpAccountName

return account, nil
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/cfg/cfg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func TestNewConfigManagerLoad(t *testing.T) {
idpAccount, err := cfgm.LoadIDPAccount("test123")
require.Nil(t, err)
require.Equal(t, &IDPAccount{
Name: "test123",
URL: "https://id.whatever.com",
Username: "abc@whatever.com",
Provider: "keycloak",
Expand Down Expand Up @@ -61,6 +62,7 @@ func TestNewConfigManagerSave(t *testing.T) {
idpAccount, err := cfgm.LoadIDPAccount("testing2")
require.Nil(t, err)
require.Equal(t, &IDPAccount{
Name: "testing2",
URL: "https://id.whatever.com",
Username: "abc@whatever.com",
Provider: "keycloak",
Expand Down
4 changes: 4 additions & 0 deletions pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type CommonFlags struct {
DisableKeychain bool
Region string
CredentialsFile string
SAMLCache bool
}

// LoginExecFlags flags for the Login / Exec commands
Expand Down Expand Up @@ -98,4 +99,7 @@ func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) {
if commonFlags.CredentialsFile != "" {
account.CredentialsFile = commonFlags.CredentialsFile
}
if commonFlags.SAMLCache {
account.SAMLCache = commonFlags.SAMLCache
}
}
142 changes: 142 additions & 0 deletions pkg/samlcache/samlcache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package samlcache

import (
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"runtime"
"time"

homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

var (
ErrInvalidCachePath = errors.New("Cannot evaluate Cache file path")
logger = logrus.WithField("pkg", "samlcache")
)

const (
SAMLAssertionValidityTimeout = 5 * time.Minute
SAMLCacheFilePermissions = 0600
SAMLCacheDirPermissions = 0700
SAMLCacheDir = "saml2aws"
)

// SAMLCacheProvider loads aws credentials file
type SAMLCacheProvider struct {
Filename string
Account string
}

func resolveSymlink(filename string) (string, error) {
sympath, err := filepath.EvalSymlinks(filename)

// return the un modified filename
if os.IsNotExist(err) {
return filename, nil
}
if err != nil {
return "", err
}

return sympath, nil
}

func (p *SAMLCacheProvider) IsValid() bool {
var cache_path string
var err error
if p.Filename == "" {
cache_path, err = locateCacheFile(p.Account)
if err != nil {
return false
}
} else {
cache_path = p.Filename
}

fileInfo, err := os.Stat(cache_path)
if err != nil {
return false
}

return time.Since(fileInfo.ModTime()) < SAMLAssertionValidityTimeout
}

func locateCacheFile(account string) (string, error) {

var name, filename string
var err error
if account == "" {
filename = "cache"
} else {
filename = fmt.Sprintf("cache_%s", account)
}
if runtime.GOOS == "windows" {
name = path.Join(os.Getenv("USERPROFILE"), ".aws", SAMLCacheDir, filename)
} else {
name, err = homedir.Expand(path.Join("~", ".aws", SAMLCacheDir, filename))
if err != nil {
return "", ErrInvalidCachePath
}
}
// is the filename a symlink?
name, err = resolveSymlink(name)
if err != nil {
return "", errors.Wrap(err, "unable to resolve symlink")
}

logger.WithField("name", name).Debug("resolveSymlink")

return name, nil
}

func (p *SAMLCacheProvider) Read() (string, error) {

var cache_path string
var err error
if p.Filename == "" {
cache_path, err = locateCacheFile(p.Account)
if err != nil {
return "", errors.Wrap(err, "Could not retrieve cache file path")
}
} else {
cache_path = p.Filename
}

content, err := ioutil.ReadFile(cache_path)
if err != nil {
return "", errors.Wrap(err, "Could not read the cache file path")
}

return string(content), nil
}

func (p *SAMLCacheProvider) Write(samlAssertion string) error {

var cache_path string
var err error
if p.Filename == "" {
cache_path, err = locateCacheFile(p.Account)
if err != nil {
return errors.Wrap(err, "Could not retrieve cache file path")
}
} else {
cache_path = p.Filename
}

// create the directory if it doesn't exist
err = os.MkdirAll(path.Dir(cache_path), SAMLCacheDirPermissions)
if err != nil {
return errors.Wrap(err, "Could not write the cache file directory")
}
err = ioutil.WriteFile(cache_path, []byte(samlAssertion), SAMLCacheFilePermissions)
if err != nil {
return errors.Wrap(err, "Could not write the cache file path")
}

return nil
}
Loading

0 comments on commit eb7bc3c

Please sign in to comment.