diff --git a/README.md b/README.md index f6f0c6d..d72e912 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ | Local provider | ✅ Implemented | | [HashiCorp Vault](https://www.vaultproject.io) | ✅ Implemented | | [OpenBao](https://github.com/openbao/openbao) | ✅ Implemented | -| [AWS Secrets Manager](https://aws.amazon.com/secrets-manager)| Upcoming | +| [AWS Secrets Manager](https://aws.amazon.com/secrets-manager)| ✅ Implemented | ## Getting started diff --git a/docker-compose.yaml b/docker-compose.yaml index 2a7b1a3..819a457 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,3 @@ -version: "3.9" - services: vault: container_name: secret-init-vault diff --git a/env_store.go b/env_store.go index a06ec75..1ec1b03 100644 --- a/env_store.go +++ b/env_store.go @@ -24,6 +24,7 @@ import ( "github.com/bank-vaults/secret-init/pkg/common" "github.com/bank-vaults/secret-init/pkg/provider" + "github.com/bank-vaults/secret-init/pkg/provider/aws" "github.com/bank-vaults/secret-init/pkg/provider/bao" "github.com/bank-vaults/secret-init/pkg/provider/file" "github.com/bank-vaults/secret-init/pkg/provider/vault" @@ -33,6 +34,7 @@ var supportedProviders = []string{ file.ProviderName, vault.ProviderName, bao.ProviderName, + aws.ProviderName, } // EnvStore is a helper for managing interactions between environment variables and providers, @@ -57,45 +59,45 @@ func NewEnvStore(appConfig *common.Config) *EnvStore { } } -// GetProviderPaths returns a map of secret paths for each provider -func (s *EnvStore) GetProviderPaths() map[string][]string { - providerPaths := make(map[string][]string) +// GetSecretReferences returns a map of secret key=value pairs for each provider +func (s *EnvStore) GetSecretReferences() map[string][]string { + secretReferences := make(map[string][]string) - for envKey, path := range s.data { - providerName, path := getProviderPath(path) + for envKey, envPath := range s.data { + providerName, envSecretReference := getProviderPath(envPath) + envSecretReference = envKey + "=" + envSecretReference switch providerName { case file.ProviderName: - providerPaths[file.ProviderName] = append(providerPaths[file.ProviderName], path) + secretReferences[file.ProviderName] = append(secretReferences[file.ProviderName], envSecretReference) case vault.ProviderName: - // The injector function expects a map of key:value pairs - path = envKey + "=" + path - providerPaths[vault.ProviderName] = append(providerPaths[vault.ProviderName], path) + secretReferences[vault.ProviderName] = append(secretReferences[vault.ProviderName], envSecretReference) case bao.ProviderName: - // The injector function expects a map of key:value pairs - path = envKey + "=" + path - providerPaths[bao.ProviderName] = append(providerPaths[bao.ProviderName], path) + secretReferences[bao.ProviderName] = append(secretReferences[bao.ProviderName], envSecretReference) + + case aws.ProviderName: + secretReferences[aws.ProviderName] = append(secretReferences[aws.ProviderName], envSecretReference) } } - return providerPaths + return secretReferences } // LoadProviderSecrets creates a new provider for each detected provider using a specified config. // It then asynchronously loads secrets using each provider and it's corresponding paths. -// The secrets from each provider are then placed into a map with the provider name as the key. -func (s *EnvStore) LoadProviderSecrets(providerPaths map[string][]string) (map[string][]provider.Secret, error) { +// The secrets from each provider are then placed into a single slice. +func (s *EnvStore) LoadProviderSecrets(ctx context.Context, providerPaths map[string][]string) ([]provider.Secret, error) { // At most, we will have one error per provider errCh := make(chan error, len(supportedProviders)) - providerSecrets := make(map[string][]provider.Secret) + var providerSecrets []provider.Secret // Workaround for openBao // Remove once openBao uses BAO_ADDR in their client, instead of VAULT_ADDR vaultPaths, ok := providerPaths[vault.ProviderName] if ok { var err error - providerSecrets[vault.ProviderName], err = s.workaroundForBao(vaultPaths) + providerSecrets, err = s.workaroundForBao(vaultPaths) if err != nil { return nil, fmt.Errorf("failed to workaround for bao: %w", err) } @@ -119,14 +121,14 @@ func (s *EnvStore) LoadProviderSecrets(providerPaths map[string][]string) (map[s return } - secrets, err := provider.LoadSecrets(context.Background(), paths) + secrets, err := provider.LoadSecrets(ctx, paths) if err != nil { errCh <- fmt.Errorf("failed to load secrets for provider %s: %w", providerName, err) return } mu.Lock() - providerSecrets[providerName] = secrets + providerSecrets = append(providerSecrets, secrets...) mu.Unlock() }(providerName, paths, errCh) } @@ -167,25 +169,11 @@ func (s *EnvStore) workaroundForBao(vaultPaths []string) ([]provider.Secret, err } // ConvertProviderSecrets converts the loaded secrets to environment variables -func (s *EnvStore) ConvertProviderSecrets(providerSecrets map[string][]provider.Secret) ([]string, error) { +func (s *EnvStore) ConvertProviderSecrets(providerSecrets []provider.Secret) ([]string, error) { var secretsEnv []string - for providerName, secrets := range providerSecrets { - switch providerName { - case vault.ProviderName, bao.ProviderName: - // The Vault and Bao providers already returns the secrets with the environment variable keys - for _, secret := range secrets { - secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", secret.Path, secret.Value)) - } - - default: - secrets, err := createSecretEnvsFrom(s.data, secrets) - if err != nil { - return nil, fmt.Errorf("failed to create secret environment variables: %w", err) - } - - secretsEnv = append(secretsEnv, secrets...) - } + for _, secret := range providerSecrets { + secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", secret.Key, secret.Value)) } return secretsEnv, nil @@ -194,24 +182,28 @@ func (s *EnvStore) ConvertProviderSecrets(providerSecrets map[string][]provider. // Returns the detected provider name and path with removed prefix func getProviderPath(path string) (string, string) { if strings.HasPrefix(path, "file:") { - var fileProviderName = file.ProviderName - return fileProviderName, strings.TrimPrefix(path, "file:") + return file.ProviderName, path } // If the path contains some string formatted as "vault:{STR}#{STR}" // it is most probably a vault path if vault.ProviderEnvRegex.MatchString(path) { - // Do not remove the prefix since it will be processed during injection return vault.ProviderName, path } // If the path contains some string formatted as "bao:{STR}#{STR}" // it is most probably a vault path if bao.ProviderEnvRegex.MatchString(path) { - // Do not remove the prefix since it will be processed during injection return bao.ProviderName, path } + // Example AWS prefixes: + // arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret + // arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter + if strings.HasPrefix(path, "arn:aws:secretsmanager:") || strings.HasPrefix(path, "arn:aws:ssm:") { + return aws.ProviderName, path + } + return "", path } @@ -249,33 +241,16 @@ func newProvider(providerName string, appConfig *common.Config) (provider.Provid } return provider, nil - default: - return nil, fmt.Errorf("provider %s is not supported", providerName) - } -} - -func createSecretEnvsFrom(envs map[string]string, secrets []provider.Secret) ([]string, error) { - // Reverse the map so we can match - // the environment variable key to the secret - // by using the secret path - reversedEnvs := make(map[string]string) - for envKey, path := range envs { - providerName, path := getProviderPath(path) - if providerName != "" { - reversedEnvs[path] = envKey + case aws.ProviderName: + config, err := aws.LoadConfig() + if err != nil { + return nil, fmt.Errorf("failed to create aws config: %w", err) } - } - var secretsEnv []string - for _, secret := range secrets { - path := secret.Path - key, ok := reversedEnvs[path] - if !ok { - return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", path) - } + provider := aws.NewProvider(config) + return provider, nil - secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", key, secret.Value)) + default: + return nil, fmt.Errorf("provider %s is not supported", providerName) } - - return secretsEnv, nil } diff --git a/env_store_test.go b/env_store_test.go index 63e1290..19c9642 100644 --- a/env_store_test.go +++ b/env_store_test.go @@ -15,6 +15,7 @@ package main import ( + "context" "fmt" "os" "testing" @@ -25,7 +26,7 @@ import ( "github.com/bank-vaults/secret-init/pkg/provider" ) -func TestEnvStore_GetProviderPaths(t *testing.T) { +func TestEnvStore_GetSecretReferences(t *testing.T) { tests := []struct { name string envs map[string]string @@ -38,7 +39,7 @@ func TestEnvStore_GetProviderPaths(t *testing.T) { }, wantPaths: map[string][]string{ "file": { - "secret/data/test/aws", + "AWS_SECRET_ACCESS_KEY_ID=file:secret/data/test/aws", }, }, }, @@ -65,20 +66,68 @@ func TestEnvStore_GetProviderPaths(t *testing.T) { }, }, }, + { + name: "bao provider", + envs: map[string]string{ + "ACCOUNT_PASSWORD_1": "bao:secret/data/account#password#1", + "ACCOUNT_PASSWORD": "bao:secret/data/account#password", + "ROOT_CERT": ">>bao:pki/root/generate/internal#certificate", + "ROOT_CERT_CACHED": ">>bao:pki/root/generate/internal#certificate", + "INLINE_SECRET": "scheme://${bao:secret/data/account#username}:${bao:secret/data/account#password}@127.0.0.1:8080", + "INLINE_SECRET_EMBEDDED_TEMPLATE": "scheme://${bao:secret/data/account#username}:${bao:secret/data/account#${.password | urlquery}}@127.0.0.1:8080", + "INLINE_DYNAMIC_SECRET": "${>>bao:pki/root/generate/internal#certificate}__${>>bao:pki/root/generate/internal#certificate}", + }, + wantPaths: map[string][]string{ + "bao": { + "ACCOUNT_PASSWORD_1=bao:secret/data/account#password#1", + "ACCOUNT_PASSWORD=bao:secret/data/account#password", + "ROOT_CERT=>>bao:pki/root/generate/internal#certificate", + "ROOT_CERT_CACHED=>>bao:pki/root/generate/internal#certificate", + "INLINE_SECRET=scheme://${bao:secret/data/account#username}:${bao:secret/data/account#password}@127.0.0.1:8080", + "INLINE_SECRET_EMBEDDED_TEMPLATE=scheme://${bao:secret/data/account#username}:${bao:secret/data/account#${.password | urlquery}}@127.0.0.1:8080", + "INLINE_DYNAMIC_SECRET=${>>bao:pki/root/generate/internal#certificate}__${>>bao:pki/root/generate/internal#certificate}", + }, + }, + }, + { + name: "aws provider", + envs: map[string]string{ + "AWS_SECRET1": "arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret", + "AWS_SECRET2": "arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter", + }, + wantPaths: map[string][]string{ + "aws": { + "AWS_SECRET1=arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret", + "AWS_SECRET2=arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter", + }, + }, + }, { name: "multi provider", envs: map[string]string{ "AWS_SECRET_ACCESS_KEY_ID": "file:secret/data/test/aws", "MYSQL_PASSWORD": "vault:secret/data/test/mysql#MYSQL_PASSWORD", "AWS_SECRET_ACCESS_KEY": "vault:secret/data/test/aws#AWS_SECRET_ACCESS_KEY", + "RABBITMQ_USERNAME": "bao:secret/data/test/rabbitmq#RABBITMQ_USERNAME", + "RABBITMQ_PASSWORD": "bao:secret/data/test/rabbitmq#RABBITMQ_PASSWORD", + "AWS_SECRET1": "arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret", + "AWS_SECRET2": "arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter", }, wantPaths: map[string][]string{ + "file": { + "AWS_SECRET_ACCESS_KEY_ID=file:secret/data/test/aws", + }, "vault": { "MYSQL_PASSWORD=vault:secret/data/test/mysql#MYSQL_PASSWORD", "AWS_SECRET_ACCESS_KEY=vault:secret/data/test/aws#AWS_SECRET_ACCESS_KEY", }, - "file": { - "secret/data/test/aws", + "bao": { + "RABBITMQ_USERNAME=bao:secret/data/test/rabbitmq#RABBITMQ_USERNAME", + "RABBITMQ_PASSWORD=bao:secret/data/test/rabbitmq#RABBITMQ_PASSWORD", + }, + "aws": { + "AWS_SECRET1=arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret", + "AWS_SECRET2=arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter", }, }, }, @@ -95,7 +144,7 @@ func TestEnvStore_GetProviderPaths(t *testing.T) { os.Clearenv() }) - paths := NewEnvStore(&common.Config{}).GetProviderPaths() + paths := NewEnvStore(&common.Config{}).GetSecretReferences() for key, expectedSlice := range ttp.wantPaths { actualSlice, ok := paths[key] @@ -113,7 +162,7 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { tests := []struct { name string providerPaths map[string][]string - wantProviderSecrets map[string][]provider.Secret + wantProviderSecrets []provider.Secret addvault bool err error }{ @@ -121,15 +170,13 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { name: "Load secrets successfully", providerPaths: map[string][]string{ "file": { - secretFile, + "AWS_SECRET_ACCESS_KEY_ID=file:" + secretFile, }, }, - wantProviderSecrets: map[string][]provider.Secret{ - "file": { - { - Path: secretFile, - Value: "secretId", - }, + wantProviderSecrets: []provider.Secret{ + { + Key: "AWS_SECRET_ACCESS_KEY_ID", + Value: "secretId", }, }, addvault: false, @@ -138,7 +185,7 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { name: "Fail to create provider", providerPaths: map[string][]string{ "invalid": { - secretFile, + "AWS_SECRET_ACCESS_KEY_ID=file:" + secretFile, }, }, addvault: false, @@ -151,7 +198,7 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { t.Run(ttp.name, func(t *testing.T) { createEnvsForProvider(ttp.addvault, secretFile) - providerSecrets, err := NewEnvStore(&common.Config{}).LoadProviderSecrets(ttp.providerPaths) + providerSecrets, err := NewEnvStore(&common.Config{}).LoadProviderSecrets(context.Background(), ttp.providerPaths) if err != nil { assert.EqualError(t, ttp.err, err.Error(), "Unexpected error message") } @@ -168,19 +215,17 @@ func TestEnvStore_ConvertProviderSecrets(t *testing.T) { tests := []struct { name string - providerSecrets map[string][]provider.Secret + providerSecrets []provider.Secret wantSecretsEnv []string addvault bool err error }{ { name: "Convert secrets successfully", - providerSecrets: map[string][]provider.Secret{ - "file": { - { - Path: secretFile, - Value: "secretId", - }, + providerSecrets: []provider.Secret{ + { + Key: "AWS_SECRET_ACCESS_KEY_ID", + Value: "secretId", }, }, wantSecretsEnv: []string{ @@ -188,19 +233,6 @@ func TestEnvStore_ConvertProviderSecrets(t *testing.T) { }, addvault: false, }, - { - name: "Fail to convert secrets due to fail to find env-key", - providerSecrets: map[string][]provider.Secret{ - "file": { - { - Path: secretFile + "/invalid", - Value: "secretId", - }, - }, - }, - addvault: false, - err: fmt.Errorf("failed to create secret environment variables: failed to find environment variable key for secret path: " + secretFile + "/invalid"), - }, } for _, tt := range tests { diff --git a/examples/README.md b/examples/README.md index a10c2a3..6a2a193 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,6 +7,7 @@ Discover a range of examples that highlight the functionalities of **secret-init - [File provider](file-provider.md) - [Vault provider](vault-provider.md) - [Bao provider](bao-provider.md) +- [AWS provider](aws-provider.md) ## Multi provider use-case diff --git a/examples/aws-provider.md b/examples/aws-provider.md new file mode 100644 index 0000000..5521967 --- /dev/null +++ b/examples/aws-provider.md @@ -0,0 +1,59 @@ +# AWS-provider + +## Overview + +The AWS Provider in Secret-init can load secrets from AWS Secrets Manager and AWS Systems Manager (SSM) Parameter Store as well. + +## Prerequisites + +- Golang `>= 1.21` +- Makefile +- Access to AWS services + +## Environment setup + +```bash +# Secret-ini requires atleast these environment variables to be set properly +export AWS_ACCESS_KEY_ID +export AWS_SECRET_ACCESS_KEY +export AWS_REGION +``` + +## Define secrets to inject + +```bash +# Export environment variables +export MYSQL_PASSWORD=arn:aws:secretsmanager:eu-north-1:123456789:secret:secret/test/mysql-ASD123 +export SM_JSON=arn:aws:secretsmanager:eu-north-1:123456789:secret:test/secret/JSON-ASD123 +export SSM_SECRET=arn:aws:ssm:eu-north-1:123456789:parameter/bank-vaults/test + +# NOTE: Secret-init is designed to identify any secret-reference that starts with "arn:aws:secretsmanager:" or "arn:aws:ssm:" +``` + +## Run secret-init + +```bash +# Build the secret-init binary +go build + +# Use in daemon mode +SECRET_INIT_DAEMON="true" + +# Run secret-init with a command e.g. +./secret-init env | grep 'MYSQL_PASSWORD\|SM_JSON\|SSM_SECRET' + +# JSON secrets are loaded as is: +# SM_JSON="{"firsts3cr3t":"s3cr3ton3","seconds3cr3t":"s3cr3ttwo"}" +``` + +## Cleanup + +```bash +# Remove binary +rm -rf secret-init + +# Unset the environment variables +unset MYSQL_PASSWORD +unset SM_JSON +unset SSM_SECRET +``` diff --git a/go.mod b/go.mod index d66d307..d961f4e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/bank-vaults/secret-init go 1.21.1 require ( - emperror.dev/errors v0.8.1 + github.com/aws/aws-sdk-go v1.51.6 github.com/bank-vaults/internal v0.3.0 github.com/bank-vaults/vault-sdk v0.9.3 github.com/hashicorp/vault/api v1.12.2 @@ -20,6 +20,7 @@ require ( cloud.google.com/go/iam v1.1.6 // indirect cloud.google.com/go/kms v1.15.8 // indirect cloud.google.com/go/storage v1.39.1 // indirect + emperror.dev/errors v0.8.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.10.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect @@ -30,7 +31,6 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/Masterminds/sprig/v3 v3.2.3 // indirect - github.com/aws/aws-sdk-go v1.51.6 // indirect github.com/aws/aws-sdk-go-v2 v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 // indirect github.com/aws/aws-sdk-go-v2/config v1.27.7 // indirect diff --git a/main.go b/main.go index 28b8155..64c73e5 100644 --- a/main.go +++ b/main.go @@ -52,9 +52,9 @@ func main() { // Fetch all provider secrets and assemble env variables using envstore envStore := NewEnvStore(config) - providerPaths := envStore.GetProviderPaths() + secretReferences := envStore.GetSecretReferences() - providerSecrets, err := envStore.LoadProviderSecrets(providerPaths) + providerSecrets, err := envStore.LoadProviderSecrets(context.Background(), secretReferences) if err != nil { slog.Error(fmt.Errorf("failed to extract secrets: %w", err).Error()) os.Exit(1) diff --git a/pkg/provider/aws/aws.go b/pkg/provider/aws/aws.go new file mode 100644 index 0000000..adb0d2a --- /dev/null +++ b/pkg/provider/aws/aws.go @@ -0,0 +1,154 @@ +// Copyright © 2024 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aws + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/ssm" + + "github.com/bank-vaults/secret-init/pkg/provider" +) + +var ProviderName = "aws" + +type Provider struct { + sm *secretsmanager.SecretsManager + ssm *ssm.SSM +} + +func NewProvider(config *Config) *Provider { + return &Provider{ + sm: secretsmanager.New(config.session), + ssm: ssm.New(config.session), + } +} + +func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider.Secret, error) { + var secrets []provider.Secret + + for _, path := range paths { + split := strings.SplitN(path, "=", 2) + originalKey, secretID := split[0], split[1] + + // valid secretsmanager secret examples: + // arn:aws:secretsmanager:region:account-id:secret:secret-name + // secretsmanager:secret-name + if strings.Contains(secretID, "secretsmanager:") { + secret, err := p.sm.GetSecretValueWithContext( + ctx, + &secretsmanager.GetSecretValueInput{ + SecretId: aws.String(secretID), + }) + if err != nil { + return nil, fmt.Errorf("failed to get secret from AWS secrets manager: %w", err) + } + + secretBytes, err := extractSecretValueFromSM(secret) + if err != nil { + return nil, fmt.Errorf("failed to extract secret value from AWS secrets manager: %w", err) + } + + secretValue, err := parseSecretValueFromSM(secretBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse secret value from AWS secrets manager: %w", err) + } + + secrets = append(secrets, provider.Secret{ + Key: originalKey, + Value: string(secretValue), + }) + } + + // Valid ssm parameter examples: + // arn:aws:ssm:region:account-id:parameter/path/to/parameter-name + // arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter + if strings.Contains(secretID, "ssm:") { + parameteredSecret, err := p.ssm.GetParameterWithContext( + ctx, + &ssm.GetParameterInput{ + Name: aws.String(secretID), + WithDecryption: aws.Bool(true), + }) + if err != nil { + return nil, fmt.Errorf("failed to get secret from AWS SSM: %w", err) + } + + secrets = append(secrets, provider.Secret{ + Key: originalKey, + Value: aws.StringValue(parameteredSecret.Parameter.Value), + }) + } + } + + return secrets, nil +} + +// AWS Secrets Manager can store secrets in two formats: +// - SecretString: for text-based secrets, returned as a byte slice. +// - SecretBinary: for binary secrets, returned as a byte slice without additional encoding. +// If neither is available, the function returns an error. +// +// Ref: https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html +func extractSecretValueFromSM(secret *secretsmanager.GetSecretValueOutput) ([]byte, error) { + // Secret available as string + if secret.SecretString != nil { + return []byte(aws.StringValue(secret.SecretString)), nil + } + + // Secret available as binary + if secret.SecretBinary != nil { + return secret.SecretBinary, nil + } + + // Handle the case where neither SecretString nor SecretBinary is available + return []byte{}, fmt.Errorf("secret does not contain a value in expected formats") +} + +// parseSecretValueFromSM takes a secret and attempts to parse it. +// It unifies the handling of all secrets coming from AWS SM, +// ensuring the output is consistent in the form of a []byte slice. +func parseSecretValueFromSM(secretBytes []byte) ([]byte, error) { + // If the secret is not a JSON object, append it as a single secret + if !json.Valid(secretBytes) { + return secretBytes, nil + } + + var secretValue map[string]interface{} + err := json.Unmarshal(secretBytes, &secretValue) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal secret from AWS Secrets Manager: %w", err) + } + + // If the JSON object contains a single key-value pair, the value is the actual secret + if len(secretValue) == 1 { + for _, value := range secretValue { + valueBytes, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("failed to marshal secret from map: %w", err) + } + + return valueBytes, nil + } + } + + // For JSON objects with multiple key-value pairs, the original JSON is returned as is + return secretBytes, nil +} diff --git a/pkg/provider/aws/aws_test.go b/pkg/provider/aws/aws_test.go new file mode 100644 index 0000000..5ae5897 --- /dev/null +++ b/pkg/provider/aws/aws_test.go @@ -0,0 +1,55 @@ +// Copyright © 2024 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/stretchr/testify/assert" +) + +func TestNewProvider(t *testing.T) { + tests := []struct { + name string + config *Config + wantType bool + }{ + { + name: "Valid config", + config: &Config{ + session: createSession(), + }, + wantType: true, + }, + } + + for _, tt := range tests { + ttp := tt + t.Run(ttp.name, func(t *testing.T) { + provider := NewProvider(ttp.config) + if ttp.wantType { + assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") + } + }) + } + +} + +func createSession() *session.Session { + return session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigDisable, + })) +} diff --git a/pkg/provider/aws/config.go b/pkg/provider/aws/config.go new file mode 100644 index 0000000..23db45a --- /dev/null +++ b/pkg/provider/aws/config.go @@ -0,0 +1,75 @@ +// Copyright © 2024 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aws + +import ( + "fmt" + "os" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/spf13/cast" +) + +const ( + LoadFromSharedConfigEnv = "AWS_LOAD_FROM_SHARED_CONFIG" + DefaultRegionEnv = "AWS_DEFAULT_REGION" + RegionEnv = "AWS_REGION" +) + +type Config struct { + session *session.Session +} + +func LoadConfig() (*Config, error) { + // Loading session data from shared config is disabled by default and needs to be + // explicitly enabled via AWS_LOAD_FROM_SHARED_CONFIG + options := session.Options{ + SharedConfigState: session.SharedConfigDisable, + } + + // Override session options from env configs + if cast.ToBool(os.Getenv(LoadFromSharedConfigEnv)) { + options.SharedConfigState = session.SharedConfigEnable + } + + if region := getRegionEnv(); region != nil { + options.Config = aws.Config{Region: region} + } + + // Create session + sess, err := session.NewSessionWithOptions(options) + if err != nil { + return nil, fmt.Errorf("failed to create AWS session: %w", err) + } + + return &Config{session: sess}, nil +} + +func getRegionEnv() *string { + region, hasRegion := os.LookupEnv(RegionEnv) + if hasRegion { + return aws.String(region) + } + + defaultRegion, hasDefaultRegion := os.LookupEnv(DefaultRegionEnv) + if hasDefaultRegion { + return aws.String(defaultRegion) + } + + // Return nil if no region is found, allowing the AWS SDK to attempt to + // determine the region from the shared config or environment variables. + return nil +} diff --git a/pkg/provider/bao/bao.go b/pkg/provider/bao/bao.go index e712c3f..78498df 100644 --- a/pkg/provider/bao/bao.go +++ b/pkg/provider/bao/bao.go @@ -56,10 +56,9 @@ func (s *sanitized) append(key string, value string) { // it is not a BAO_* variable. // Additionally, in a login scenario, we include BAO_* variables in the secrets list. if !ok || (s.login && envType.login) { - // Path here is actually the secret's key, - // An example of this can be found at the LoadSecrets() function below + // Example can be found at the LoadSecrets() function below secret := provider.Secret{ - Path: key, + Key: key, Value: value, } @@ -159,8 +158,7 @@ func parsePathsToMap(paths []string) map[string]string { for _, path := range paths { split := strings.SplitN(path, "=", 2) - key := split[0] - value := split[1] + key, value := split[0], split[1] baoEnviron[key] = value } diff --git a/pkg/provider/bao/daemon_secret_renewer.go b/pkg/provider/bao/daemon_secret_renewer.go index fa1e7e0..fb0c270 100644 --- a/pkg/provider/bao/daemon_secret_renewer.go +++ b/pkg/provider/bao/daemon_secret_renewer.go @@ -15,12 +15,12 @@ package bao import ( + "fmt" "log/slog" "os" "syscall" "time" - "emperror.dev/errors" bao "github.com/bank-vaults/vault-sdk/vault" baoapi "github.com/hashicorp/vault/api" ) @@ -34,7 +34,7 @@ func (r daemonSecretRenewer) Renew(path string, secret *baoapi.Secret) error { watcherInput := baoapi.LifetimeWatcherInput{Secret: secret} watcher, err := r.client.RawClient().NewLifetimeWatcher(&watcherInput) if err != nil { - return errors.Wrap(err, "failed to create secret watcher") + return fmt.Errorf("failed to create lifetime watcher: %w", err) } go watcher.Start() diff --git a/pkg/provider/file/file.go b/pkg/provider/file/file.go index 2869b00..7543ad1 100644 --- a/pkg/provider/file/file.go +++ b/pkg/provider/file/file.go @@ -48,23 +48,27 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se var secrets []provider.Secret for _, path := range paths { - secret, err := p.getSecretFromFile(path) + split := strings.SplitN(path, "=", 2) + key, valuePath := split[0], split[1] + valuePath = strings.TrimPrefix(valuePath, "file:") + + secretValue, err := p.getSecretFromFile(valuePath) if err != nil { return nil, fmt.Errorf("failed to get secret from file: %w", err) } secrets = append(secrets, provider.Secret{ - Path: path, - Value: secret, + Key: key, + Value: secretValue, }) } return secrets, nil } -func (p *Provider) getSecretFromFile(path string) (string, error) { - path = strings.TrimLeft(path, "/") - content, err := fs.ReadFile(p.fs, path) +func (p *Provider) getSecretFromFile(valuePath string) (string, error) { + valuePath = strings.TrimLeft(valuePath, "/") + content, err := fs.ReadFile(p.fs, valuePath) if err != nil { return "", fmt.Errorf("failed to read file: %w", err) } diff --git a/pkg/provider/file/file_test.go b/pkg/provider/file/file_test.go index de31025..f253e8e 100644 --- a/pkg/provider/file/file_test.go +++ b/pkg/provider/file/file_test.go @@ -91,22 +91,22 @@ func TestLoadSecrets(t *testing.T) { { name: "Load secrets successfully", paths: []string{ - "test/secrets/sqlpass.txt", - "test/secrets/awsaccess.txt", - "test/secrets/awsid.txt", + "MYSQL_PASSWORD=file:test/secrets/sqlpass.txt", + "AWS_SECRET_ACCESS_KEY=file:test/secrets/awsaccess.txt", + "AWS_ACCESS_KEY_ID=file:test/secrets/awsid.txt", }, wantSecrets: []provider.Secret{ - {Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"}, - {Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"}, - {Path: "test/secrets/awsid.txt", Value: "secretId"}, + {Key: "MYSQL_PASSWORD", Value: "3xtr3ms3cr3t"}, + {Key: "AWS_SECRET_ACCESS_KEY", Value: "s3cr3t"}, + {Key: "AWS_ACCESS_KEY_ID", Value: "secretId"}, }, }, { name: "Fail to load secrets due to invalid path", paths: []string{ - "test/secrets/mistake/sqlpass.txt", - "test/secrets/mistake/awsaccess.txt", - "test/secrets/mistake/awsid.txt", + "MYSQL_PASSWORD=file:test/secrets/mistake/sqlpass.txt", + "AWS_SECRET_ACCESS_KEY=file:test/secrets/mistake/awsaccess.txt", + "AWS_ACCESS_KEY_ID=file:test/secrets/mistake/awsid.txt", }, err: fmt.Errorf("failed to get secret from file: failed to read file: open test/secrets/mistake/sqlpass.txt: file does not exist"), }, diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 14e07fc..65b0cd9 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -23,6 +23,6 @@ type Provider interface { // Secret holds Provider-specific secret data. type Secret struct { - Path string + Key string Value string } diff --git a/pkg/provider/vault/daemon_secret_renewer.go b/pkg/provider/vault/daemon_secret_renewer.go index 2ed5ba5..5708e05 100644 --- a/pkg/provider/vault/daemon_secret_renewer.go +++ b/pkg/provider/vault/daemon_secret_renewer.go @@ -15,12 +15,12 @@ package vault import ( + "fmt" "log/slog" "os" "syscall" "time" - "emperror.dev/errors" "github.com/bank-vaults/vault-sdk/vault" vaultapi "github.com/hashicorp/vault/api" ) @@ -34,7 +34,7 @@ func (r daemonSecretRenewer) Renew(path string, secret *vaultapi.Secret) error { watcherInput := vaultapi.LifetimeWatcherInput{Secret: secret} watcher, err := r.client.RawClient().NewLifetimeWatcher(&watcherInput) if err != nil { - return errors.Wrap(err, "failed to create secret watcher") + return fmt.Errorf("failed to create lifetime watcher: %w", err) } go watcher.Start() diff --git a/pkg/provider/vault/vault.go b/pkg/provider/vault/vault.go index e83d1ee..27657ec 100644 --- a/pkg/provider/vault/vault.go +++ b/pkg/provider/vault/vault.go @@ -56,10 +56,9 @@ func (s *sanitized) append(key string, value string) { // it is not a VAULT_* variable. // Additionally, in a login scenario, we include VAULT_* variables in the secrets list. if !ok || (s.login && envType.login) { - // Path here is actually the secret's key, - // An example of this can be found at the LoadSecrets() function below + // Example can be found at the LoadSecrets() function below secret := provider.Secret{ - Path: key, + Key: key, Value: value, } @@ -159,8 +158,7 @@ func parsePathsToMap(paths []string) map[string]string { for _, path := range paths { split := strings.SplitN(path, "=", 2) - key := split[0] - value := split[1] + key, value := split[0], split[1] vaultEnviron[key] = value }