diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index b43d2500093c..f5df0fd3a0d1 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -342,21 +342,16 @@ func (s *Source) pageChunker( ctx.Logger().V(2).Info("Skipped due to excessive errors") return nil } - - // Use an anonymous function to retrieve the S3 object with a dedicated timeout context. + // Make sure we use a separate context for the GetObjectWithContext call. // This ensures that the timeout is isolated and does not affect any downstream operations. (e.g. HandleFile) - getObject := func() (*s3.GetObjectOutput, error) { - const getObjectTimeout = 30 * time.Second - objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout) - defer cancel() - - return client.GetObjectWithContext(objCtx, &s3.GetObjectInput{ - Bucket: &bucket, - Key: obj.Key, - }) - } + const getObjectTimeout = 30 * time.Second + objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout) + defer cancel() - res, err := getObject() + res, err := client.GetObjectWithContext(objCtx, &s3.GetObjectInput{ + Bucket: &bucket, + Key: obj.Key, + }) if err != nil { if !strings.Contains(err.Error(), "AccessDenied") { ctx.Logger().Error(err, "could not get S3 object") diff --git a/pkg/sources/s3/s3_integration_test.go b/pkg/sources/s3/s3_integration_test.go index fc160af9a62a..1832eeb30c9e 100644 --- a/pkg/sources/s3/s3_integration_test.go +++ b/pkg/sources/s3/s3_integration_test.go @@ -10,9 +10,10 @@ import ( "time" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" - "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" @@ -50,6 +51,37 @@ func TestSource_ChunksCount(t *testing.T) { assert.Greater(t, got, wantChunkCount) } +func TestSource_ChunksLarge(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + s := Source{} + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"trufflesec-ahrav-test"}, + } + conn, err := anypb.New(connection) + if err != nil { + t.Fatal(err) + } + + err = s.Init(ctx, "test name", 0, 0, false, conn, 1) + chunksCh := make(chan *sources.Chunk) + go func() { + defer close(chunksCh) + err = s.Chunks(ctx, chunksCh) + assert.Nil(t, err) + }() + + wantChunkCount := 9637 + got := 0 + + for range chunksCh { + got++ + } + assert.Equal(t, got, wantChunkCount) +} + func TestSource_Validate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel()