Skip to content

Commit

Permalink
Validate S3 source (#1715)
Browse files Browse the repository at this point in the history
This PR adds S3 source validation. This is accomplished by factoring out common "bucket visiting" logic to be used by both scanning and validation.
  • Loading branch information
rosecodym authored Sep 5, 2023
1 parent c9e6086 commit afe7085
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 39 deletions.
158 changes: 119 additions & 39 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
)

const (
defaultAWSRegion = "us-east-1"
defaultMaxObjectSize = 250 * 1024 * 1024 // 250 MiB
maxObjectSizeLimit = 250 * 1024 * 1024 // 250 MiB
)
Expand All @@ -53,6 +54,7 @@ type Source struct {
// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)

// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType {
Expand Down Expand Up @@ -93,6 +95,23 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
return nil
}

func (s *Source) Validate(ctx context.Context) []error {
var errors []error
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
roleErrs := s.validateBucketAccess(c, defaultRegionClient, roleArn, buckets)
if len(roleErrs) > 0 {
errors = append(errors, roleErrs...)
}
}

err := s.visitRoles(ctx, visitor)
if err != nil {
errors = append(errors, err)
}

return errors
}

// setMaxObjectSize sets the maximum size of objects that will be scanned. If
// not set, set to a negative number, or set larger than the
// maxObjectSizeLimit, the defaultMaxObjectSizeLimit will be used.
Expand Down Expand Up @@ -153,7 +172,7 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {

res, err := client.ListBuckets(&s3.ListBucketsInput{})
if err != nil {
return nil, fmt.Errorf("could not list s3 buckets: %w", err)
return nil, err
}

var bucketsToScan []string
Expand All @@ -163,34 +182,30 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
return bucketsToScan, nil
}

func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) error {
const defaultAWSRegion = "us-east-1"
func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bucketsToScan []string, chunksChan chan *sources.Chunk) {
objectCount := uint64(0)

logger := s.log
if role != "" {
logger = logger.WithValues("roleArn", role)
}

for i, bucket := range bucketsToScan {
logger := logger.WithValues("bucket", bucket)

if common.IsDone(ctx) {
return nil
return
}

s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
s.log.Info("Scanning bucket", "bucket", bucket)
region, err := s3manager.GetBucketRegionWithClient(ctx, client, bucket)
logger.Info("Scanning bucket")

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
s.log.Error(err, "could not get s3 region for bucket", "bucket: ", bucket)
logger.Error(err, "could not get regional client for bucket")
continue
}

var regionalClient *s3.S3
if region != defaultAWSRegion {
regionalClient, err = s.newClient(region, role)
if err != nil {
s.log.Error(err, "could not make regional s3 client")
continue
}
} else {
regionalClient = client
}

errorCount := sync.Map{}

err = regionalClient.ListObjectsV2PagesWithContext(
Expand All @@ -201,40 +216,46 @@ func (s *Source) scanBuckets(ctx context.Context, client *s3.S3, role string, bu
})

if err != nil {
s.log.Error(err, "could not list objects in s3 bucket", "bucket: ", bucket)
continue
if role == "" {
logger.Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
logger.V(3).Info("could not list objects in bucket",
"err", err)
}
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")
return nil
}

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error {
const defaultAWSRegion = "us-east-1"

roles := s.conn.Roles
if len(roles) == 0 {
roles = []string{""}
visitor := func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string) {
s.scanBuckets(c, defaultRegionClient, roleArn, buckets, chunksChan)
}

for _, role := range roles {
client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return errors.WrapPrefix(err, "could not create s3 client", 0)
}
return s.visitRoles(ctx, visitor)
}

bucketsToScan, err := s.getBucketsToScan(client)
if err != nil {
return err
}
func (s *Source) getRegionalClientForBucket(ctx context.Context, defaultRegionClient *s3.S3, role, bucket string) (*s3.S3, error) {
region, err := s3manager.GetBucketRegionWithClient(ctx, defaultRegionClient, bucket)
if err != nil {
return nil, errors.WrapPrefix(err, "could not get s3 region for bucket", 0)
}

if err := s.scanBuckets(ctx, client, role, bucketsToScan, chunksChan); err != nil {
return err
}
if region == defaultAWSRegion {
return defaultRegionClient, nil
}

return nil
regionalClient, err := s.newClient(region, role)
if err != nil {
return nil, errors.WrapPrefix(err, "could not create regional s3 client", 0)
}

return regionalClient, nil
}

// pageChunker emits chunks onto the given channel from a page
Expand Down Expand Up @@ -396,6 +417,65 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
_ = s.jobPool.Wait()
}

func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleArn string, buckets []string) []error {
shouldHaveAccessToAllBuckets := roleArn == ""
wasAbleToListAnyBucket := false
var errors []error

for _, bucket := range buckets {
if common.IsDone(ctx) {
return append(errors, ctx.Err())
}

regionalClient, err := s.getRegionalClientForBucket(ctx, client, roleArn, bucket)
if err != nil {
errors = append(errors, fmt.Errorf("could not get regional client for bucket %q: %w", bucket, err))
continue
}

_, err = regionalClient.ListObjectsV2(&s3.ListObjectsV2Input{Bucket: &bucket})

if err == nil {
wasAbleToListAnyBucket = true
} else if shouldHaveAccessToAllBuckets {
errors = append(errors, fmt.Errorf("could not list objects in bucket %q: %w", bucket, err))
}
}

if !wasAbleToListAnyBucket {
if roleArn == "" {
errors = append(errors, fmt.Errorf("could not list objects in any bucket"))
} else {
errors = append(errors, fmt.Errorf("role %q could not list objects in any bucket", roleArn))
}
}

return errors
}

func (s *Source) visitRoles(ctx context.Context, f func(c context.Context, defaultRegionClient *s3.S3, roleArn string, buckets []string)) error {
roles := s.conn.Roles
if len(roles) == 0 {
roles = []string{""}
}

for _, role := range roles {
client, err := s.newClient(defaultAWSRegion, role)
if err != nil {
return errors.WrapPrefix(err, "could not create s3 client", 0)
}

bucketsToScan, err := s.getBucketsToScan(client)
if err != nil {
return fmt.Errorf("role %q could not list any s3 buckets for scanning: %w", role, err)
}

f(ctx, client, role, bucketsToScan)
}

return nil
}

// S3 links currently have the general format of:
// https://[bucket].s3[.region unless us-east-1].amazonaws.com/[key]
func makeS3Link(bucket, region, key string) string {
Expand Down
128 changes: 128 additions & 0 deletions pkg/sources/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
package s3

import (
"fmt"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"google.golang.org/protobuf/types/known/anypb"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
Expand Down Expand Up @@ -45,3 +49,127 @@ func TestSource_ChunksCount(t *testing.T) {
}
assert.Greater(t, got, wantChunkCount)
}

func TestSource_Validate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()

secret, err := common.GetTestSecret(ctx)
if err != nil {
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
}

s3key := secret.MustGetField("AWS_S3_KEY")
s3secret := secret.MustGetField("AWS_S3_SECRET")

tests := []struct {
name string
roles []string
buckets []string
wantErrCount int
}{
{
name: "buckets without roles, can access all buckets",
buckets: []string{
"truffletestbucket-s3-tests",
},
wantErrCount: 0,
},
{
name: "buckets without roles, one error per inaccessible bucket",
buckets: []string{
"truffletestbucket-s3-tests",
"truffletestbucket-s3-role-assumption",
"truffletestbucket-no-access",
},
wantErrCount: 2,
},
{
name: "roles without buckets, all can access at least one account bucket",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
},
wantErrCount: 0,
},
{
name: "roles without buckets, one error per role that cannot access any account buckets",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
"arn:aws:iam::619888638459:role/test-no-access",
},
wantErrCount: 1,
},
{
name: "role and buckets, can access at least one bucket",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
},
buckets: []string{
"truffletestbucket-s3-role-assumption",
"truffletestbucket-no-access",
},
wantErrCount: 0,
},
{
name: "roles and buckets, one error per role that cannot access at least one bucket",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
"arn:aws:iam::619888638459:role/test-no-access",
},
buckets: []string{
"truffletestbucket-s3-role-assumption",
"truffletestbucket-no-access",
},
wantErrCount: 1,
},
{
name: "role and buckets, a bucket doesn't even exist",
roles: []string{
"arn:aws:iam::619888638459:role/s3-test-assume-role",
},
buckets: []string{
"truffletestbucket-s3-role-assumption",
"not-a-real-bucket-asljdhmglasjgvklhsdaljfh", // need a bucket name that nobody is likely to ever create
},
wantErrCount: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
var cancelOnce sync.Once
defer cancelOnce.Do(cancel)

// These are used by the tests that assume roles
t.Setenv("AWS_ACCESS_KEY_ID", s3key)
t.Setenv("AWS_SECRET_ACCESS_KEY", s3secret)

s := &Source{}

conn, err := anypb.New(&sourcespb.S3{
// These are used by the tests that don't assume roles
Credential: &sourcespb.S3_AccessKey{
AccessKey: &credentialspb.KeySecret{
Key: s3key,
Secret: s3secret,
},
},
Buckets: tt.buckets,
Roles: tt.roles,
})
if err != nil {
t.Fatal(err)
}

err = s.Init(ctx, tt.name, 0, 0, false, conn, 0)
if err != nil {
t.Fatal(err)
}

errs := s.Validate(ctx)

assert.Equal(t, tt.wantErrCount, len(errs))
})
}
}

0 comments on commit afe7085

Please sign in to comment.