diff --git a/README.md b/README.md index ba4cadcce..6324797c1 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ The process goes something like this: * Prompt user for credentials * Log in to Identity Provider using form based authentication * Build a SAML assertion containing AWS roles +* Optionally cache the SAML assertion (the cache is not encrypted) * Exchange the role and SAML assertion with [AWS STS service](https://docs.aws.amazon.com/STS/latest/APIReference/Welcome.html) to get a temporary set of credentials * Save these credentials to an aws profile named "saml" @@ -630,6 +631,14 @@ credential_process = saml2aws login --skip-prompt --quiet --credential-process - When using the aws cli with the `mybucket` profile, the authentication process will be run and the aws will then be executed based on the returned credentials. +# Caching the saml2aws SAML assertion for immediate reuse + +You can use the flag `--cache-saml` in order to cache the SAML assertion at authentication time. The SAML assertion cache has a very short validity (5 min) and can be used to authenticate to several roles with a single MFA validation. + +there is a file per saml2aws profile, the cache directory is called `saml2aws` and is located in your `.aws` directory in your user homedir. + +You can toggle `--cache-saml` during `login` or during `list-roles`, and you can set it once during `configure` and use it implicitly. + # License This code is Copyright (c) 2018 [Versent](http://versent.com.au) and released under the MIT license. All rights not explicitly granted in the MIT license are reserved. See the included LICENSE.md file for more details. diff --git a/cmd/saml2aws/commands/list_roles.go b/cmd/saml2aws/commands/list_roles.go index ee467b4d7..8b3a6dbcd 100644 --- a/cmd/saml2aws/commands/list_roles.go +++ b/cmd/saml2aws/commands/list_roles.go @@ -11,6 +11,7 @@ import ( "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/flags" + "github.com/versent/saml2aws/v2/pkg/samlcache" ) // ListRoles will list available role ARNs @@ -23,6 +24,12 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error { return errors.Wrap(err, "error building login details") } + // creates a cacheProvider, only used when --cache is set + // cacheProvider := samlcache.NewSAMLCacheProvider(account.Name, "") + cacheProvider := &samlcache.SAMLCacheProvider{ + Account: account.Name, + } + loginDetails, err := resolveLoginDetails(account, loginFlags) if err != nil { log.Printf("%+v", err) @@ -41,10 +48,28 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error { return errors.Wrap(err, "error validating login details") } - samlAssertion, err := provider.Authenticate(loginDetails) - if err != nil { - return errors.Wrap(err, "error authenticating to IdP") + var samlAssertion string + if account.SAMLCache { + if cacheProvider.IsValid() { + samlAssertion, err = cacheProvider.Read() + if err != nil { + logger.Debug("Could not read cache:", err) + } + } + } + if samlAssertion == "" { + // samlAssertion was not cached + samlAssertion, err = provider.Authenticate(loginDetails) + if err != nil { + return errors.Wrap(err, "error authenticating to IdP") + } + if account.SAMLCache { + err = cacheProvider.Write(samlAssertion) + if err != nil { + logger.Error("Could not write samlAssertion:", err) + } + } } if samlAssertion == "" { diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index be6c58cd9..7803c23a9 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -18,6 +18,7 @@ import ( "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/creds" "github.com/versent/saml2aws/v2/pkg/flags" + "github.com/versent/saml2aws/v2/pkg/samlcache" ) // Login login to ADFS @@ -31,6 +32,10 @@ func Login(loginFlags *flags.LoginExecFlags) error { } sharedCreds := awsconfig.NewSharedCredentials(account.Profile, account.CredentialsFile) + // creates a cacheProvider, only used when --cache is set + cacheProvider := &samlcache.SAMLCacheProvider{ + Account: account.Name, + } logger.Debug("check if Creds Exist") @@ -79,10 +84,30 @@ func Login(loginFlags *flags.LoginExecFlags) error { log.Printf("Authenticating as %s ...", loginDetails.Username) - samlAssertion, err := provider.Authenticate(loginDetails) - if err != nil { - return errors.Wrap(err, "error authenticating to IdP") + var samlAssertion string + if account.SAMLCache { + if cacheProvider.IsValid() { + samlAssertion, err = cacheProvider.Read() + if err != nil { + return errors.Wrap(err, "Could not read saml cache") + } + } else { + logger.Debug("Cache is invalid") + } + } + if samlAssertion == "" { + // samlAssertion was not cached + samlAssertion, err = provider.Authenticate(loginDetails) + if err != nil { + return errors.Wrap(err, "error authenticating to IdP") + } + if account.SAMLCache { + err = cacheProvider.Write(samlAssertion) + if err != nil { + return errors.Wrap(err, "Could not write saml cache") + } + } } if samlAssertion == "" { diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index c2a145d72..00df753b2 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -92,6 +92,7 @@ func main() { cmdConfigure.Flag("profile", "The AWS profile to save the temporary credentials. (env: SAML2AWS_PROFILE)").Envar("SAML2AWS_PROFILE").Short('p').StringVar(&commonFlags.Profile) cmdConfigure.Flag("resource-id", "F5APM SAML resource ID of your company account. (env: SAML2AWS_F5APM_RESOURCE_ID)").Envar("SAML2AWS_F5APM_RESOURCE_ID").StringVar(&commonFlags.ResourceID) cmdConfigure.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile) + cmdConfigure.Flag("cache-saml", "Caches the SAML response").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) configFlags := commonFlags // `login` command and settings @@ -105,6 +106,7 @@ func main() { cmdLogin.Flag("force", "Refresh credentials even if not expired.").BoolVar(&loginFlags.Force) cmdLogin.Flag("credential-process", "Enables AWS Credential Process support by outputting credentials to STDOUT in a JSON message.").BoolVar(&loginFlags.CredentialProcess) cmdLogin.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile) + cmdLogin.Flag("cache-saml", "Caches the SAML response").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) // `exec` command and settings cmdExec := app.Command("exec", "Exec the supplied command with env vars from STS token.") @@ -128,6 +130,7 @@ func main() { // `list` command and settings cmdListRoles := app.Command("list-roles", "List available role ARNs.") + cmdListRoles.Flag("cache-saml", "Caches the SAML response").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) listRolesFlags := new(flags.LoginExecFlags) listRolesFlags.CommonFlags = commonFlags diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index f0ff4eb71..7d5b83355 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -30,6 +30,7 @@ const ( // IDPAccount saml IDP account type IDPAccount struct { + Name string AppID string `ini:"app_id"` // used by OneLogin and AzureAD URL string `ini:"url"` Username string `ini:"username"` @@ -47,6 +48,7 @@ type IDPAccount struct { HttpAttemptsCount string `ini:"http_attempts_count"` HttpRetryDelay string `ini:"http_retry_delay"` CredentialsFile string `ini:"credentials_file"` + SAMLCache bool `ini:"saml_cache"` } func (ia IDPAccount) String() string { @@ -195,6 +197,9 @@ func (cm *ConfigManager) LoadIDPAccount(idpAccountName string) (*IDPAccount, err return nil, errors.Wrap(err, "Unable to read idp account") } + // adding Name at Load time for the IdpAccount to have awareness of "self" + account.Name = idpAccountName + return account, nil } diff --git a/pkg/cfg/cfg_test.go b/pkg/cfg/cfg_test.go index e039379de..f59e44bff 100644 --- a/pkg/cfg/cfg_test.go +++ b/pkg/cfg/cfg_test.go @@ -27,6 +27,7 @@ func TestNewConfigManagerLoad(t *testing.T) { idpAccount, err := cfgm.LoadIDPAccount("test123") require.Nil(t, err) require.Equal(t, &IDPAccount{ + Name: "test123", URL: "https://id.whatever.com", Username: "abc@whatever.com", Provider: "keycloak", @@ -61,6 +62,7 @@ func TestNewConfigManagerSave(t *testing.T) { idpAccount, err := cfgm.LoadIDPAccount("testing2") require.Nil(t, err) require.Equal(t, &IDPAccount{ + Name: "testing2", URL: "https://id.whatever.com", Username: "abc@whatever.com", Provider: "keycloak", diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index b3277102d..a4655e566 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -28,6 +28,7 @@ type CommonFlags struct { DisableKeychain bool Region string CredentialsFile string + SAMLCache bool } // LoginExecFlags flags for the Login / Exec commands @@ -98,4 +99,7 @@ func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) { if commonFlags.CredentialsFile != "" { account.CredentialsFile = commonFlags.CredentialsFile } + if commonFlags.SAMLCache { + account.SAMLCache = commonFlags.SAMLCache + } } diff --git a/pkg/samlcache/samlcache.go b/pkg/samlcache/samlcache.go new file mode 100644 index 000000000..b91f9b432 --- /dev/null +++ b/pkg/samlcache/samlcache.go @@ -0,0 +1,142 @@ +package samlcache + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "path/filepath" + "runtime" + "time" + + homedir "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +var ( + ErrInvalidCachePath = errors.New("Cannot evaluate Cache file path") + logger = logrus.WithField("pkg", "samlcache") +) + +const ( + SAMLAssertionValidityTimeout = 5 * time.Minute + SAMLCacheFilePermissions = 0600 + SAMLCacheDirPermissions = 0700 + SAMLCacheDir = "saml2aws" +) + +// SAMLCacheProvider loads aws credentials file +type SAMLCacheProvider struct { + Filename string + Account string +} + +func resolveSymlink(filename string) (string, error) { + sympath, err := filepath.EvalSymlinks(filename) + + // return the un modified filename + if os.IsNotExist(err) { + return filename, nil + } + if err != nil { + return "", err + } + + return sympath, nil +} + +func (p *SAMLCacheProvider) IsValid() bool { + var cache_path string + var err error + if p.Filename == "" { + cache_path, err = locateCacheFile(p.Account) + if err != nil { + return false + } + } else { + cache_path = p.Filename + } + + fileInfo, err := os.Stat(cache_path) + if err != nil { + return false + } + + return time.Since(fileInfo.ModTime()) < SAMLAssertionValidityTimeout +} + +func locateCacheFile(account string) (string, error) { + + var name, filename string + var err error + if account == "" { + filename = "cache" + } else { + filename = fmt.Sprintf("cache_%s", account) + } + if runtime.GOOS == "windows" { + name = path.Join(os.Getenv("USERPROFILE"), ".aws", SAMLCacheDir, filename) + } else { + name, err = homedir.Expand(path.Join("~", ".aws", SAMLCacheDir, filename)) + if err != nil { + return "", ErrInvalidCachePath + } + } + // is the filename a symlink? + name, err = resolveSymlink(name) + if err != nil { + return "", errors.Wrap(err, "unable to resolve symlink") + } + + logger.WithField("name", name).Debug("resolveSymlink") + + return name, nil +} + +func (p *SAMLCacheProvider) Read() (string, error) { + + var cache_path string + var err error + if p.Filename == "" { + cache_path, err = locateCacheFile(p.Account) + if err != nil { + return "", errors.Wrap(err, "Could not retrieve cache file path") + } + } else { + cache_path = p.Filename + } + + content, err := ioutil.ReadFile(cache_path) + if err != nil { + return "", errors.Wrap(err, "Could not read the cache file path") + } + + return string(content), nil +} + +func (p *SAMLCacheProvider) Write(samlAssertion string) error { + + var cache_path string + var err error + if p.Filename == "" { + cache_path, err = locateCacheFile(p.Account) + if err != nil { + return errors.Wrap(err, "Could not retrieve cache file path") + } + } else { + cache_path = p.Filename + } + + // create the directory if it doesn't exist + err = os.MkdirAll(path.Dir(cache_path), SAMLCacheDirPermissions) + if err != nil { + return errors.Wrap(err, "Could not write the cache file directory") + } + err = ioutil.WriteFile(cache_path, []byte(samlAssertion), SAMLCacheFilePermissions) + if err != nil { + return errors.Wrap(err, "Could not write the cache file path") + } + + return nil +} diff --git a/pkg/samlcache/samlcache_test.go b/pkg/samlcache/samlcache_test.go new file mode 100644 index 000000000..519ddbf51 --- /dev/null +++ b/pkg/samlcache/samlcache_test.go @@ -0,0 +1,109 @@ +package samlcache + +import ( + "io/ioutil" + "os" + "path" + "testing" + "time" +) + +func TestLocateCacheDefault(t *testing.T) { + + cache_location, err := locateCacheFile("") + if err != nil { + t.Error("Could not locate cache file:", err) + } + + if cache_location == "" { + t.Error("Retrieved location is empty") + } + + if path.Base(cache_location) != "cache" { + t.Error("Filename is not the default one (cache):", path.Base(cache_location)) + } + +} + +func TestLocateCacheAccount(t *testing.T) { + + cache_location, err := locateCacheFile("myaccount") + if err != nil { + t.Error("Could not locate cache file:", err) + } + + if cache_location == "" { + t.Error("Retrieved location is empty") + } + + if path.Base(cache_location) != "cache_myaccount" { + t.Error("Filename is not the default one (cache_myaccount):", path.Base(cache_location)) + } + +} + +func TestCanWrite(t *testing.T) { + + p := SAMLCacheProvider{ + Filename: "testdir/cache_file", + } + + err := p.Write("test_write_cache") + if err != nil { + t.Error("Could not write cache:", err) + } + + if _, err := os.Stat("testdir/cache_file"); os.IsNotExist(err) { + t.Error("The cache file was not created:", err) + } + + os.RemoveAll("testdir") + +} + +func TestCanRead(t *testing.T) { + + // create a dummy file + _ = ioutil.WriteFile("example_cache", []byte("testing output"), 0700) + + p := SAMLCacheProvider{ + Filename: "example_cache", + } + + output, err := p.Read() + if err != nil { + t.Error("Could not read cache:", err) + } + + if output != "testing output" { + t.Error("Cache file does not contain the right thing", output) + } + + os.Remove("example_cache") + +} + +func TestIsValid(t *testing.T) { + + // create a dummy file + _ = ioutil.WriteFile("example_cache", []byte("testing output"), 0700) + p := SAMLCacheProvider{ + Filename: "example_cache", + } + + if !p.IsValid() { + t.Error("Cache file is not valid!") + } + + // changing the file timestamp to validate expiry + // new_time := time.Now().Sub(24 * time.Hour) + new_time := time.Now().Add(-24 * time.Hour) + _ = os.Chtimes("example_cache", new_time, new_time) + + if p.IsValid() { + t.Error("Cache file should be expired!") + } + + os.Remove("example_cache") + +}