diff --git a/Makefile b/Makefile
index 9eb3d3f540a..fd8452e6846 100644
--- a/Makefile
+++ b/Makefile
@@ -24,7 +24,7 @@ SDK_COMPA_PKGS=${SDK_CORE_PKGS} ${SDK_CLIENT_PKGS}
SDK_EXAMPLES_PKGS=
SDK_ALL_PKGS=${SDK_COMPA_PKGS} ${SDK_EXAMPLES_PKGS}
-RUN_NONE=-run '^$$'
+RUN_NONE=-run NONE
RUN_INTEG=-run '^TestInteg_'
CODEGEN_RESOURCES_PATH=$(shell pwd)/codegen/smithy-aws-go-codegen/src/main/resources/software/amazon/smithy/aws/go/codegen
@@ -98,7 +98,7 @@ gen-endpoint-prefix.json:
# Unit Testing #
################
-unit: lint unit-modules-.
+unit: lint unit-modules-.
unit-race: lint unit-race-modules-.
unit-test: test-modules-.
@@ -194,7 +194,7 @@ integ-modules-%:
"go test -timeout=10m -tags "integration" -v ${RUN_INTEG} -count 1 ./..."
cleanup-integ-buckets:
- @echo "Cleaning up SDK integraiton resources"
+ @echo "Cleaning up SDK integration resources"
go run -tags "integration" ./internal/awstesting/cmd/bucket_cleanup/main.go "aws-sdk-go-integration"
##############
diff --git a/feature/s3/manager/api.go b/feature/s3/manager/api.go
new file mode 100644
index 00000000000..4059f9851d7
--- /dev/null
+++ b/feature/s3/manager/api.go
@@ -0,0 +1,37 @@
+package manager
+
+import (
+ "context"
+
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+// DeleteObjectsAPIClient is an S3 API client that can invoke the DeleteObjects operation.
+type DeleteObjectsAPIClient interface {
+ DeleteObjects(context.Context, *s3.DeleteObjectsInput, ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error)
+}
+
+// DownloadAPIClient is an S3 API client that can invoke the GetObject operation.
+type DownloadAPIClient interface {
+ GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
+}
+
+// HeadBucketAPIClient is an S3 API client that can invoke the HeadBucket operation.
+type HeadBucketAPIClient interface {
+ HeadBucket(context.Context, *s3.HeadBucketInput, ...func(*s3.Options)) (*s3.HeadBucketOutput, error)
+}
+
+// ListObjectsV2APIClient is an S3 API client that can invoke the ListObjectV2 operation.
+type ListObjectsV2APIClient interface {
+ ListObjectsV2(context.Context, *s3.ListObjectsV2Input, ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
+}
+
+// UploadAPIClient is an S3 API client that can invoke PutObject, UploadPart, CreateMultipartUpload,
+// CompleteMultipartUpload, and AbortMultipartUpload operations.
+type UploadAPIClient interface {
+ PutObject(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error)
+ UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
+ CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
+ CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
+ AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
+}
diff --git a/feature/s3/manager/bucket_region.go b/feature/s3/manager/bucket_region.go
new file mode 100644
index 00000000000..e810877bc4d
--- /dev/null
+++ b/feature/s3/manager/bucket_region.go
@@ -0,0 +1,133 @@
+package manager
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/awslabs/smithy-go/middleware"
+ smithyhttp "github.com/awslabs/smithy-go/transport/http"
+)
+
+const bucketRegionHeader = "X-Amz-Bucket-Region"
+
+// GetBucketRegion will attempt to get the region for a bucket using the
+// client's configured region to determine which AWS partition to perform the query on.
+//
+// The request will not be signed, and will not use your AWS credentials.
+//
+// A BucketNotFound error will be returned if the bucket does not exist in the
+// AWS partition the client region belongs to.
+//
+// For example to get the region of a bucket which exists in "eu-central-1"
+// you could provide a region hint of "us-west-2".
+//
+// cfg := config.LoadDefaultConfig()
+//
+// bucket := "my-bucket"
+// region, err := s3manager.GetBucketRegion(ctx, s3.NewFromConfig(cfg), bucket)
+// if err != nil {
+// var bnf BucketNotFound
+// if errors.As(err, &bnf) {
+// fmt.Fprintf(os.Stderr, "unable to find bucket %s's region\n", bucket)
+// }
+// }
+// fmt.Printf("Bucket %s is in %s region\n", bucket, region)
+//
+// By default the request will be made to the Amazon S3 endpoint using the virtual-hosted-style addressing.
+//
+// bucketname.s3.us-west-2.amazonaws.com/
+//
+// To configure the GetBucketRegion to make a request via the Amazon
+// S3 FIPS endpoints directly when a FIPS region name is not available, (e.g.
+// fips-us-gov-west-1) set the EndpointResolver on the config or client the
+// utility is called with.
+//
+// cfg, err := config.LoadDefaultConfig(config.WithEndpointResolver{
+// EndpointResolver: aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
+// return aws.Endpoint{URL: "https://s3-fips.us-west-2.amazonaws.com"}, nil
+// }),
+// })
+// if err != nil {
+// panic(err)
+// }
+func GetBucketRegion(ctx context.Context, client HeadBucketAPIClient, bucket string, optFns ...func(*s3.Options)) (string, error) {
+ var captureBucketRegion deserializeBucketRegion
+
+ clientOptionFns := make([]func(*s3.Options), len(optFns)+1)
+ clientOptionFns[0] = func(options *s3.Options) {
+ options.Credentials = aws.AnonymousCredentials{}
+ options.APIOptions = append(options.APIOptions, captureBucketRegion.RegisterMiddleware)
+ }
+ copy(clientOptionFns[1:], optFns)
+
+ _, err := client.HeadBucket(ctx, &s3.HeadBucketInput{
+ Bucket: aws.String(bucket),
+ }, clientOptionFns...)
+ if len(captureBucketRegion.BucketRegion) == 0 && err != nil {
+ var httpStatusErr interface {
+ HTTPStatusCode() int
+ }
+ if !errors.As(err, &httpStatusErr) {
+ return "", err
+ }
+
+ if httpStatusErr.HTTPStatusCode() == http.StatusNotFound {
+ return "", &bucketNotFound{}
+ }
+
+ return "", err
+ }
+
+ return captureBucketRegion.BucketRegion, nil
+}
+
+type deserializeBucketRegion struct {
+ BucketRegion string
+}
+
+func (d *deserializeBucketRegion) RegisterMiddleware(stack *middleware.Stack) error {
+ return stack.Deserialize.Add(d, middleware.After)
+}
+
+func (d *deserializeBucketRegion) ID() string {
+ return "DeserializeBucketRegion"
+}
+
+func (d *deserializeBucketRegion) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
+ out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
+) {
+ out, metadata, err = next.HandleDeserialize(ctx, in)
+ if err != nil {
+ return out, metadata, err
+ }
+
+ resp, ok := out.RawResponse.(*smithyhttp.Response)
+ if !ok {
+ return out, metadata, fmt.Errorf("unknown transport type %T", out.RawResponse)
+ }
+
+ d.BucketRegion = resp.Header.Get(bucketRegionHeader)
+
+ return out, metadata, err
+}
+
+// BucketNotFound indicates the bucket was not found in the partition when calling GetBucketRegion.
+type BucketNotFound interface {
+ error
+
+ isBucketNotFound()
+}
+
+type bucketNotFound struct{}
+
+func (b *bucketNotFound) Error() string {
+ return "bucket not found"
+}
+
+func (b *bucketNotFound) isBucketNotFound() {}
+
+var _ BucketNotFound = (*bucketNotFound)(nil)
diff --git a/feature/s3/manager/bucket_region_test.go b/feature/s3/manager/bucket_region_test.go
new file mode 100644
index 00000000000..dd56d7d1a69
--- /dev/null
+++ b/feature/s3/manager/bucket_region_test.go
@@ -0,0 +1,120 @@
+package manager
+
+import (
+ "context"
+ "errors"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+var mockErrResponse = []byte(`
+
+ MockCode
+ The error message
+ 4442587FB7D0A2F9
+`)
+
+func testSetupGetBucketRegionServer(region string, statusCode int, incHeader bool) *httptest.Server {
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.Copy(ioutil.Discard, r.Body)
+ if incHeader {
+ w.Header().Set(bucketRegionHeader, region)
+ }
+ if statusCode >= 300 {
+ w.Header().Set("Content-Length", strconv.Itoa(len(mockErrResponse)))
+ w.WriteHeader(statusCode)
+ w.Write(mockErrResponse)
+ } else {
+ w.WriteHeader(statusCode)
+ }
+ }))
+}
+
+var testGetBucketRegionCases = []struct {
+ RespRegion string
+ StatusCode int
+ ExpectReqRegion string
+}{
+ {
+ RespRegion: "bucket-region",
+ StatusCode: 301,
+ },
+ {
+ RespRegion: "bucket-region",
+ StatusCode: 403,
+ },
+ {
+ RespRegion: "bucket-region",
+ StatusCode: 200,
+ },
+ {
+ RespRegion: "bucket-region",
+ StatusCode: 200,
+ ExpectReqRegion: "default-region",
+ },
+}
+
+func TestGetBucketRegion_Exists(t *testing.T) {
+ for i, c := range testGetBucketRegionCases {
+ server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true)
+
+ client := s3.New(s3.Options{
+ EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) {
+ return aws.Endpoint{
+ URL: server.URL,
+ }, nil
+ }),
+ })
+
+ region, err := GetBucketRegion(context.Background(), client, "bucket", func(o *s3.Options) {
+ o.UsePathStyle = true
+ })
+ if err != nil {
+ t.Errorf("%d, expect no error, got %v", i, err)
+ goto closeServer
+ }
+ if e, a := c.RespRegion, region; e != a {
+ t.Errorf("%d, expect %q region, got %q", i, e, a)
+ }
+
+ closeServer:
+ server.Close()
+ }
+}
+
+func TestGetBucketRegion_NotExists(t *testing.T) {
+ server := testSetupGetBucketRegionServer("ignore-region", 404, false)
+ defer server.Close()
+
+ client := s3.New(s3.Options{
+ EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) {
+ return aws.Endpoint{
+ URL: server.URL,
+ }, nil
+ }),
+ })
+
+ region, err := GetBucketRegion(context.Background(), client, "bucket", func(o *s3.Options) {
+ o.UsePathStyle = true
+ })
+ if err == nil {
+ t.Fatalf("expect error, but did not get one")
+ }
+
+ var bnf BucketNotFound
+ if !errors.As(err, &bnf) {
+ t.Errorf("expect %T error, got %v", bnf, err)
+ }
+
+ if len(region) != 0 {
+ t.Errorf("expect region not to be set, got %q", region)
+ }
+}
diff --git a/feature/s3/manager/buffered_read_seeker.go b/feature/s3/manager/buffered_read_seeker.go
new file mode 100644
index 00000000000..e781aef610d
--- /dev/null
+++ b/feature/s3/manager/buffered_read_seeker.go
@@ -0,0 +1,79 @@
+package manager
+
+import (
+ "io"
+)
+
+// BufferedReadSeeker is buffered io.ReadSeeker
+type BufferedReadSeeker struct {
+ r io.ReadSeeker
+ buffer []byte
+ readIdx, writeIdx int
+}
+
+// NewBufferedReadSeeker returns a new BufferedReadSeeker
+// if len(b) == 0 then the buffer will be initialized to 64 KiB.
+func NewBufferedReadSeeker(r io.ReadSeeker, b []byte) *BufferedReadSeeker {
+ if len(b) == 0 {
+ b = make([]byte, 64*1024)
+ }
+ return &BufferedReadSeeker{r: r, buffer: b}
+}
+
+func (b *BufferedReadSeeker) reset(r io.ReadSeeker) {
+ b.r = r
+ b.readIdx, b.writeIdx = 0, 0
+}
+
+// Read will read up len(p) bytes into p and will return
+// the number of bytes read and any error that occurred.
+// If the len(p) > the buffer size then a single read request
+// will be issued to the underlying io.ReadSeeker for len(p) bytes.
+// A Read request will at most perform a single Read to the underlying
+// io.ReadSeeker, and may return < len(p) if serviced from the buffer.
+func (b *BufferedReadSeeker) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return n, err
+ }
+
+ if b.readIdx == b.writeIdx {
+ if len(p) >= len(b.buffer) {
+ n, err = b.r.Read(p)
+ return n, err
+ }
+ b.readIdx, b.writeIdx = 0, 0
+
+ n, err = b.r.Read(b.buffer)
+ if n == 0 {
+ return n, err
+ }
+
+ b.writeIdx += n
+ }
+
+ n = copy(p, b.buffer[b.readIdx:b.writeIdx])
+ b.readIdx += n
+
+ return n, err
+}
+
+// Seek will position then underlying io.ReadSeeker to the given offset
+// and will clear the buffer.
+func (b *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) {
+ n, err := b.r.Seek(offset, whence)
+
+ b.reset(b.r)
+
+ return n, err
+}
+
+// ReadAt will read up to len(p) bytes at the given file offset.
+// This will result in the buffer being cleared.
+func (b *BufferedReadSeeker) ReadAt(p []byte, off int64) (int, error) {
+ _, err := b.Seek(off, io.SeekStart)
+ if err != nil {
+ return 0, err
+ }
+
+ return b.Read(p)
+}
diff --git a/feature/s3/manager/buffered_read_seeker_test.go b/feature/s3/manager/buffered_read_seeker_test.go
new file mode 100644
index 00000000000..ed46668395f
--- /dev/null
+++ b/feature/s3/manager/buffered_read_seeker_test.go
@@ -0,0 +1,79 @@
+package manager
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+func TestBufferedReadSeekerRead(t *testing.T) {
+ expected := []byte("testData")
+
+ readSeeker := NewBufferedReadSeeker(bytes.NewReader(expected), make([]byte, 4))
+
+ var (
+ actual []byte
+ buffer = make([]byte, 2)
+ )
+
+ for {
+ n, err := readSeeker.Read(buffer)
+ actual = append(actual, buffer[:n]...)
+ if err != nil && err == io.EOF {
+ break
+ } else if err != nil {
+ t.Fatalf("failed to read from reader: %v", err)
+ }
+ }
+
+ if !bytes.Equal(expected, actual) {
+ t.Errorf("expected %v, got %v", expected, actual)
+ }
+}
+
+func TestBufferedReadSeekerSeek(t *testing.T) {
+ content := []byte("testData")
+
+ readSeeker := NewBufferedReadSeeker(bytes.NewReader(content), make([]byte, 4))
+
+ _, err := readSeeker.Seek(4, io.SeekStart)
+ if err != nil {
+ t.Fatalf("failed to seek reader: %v", err)
+ }
+
+ var (
+ actual []byte
+ buffer = make([]byte, 4)
+ )
+
+ for {
+ n, err := readSeeker.Read(buffer)
+ actual = append(actual, buffer[:n]...)
+ if err != nil && err == io.EOF {
+ break
+ } else if err != nil {
+ t.Fatalf("failed to read from reader: %v", err)
+ }
+ }
+
+ if e := []byte("Data"); !bytes.Equal(e, actual) {
+ t.Errorf("expected %v, got %v", e, actual)
+ }
+}
+
+func TestBufferedReadSeekerReadAt(t *testing.T) {
+ content := []byte("testData")
+
+ readSeeker := NewBufferedReadSeeker(bytes.NewReader(content), make([]byte, 2))
+
+ buffer := make([]byte, 4)
+
+ _, err := readSeeker.ReadAt(buffer, 0)
+ if err != nil {
+ t.Fatalf("failed to seek reader: %v", err)
+ }
+
+ if e := content[:4]; !bytes.Equal(e, buffer) {
+ t.Errorf("expected %v, got %v", e, buffer)
+ }
+}
diff --git a/feature/s3/manager/default_read_seeker_write_to.go b/feature/s3/manager/default_read_seeker_write_to.go
new file mode 100644
index 00000000000..6d1dc6d2c42
--- /dev/null
+++ b/feature/s3/manager/default_read_seeker_write_to.go
@@ -0,0 +1,7 @@
+// +build !windows
+
+package manager
+
+func defaultUploadBufferProvider() ReadSeekerWriteToProvider {
+ return nil
+}
diff --git a/feature/s3/manager/default_read_seeker_write_to_windows.go b/feature/s3/manager/default_read_seeker_write_to_windows.go
new file mode 100644
index 00000000000..1ae881c104a
--- /dev/null
+++ b/feature/s3/manager/default_read_seeker_write_to_windows.go
@@ -0,0 +1,5 @@
+package manager
+
+func defaultUploadBufferProvider() ReadSeekerWriteToProvider {
+ return NewBufferedReadSeekerWriteToPool(1024 * 1024)
+}
diff --git a/feature/s3/manager/default_writer_read_from.go b/feature/s3/manager/default_writer_read_from.go
new file mode 100644
index 00000000000..d5518145219
--- /dev/null
+++ b/feature/s3/manager/default_writer_read_from.go
@@ -0,0 +1,7 @@
+// +build !windows
+
+package manager
+
+func defaultDownloadBufferProvider() WriterReadFromProvider {
+ return nil
+}
diff --git a/feature/s3/manager/default_writer_read_from_windows.go b/feature/s3/manager/default_writer_read_from_windows.go
new file mode 100644
index 00000000000..88887ff586e
--- /dev/null
+++ b/feature/s3/manager/default_writer_read_from_windows.go
@@ -0,0 +1,5 @@
+package manager
+
+func defaultDownloadBufferProvider() WriterReadFromProvider {
+ return NewPooledBufferedWriterReadFromProvider(1024 * 1024)
+}
diff --git a/feature/s3/manager/doc.go b/feature/s3/manager/doc.go
new file mode 100644
index 00000000000..31171a69875
--- /dev/null
+++ b/feature/s3/manager/doc.go
@@ -0,0 +1,3 @@
+// Package manager provides utilities to upload and download objects from
+// S3 concurrently. Helpful for when working with large objects.
+package manager
diff --git a/feature/s3/manager/download.go b/feature/s3/manager/download.go
new file mode 100644
index 00000000000..60e5a051f7c
--- /dev/null
+++ b/feature/s3/manager/download.go
@@ -0,0 +1,493 @@
+package manager
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws/middleware"
+ "github.com/aws/aws-sdk-go-v2/internal/awsutil"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+const userAgentKey = "S3Manager"
+
+// DefaultDownloadPartSize is the default range of bytes to get at a time when
+// using Download().
+const DefaultDownloadPartSize = 1024 * 1024 * 5
+
+// DefaultDownloadConcurrency is the default number of goroutines to spin up
+// when using Download().
+const DefaultDownloadConcurrency = 5
+
+// DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to upload.
+const DefaultPartBodyMaxRetries = 3
+
+type errReadingBody struct {
+ err error
+}
+
+func (e *errReadingBody) Error() string {
+ return fmt.Sprintf("failed to read part body: %v", e.err)
+}
+
+func (e *errReadingBody) Unwrap() error {
+ return e.err
+}
+
+// The Downloader structure that calls Download(). It is safe to call Download()
+// on this structure for multiple objects and across concurrent goroutines.
+// Mutating the Downloader's properties is not safe to be done concurrently.
+type Downloader struct {
+ // The size (in bytes) to request from S3 for each part.
+ // The minimum allowed part size is 5MB, and if this value is set to zero,
+ // the DefaultDownloadPartSize value will be used.
+ //
+ // PartSize is ignored if the Range input parameter is provided.
+ PartSize int64
+
+ // PartBodyMaxRetries is the number of retry attempts to make for failed part uploads
+ PartBodyMaxRetries int
+
+ // The number of goroutines to spin up in parallel when sending parts.
+ // If this is set to zero, the DefaultDownloadConcurrency value will be used.
+ //
+ // Concurrency of 1 will download the parts sequentially.
+ //
+ // Concurrency is ignored if the Range input parameter is provided.
+ Concurrency int
+
+ // An S3 client to use when performing downloads.
+ S3 DownloadAPIClient
+
+ // List of client options that will be passed down to individual API
+ // operation requests made by the downloader.
+ ClientOptions []func(*s3.Options)
+
+ // Defines the buffer strategy used when downloading a part.
+ //
+ // If a WriterReadFromProvider is given the Download manager
+ // will pass the io.WriterAt of the Download request to the provider
+ // and will use the returned WriterReadFrom from the provider as the
+ // destination writer when copying from http response body.
+ BufferProvider WriterReadFromProvider
+}
+
+// WithDownloaderClientOptions appends to the Downloader's API request options.
+func WithDownloaderClientOptions(opts ...func(*s3.Options)) func(*Downloader) {
+ return func(d *Downloader) {
+ d.ClientOptions = append(d.ClientOptions, opts...)
+ }
+}
+
+// NewDownloader creates a new Downloader instance to downloads objects from
+// S3 in concurrent chunks. Pass in additional functional options to customize
+// the downloader behavior. Requires a client.ConfigProvider in order to create
+// a S3 service client. The session.Session satisfies the client.ConfigProvider
+// interface.
+//
+// Example:
+// // Load AWS Config
+// cfg, err := config.LoadDefaultConfig()
+// if err != nil {
+// panic(err)
+// }
+//
+// // Create an S3 client using the loaded configuration
+// s3.NewFromConfig(cfg)
+//
+// // Create a downloader passing it the S3 client
+// downloader := s3manager.NewDownloader(s3.NewFromConfig(cfg))
+//
+// // Create a downloader with the client and custom downloader options
+// downloader := s3manager.NewDownloader(client, func(d *s3manager.Downloader) {
+// d.PartSize = 64 * 1024 * 1024 // 64MB per part
+// })
+func NewDownloader(c DownloadAPIClient, options ...func(*Downloader)) *Downloader {
+ d := &Downloader{
+ S3: c,
+ PartSize: DefaultDownloadPartSize,
+ PartBodyMaxRetries: DefaultPartBodyMaxRetries,
+ Concurrency: DefaultDownloadConcurrency,
+ BufferProvider: defaultDownloadBufferProvider(),
+ }
+ for _, option := range options {
+ option(d)
+ }
+
+ return d
+}
+
+// Download downloads an object in S3 and writes the payload into w
+// using concurrent GET requests. The n int64 returned is the size of the object downloaded
+// in bytes.
+//
+// DownloadWithContext is the same as Download with the additional support for
+// Context input parameters. The Context must not be nil. A nil Context will
+// cause a panic. Use the Context to add deadlining, timeouts, etc. The
+// DownloadWithContext may create sub-contexts for individual underlying
+// requests.
+//
+// Additional functional options can be provided to configure the individual
+// download. These options are copies of the Downloader instance Download is
+// called from. Modifying the options will not impact the original Downloader
+// instance. Use the WithDownloaderClientOptions helper function to pass in request
+// options that will be applied to all API operations made with this downloader.
+//
+// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
+// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
+//
+// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
+// download the parts from S3 sequentially.
+//
+// It is safe to call this method concurrently across goroutines.
+//
+// If the GetObjectInput's Range value is provided that will cause the downloader
+// to perform a single GetObjectInput request for that object's range. This will
+// caused the part size, and concurrency configurations to be ignored.
+func (d Downloader) Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
+ impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
+
+ // Copy ClientOptions
+ clientOptions := make([]func(*s3.Options), 0, len(impl.cfg.ClientOptions)+1)
+ clientOptions = append(clientOptions, func(o *s3.Options) {
+ o.APIOptions = append(o.APIOptions, middleware.AddUserAgentKey(userAgentKey))
+ })
+ clientOptions = append(clientOptions, impl.cfg.ClientOptions...)
+ impl.cfg.ClientOptions = clientOptions
+
+ for _, option := range options {
+ option(&impl.cfg)
+ }
+
+ impl.partBodyMaxRetries = d.PartBodyMaxRetries
+
+ impl.totalBytes = -1
+ if impl.cfg.Concurrency == 0 {
+ impl.cfg.Concurrency = DefaultDownloadConcurrency
+ }
+
+ if impl.cfg.PartSize == 0 {
+ impl.cfg.PartSize = DefaultDownloadPartSize
+ }
+
+ return impl.download()
+}
+
+// downloader is the implementation structure used internally by Downloader.
+type downloader struct {
+ ctx context.Context
+ cfg Downloader
+
+ in *s3.GetObjectInput
+ w io.WriterAt
+
+ wg sync.WaitGroup
+ m sync.Mutex
+
+ pos int64
+ totalBytes int64
+ written int64
+ err error
+
+ partBodyMaxRetries int
+}
+
+// download performs the implementation of the object download across ranged
+// GETs.
+func (d *downloader) download() (n int64, err error) {
+ // If range is specified fall back to single download of that range
+ // this enables the functionality of ranged gets with the downloader but
+ // at the cost of no multipart downloads.
+ if rng := aws.ToString(d.in.Range); len(rng) > 0 {
+ d.downloadRange(rng)
+ return d.written, d.err
+ }
+
+ // Spin off first worker to check additional header information
+ d.getChunk()
+
+ if total := d.getTotalBytes(); total >= 0 {
+ // Spin up workers
+ ch := make(chan dlchunk, d.cfg.Concurrency)
+
+ for i := 0; i < d.cfg.Concurrency; i++ {
+ d.wg.Add(1)
+ go d.downloadPart(ch)
+ }
+
+ // Assign work
+ for d.getErr() == nil {
+ if d.pos >= total {
+ break // We're finished queuing chunks
+ }
+
+ // Queue the next range of bytes to read.
+ ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
+ d.pos += d.cfg.PartSize
+ }
+
+ // Wait for completion
+ close(ch)
+ d.wg.Wait()
+ } else {
+ // Checking if we read anything new
+ for d.err == nil {
+ d.getChunk()
+ }
+
+ // We expect a 416 error letting us know we are done downloading the
+ // total bytes. Since we do not know the content's length, this will
+ // keep grabbing chunks of data until the range of bytes specified in
+ // the request is out of range of the content. Once, this happens, a
+ // 416 should occur.
+ var responseError interface {
+ HTTPStatusCode() int
+ }
+ if errors.As(d.err, &responseError) {
+ if responseError.HTTPStatusCode() == http.StatusRequestedRangeNotSatisfiable {
+ d.err = nil
+ }
+ }
+ }
+
+ // Return error
+ return d.written, d.err
+}
+
+// downloadPart is an individual goroutine worker reading from the ch channel
+// and performing a GetObject request on the data with a given byte range.
+//
+// If this is the first worker, this operation also resolves the total number
+// of bytes to be read so that the worker manager knows when it is finished.
+func (d *downloader) downloadPart(ch chan dlchunk) {
+ defer d.wg.Done()
+ for {
+ chunk, ok := <-ch
+ if !ok {
+ break
+ }
+ if d.getErr() != nil {
+ // Drain the channel if there is an error, to prevent deadlocking
+ // of download producer.
+ continue
+ }
+
+ if err := d.downloadChunk(chunk); err != nil {
+ d.setErr(err)
+ }
+ }
+}
+
+// getChunk grabs a chunk of data from the body.
+// Not thread safe. Should only used when grabbing data on a single thread.
+func (d *downloader) getChunk() {
+ if d.getErr() != nil {
+ return
+ }
+
+ chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
+ d.pos += d.cfg.PartSize
+
+ if err := d.downloadChunk(chunk); err != nil {
+ d.setErr(err)
+ }
+}
+
+// downloadRange downloads an Object given the passed in Byte-Range value.
+// The chunk used down download the range will be configured for that range.
+func (d *downloader) downloadRange(rng string) {
+ if d.getErr() != nil {
+ return
+ }
+
+ chunk := dlchunk{w: d.w, start: d.pos}
+ // Ranges specified will short circuit the multipart download
+ chunk.withRange = rng
+
+ if err := d.downloadChunk(chunk); err != nil {
+ d.setErr(err)
+ }
+
+ // Update the position based on the amount of data received.
+ d.pos = d.written
+}
+
+// downloadChunk downloads the chunk from s3
+func (d *downloader) downloadChunk(chunk dlchunk) error {
+ in := &s3.GetObjectInput{}
+ awsutil.Copy(in, d.in)
+
+ // Get the next byte range of data
+ in.Range = aws.String(chunk.ByteRange())
+
+ var n int64
+ var err error
+ for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
+ n, err = d.tryDownloadChunk(in, &chunk)
+ if err == nil {
+ break
+ }
+ // Check if the returned error is an errReadingBody.
+ // If err is errReadingBody this indicates that an error
+ // occurred while copying the http response body.
+ // If this occurs we unwrap the err to set the underlying error
+ // and attempt any remaining retries.
+ if bodyErr, ok := err.(*errReadingBody); ok {
+ err = bodyErr.Unwrap()
+ } else {
+ return err
+ }
+
+ chunk.cur = 0
+
+ // TODO: Add Logging
+ //logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries,
+ // fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
+ // aws.StringValue(in.Key), err, retry))
+ }
+
+ d.incrWritten(n)
+
+ return err
+}
+
+func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64, error) {
+ cleanup := func() {}
+ if d.cfg.BufferProvider != nil {
+ w, cleanup = d.cfg.BufferProvider.GetReadFrom(w)
+ }
+ defer cleanup()
+
+ resp, err := d.cfg.S3.GetObject(d.ctx, in, d.cfg.ClientOptions...)
+ if err != nil {
+ return 0, err
+ }
+ d.setTotalBytes(resp) // Set total if not yet set.
+
+ n, err := io.Copy(w, resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ return n, &errReadingBody{err: err}
+ }
+
+ return n, nil
+}
+
+// getTotalBytes is a thread-safe getter for retrieving the total byte status.
+func (d *downloader) getTotalBytes() int64 {
+ d.m.Lock()
+ defer d.m.Unlock()
+
+ return d.totalBytes
+}
+
+// setTotalBytes is a thread-safe setter for setting the total byte status.
+// Will extract the object's total bytes from the Content-Range if the file
+// will be chunked, or Content-Length. Content-Length is used when the response
+// does not include a Content-Range. Meaning the object was not chunked. This
+// occurs when the full file fits within the PartSize directive.
+func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
+ d.m.Lock()
+ defer d.m.Unlock()
+
+ if d.totalBytes >= 0 {
+ return
+ }
+
+ if resp.ContentRange == nil {
+ // ContentRange is nil when the full file contents is provided, and
+ // is not chunked. Use ContentLength instead.
+ if resp.ContentLength != nil {
+ d.totalBytes = *resp.ContentLength
+ return
+ }
+ } else {
+ parts := strings.Split(*resp.ContentRange, "/")
+
+ total := int64(-1)
+ var err error
+ // Checking for whether or not a numbered total exists
+ // If one does not exist, we will assume the total to be -1, undefined,
+ // and sequentially download each chunk until hitting a 416 error
+ totalStr := parts[len(parts)-1]
+ if totalStr != "*" {
+ total, err = strconv.ParseInt(totalStr, 10, 64)
+ if err != nil {
+ d.err = err
+ return
+ }
+ }
+
+ d.totalBytes = total
+ }
+}
+
+func (d *downloader) incrWritten(n int64) {
+ d.m.Lock()
+ defer d.m.Unlock()
+
+ d.written += n
+}
+
+// getErr is a thread-safe getter for the error object
+func (d *downloader) getErr() error {
+ d.m.Lock()
+ defer d.m.Unlock()
+
+ return d.err
+}
+
+// setErr is a thread-safe setter for the error object
+func (d *downloader) setErr(e error) {
+ d.m.Lock()
+ defer d.m.Unlock()
+
+ d.err = e
+}
+
+// dlchunk represents a single chunk of data to write by the worker routine.
+// This structure also implements an io.SectionReader style interface for
+// io.WriterAt, effectively making it an io.SectionWriter (which does not
+// exist).
+type dlchunk struct {
+ w io.WriterAt
+ start int64
+ size int64
+ cur int64
+
+ // specifies the byte range the chunk should be downloaded with.
+ withRange string
+}
+
+// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
+// position to its end (or EOF).
+//
+// If a range is specified on the dlchunk the size will be ignored when writing.
+// as the total size may not of be known ahead of time.
+func (c *dlchunk) Write(p []byte) (n int, err error) {
+ if c.cur >= c.size && len(c.withRange) == 0 {
+ return 0, io.EOF
+ }
+
+ n, err = c.w.WriteAt(p, c.start+c.cur)
+ c.cur += int64(n)
+
+ return
+}
+
+// ByteRange returns a HTTP Byte-Range header value that should be used by the
+// client to request the chunk's range.
+func (c *dlchunk) ByteRange() string {
+ if len(c.withRange) != 0 {
+ return c.withRange
+ }
+
+ return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1)
+}
diff --git a/feature/s3/manager/download_test.go b/feature/s3/manager/download_test.go
new file mode 100644
index 00000000000..7aa2de1ce33
--- /dev/null
+++ b/feature/s3/manager/download_test.go
@@ -0,0 +1,746 @@
+package manager_test
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "regexp"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ managertesting "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
+ "github.com/aws/aws-sdk-go-v2/internal/awstesting"
+ "github.com/aws/aws-sdk-go-v2/internal/sdkio"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+type downloadCaptureClient struct {
+ GetObjectFn func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
+ GetObjectInvocations int
+
+ RetrievedRanges []string
+
+ lock sync.Mutex
+}
+
+func (c *downloadCaptureClient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ c.GetObjectInvocations++
+
+ if params.Range != nil {
+ c.RetrievedRanges = append(c.RetrievedRanges, aws.ToString(params.Range))
+ }
+
+ return c.GetObjectFn(ctx, params, optFns...)
+}
+
+var rangeValueRegex = regexp.MustCompile(`bytes=(\d+)-(\d+)`)
+
+func parseRange(rangeValue string) (start, fin int64) {
+ rng := rangeValueRegex.FindStringSubmatch(rangeValue)
+ start, _ = strconv.ParseInt(rng[1], 10, 64)
+ fin, _ = strconv.ParseInt(rng[2], 10, 64)
+ return start, fin
+}
+
+func newDownloadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) {
+ capture := &downloadCaptureClient{}
+
+ capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ start, fin := parseRange(aws.ToString(params.Range))
+ fin++
+
+ if fin >= int64(len(data)) {
+ fin = int64(len(data))
+ }
+
+ bodyBytes := data[start:fin]
+
+ return &s3.GetObjectOutput{
+ Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
+ ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start, fin-1, len(data))),
+ ContentLength: aws.Int64(int64(len(bodyBytes))),
+ }, nil
+ }
+
+ return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges
+}
+
+func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) {
+ capture := &downloadCaptureClient{}
+
+ capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ return &s3.GetObjectOutput{
+ Body: ioutil.NopCloser(bytes.NewReader(data[:])),
+ ContentLength: aws.Int64(int64(len(data))),
+ }, nil
+ }
+
+ return capture, &capture.GetObjectInvocations
+}
+
+type mockHTTPStatusError struct {
+ StatusCode int
+}
+
+func (m *mockHTTPStatusError) Error() string {
+ return fmt.Sprintf("http status code: %v", m.StatusCode)
+}
+
+func (m *mockHTTPStatusError) HTTPStatusCode() int {
+ return m.StatusCode
+}
+
+func newDownloadContentRangeTotalAnyClient(data []byte) (*downloadCaptureClient, *int) {
+ capture := &downloadCaptureClient{}
+ completed := false
+
+ capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ if completed {
+ return nil, &mockHTTPStatusError{StatusCode: 416}
+ }
+
+ start, fin := parseRange(aws.ToString(params.Range))
+ fin++
+
+ if fin >= int64(len(data)) {
+ fin = int64(len(data))
+ completed = true
+ }
+
+ bodyBytes := data[start:fin]
+
+ return &s3.GetObjectOutput{
+ Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
+ ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/*", start, fin-1)),
+ }, nil
+ }
+
+ return capture, &capture.GetObjectInvocations
+}
+
+func newDownloadWithErrReaderClient(cases []testErrReader) (*downloadCaptureClient, *int) {
+ var index int
+
+ c := &downloadCaptureClient{}
+ c.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ c := cases[index]
+ out := &s3.GetObjectOutput{
+ Body: ioutil.NopCloser(&c),
+ ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len)),
+ ContentLength: aws.Int64(c.Len),
+ }
+ index++
+ return out, nil
+ }
+
+ return c, &c.GetObjectInvocations
+}
+
+func TestDownloadOrder(t *testing.T) {
+ c, invocations, ranges := newDownloadRangeClient(buf12MB)
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ })
+
+ w := aws.NewWriteAtBuffer(make([]byte, len(buf12MB)))
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(len(buf12MB)), n; e != a {
+ t.Errorf("expect %d buffer length, got %d", e, a)
+ }
+
+ if e, a := 3, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+
+ expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}
+ if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v ranges, got %v", e, a)
+ }
+}
+
+func TestDownloadZero(t *testing.T) {
+ c, invocations, ranges := newDownloadRangeClient([]byte{})
+
+ d := manager.NewDownloader(c)
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if n != 0 {
+ t.Errorf("expect 0 bytes read, got %d", n)
+ }
+ if e, a := 1, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+
+ expectRngs := []string{"bytes=0-5242879"}
+ if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v ranges, got %v", e, a)
+ }
+}
+
+func TestDownloadSetPartSize(t *testing.T) {
+ c, invocations, ranges := newDownloadRangeClient([]byte{1, 2, 3})
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ d.PartSize = 1
+ })
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(3), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 3, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+ expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}
+ if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v ranges, got %v", e, a)
+ }
+ expectBytes := []byte{1, 2, 3}
+ if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v bytes, got %v", e, a)
+ }
+}
+
+func TestDownloadError(t *testing.T) {
+ c, invocations, _ := newDownloadRangeClient([]byte{1, 2, 3})
+
+ num := 0
+ orig := c.GetObjectFn
+ c.GetObjectFn = func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ out, err := orig(ctx, params, optFns...)
+ num++
+ if num > 1 {
+ return &s3.GetObjectOutput{}, fmt.Errorf("s3 service error")
+ }
+ return out, err
+ }
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ d.PartSize = 1
+ })
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+ if e, a := "s3 service error", err.Error(); e != a {
+ t.Errorf("expect %s error code, got %s", e, a)
+ }
+ if e, a := int64(1), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 2, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+ expectBytes := []byte{1}
+ if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v bytes, got %v", e, a)
+ }
+}
+
+func TestDownloadNonChunk(t *testing.T) {
+ c, invocations := newDownloadNonRangeClient(buf2MB)
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ })
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(len(buf2MB)), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 1, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+
+ count := 0
+ for _, b := range w.Bytes() {
+ count += int(b)
+ }
+ if count != 0 {
+ t.Errorf("expect 0 count, got %d", count)
+ }
+}
+
+func TestDownloadNoContentRangeLength(t *testing.T) {
+ s, invocations, _ := newDownloadRangeClient(buf2MB)
+
+ d := manager.NewDownloader(s, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ })
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(len(buf2MB)), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 1, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+
+ count := 0
+ for _, b := range w.Bytes() {
+ count += int(b)
+ }
+ if count != 0 {
+ t.Errorf("expect 0 count, got %d", count)
+ }
+}
+
+func TestDownloadContentRangeTotalAny(t *testing.T) {
+ s, invocations := newDownloadContentRangeTotalAnyClient(buf2MB)
+
+ d := manager.NewDownloader(s, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ })
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(len(buf2MB)), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 2, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+
+ count := 0
+ for _, b := range w.Bytes() {
+ count += int(b)
+ }
+ if count != 0 {
+ t.Errorf("expect 0 count, got %d", count)
+ }
+}
+
+func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
+ c, invocations := newDownloadWithErrReaderClient([]testErrReader{
+ {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
+ {Buf: []byte("123"), Len: 3, Err: io.EOF},
+ })
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ })
+
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(3), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 2, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+ if e, a := "123", string(w.Bytes()); e != a {
+ t.Errorf("expect %q response, got %q", e, a)
+ }
+}
+
+func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
+ c, invocations := newDownloadWithErrReaderClient([]testErrReader{
+ {Buf: []byte("abc"), Len: 3, Err: io.EOF},
+ })
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ })
+
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(3), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 1, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+ if e, a := "abc", string(w.Bytes()); e != a {
+ t.Errorf("expect %q response, got %q", e, a)
+ }
+}
+
+func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
+ c, invocations := newDownloadWithErrReaderClient([]testErrReader{
+ {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
+ })
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 1
+ d.PartBodyMaxRetries = 0
+ })
+
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+ if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %q error message to be in %q", e, a)
+ }
+ if e, a := int64(2), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 1, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+ if e, a := "ab", string(w.Bytes()); e != a {
+ t.Errorf("expect %q response, got %q", e, a)
+ }
+}
+
+func TestDownloadWithContextCanceled(t *testing.T) {
+ d := manager.NewDownloader(s3.New(s3.Options{}))
+
+ params := s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("Key"),
+ }
+
+ ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
+ ctx.Error = fmt.Errorf("context canceled")
+ close(ctx.DoneCh)
+
+ w := &aws.WriteAtBuffer{}
+
+ _, err := d.Download(ctx, w, ¶ms)
+ if err == nil {
+ t.Fatalf("expected error, did not get one")
+ }
+ if e, a := "canceled", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expected error message to contain %q, but did not %q", e, a)
+ }
+}
+
+func TestDownload_WithRange(t *testing.T) {
+ c, invocations, ranges := newDownloadRangeClient([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ d := manager.NewDownloader(c, func(d *manager.Downloader) {
+ d.Concurrency = 10 // should be ignored
+ d.PartSize = 1 // should be ignored
+ })
+
+ w := &aws.WriteAtBuffer{}
+ n, err := d.Download(context.Background(), w, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Range: aws.String("bytes=2-6"),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := int64(5), n; e != a {
+ t.Errorf("expect %d bytes read, got %d", e, a)
+ }
+ if e, a := 1, *invocations; e != a {
+ t.Errorf("expect %v API calls, got %v", e, a)
+ }
+ expectRngs := []string{"bytes=2-6"}
+ if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v ranges, got %v", e, a)
+ }
+ expectBytes := []byte{2, 3, 4, 5, 6}
+ if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
+ t.Errorf("expect %v bytes, got %v", e, a)
+ }
+}
+
+type mockDownloadCLient func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
+
+func (m mockDownloadCLient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ return m(ctx, params, optFns...)
+}
+
+func TestDownload_WithFailure(t *testing.T) {
+ reqCount := int64(0)
+ startingByte := 0
+
+ client := mockDownloadCLient(func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (out *s3.GetObjectOutput, err error) {
+ switch atomic.LoadInt64(&reqCount) {
+ case 1:
+ // Give a chance for the multipart chunks to be queued up
+ time.Sleep(1 * time.Second)
+ err = fmt.Errorf("some connection error")
+ default:
+ body := bytes.NewReader(make([]byte, manager.DefaultDownloadPartSize))
+ out = &s3.GetObjectOutput{
+ Body: ioutil.NopCloser(body),
+ ContentLength: aws.Int64(int64(body.Len())),
+ ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10)),
+ }
+
+ startingByte += body.Len()
+ if reqCount > 0 {
+ // sleep here to ensure context switching between goroutines
+ time.Sleep(25 * time.Millisecond)
+ }
+ }
+ atomic.AddInt64(&reqCount, 1)
+ return out, err
+ })
+
+ d := manager.NewDownloader(client, func(d *manager.Downloader) {
+ d.Concurrency = 2
+ })
+
+ w := &aws.WriteAtBuffer{}
+ params := s3.GetObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ }
+
+ // Expect this request to exit quickly after failure
+ _, err := d.Download(context.Background(), w, ¶ms)
+ if err == nil {
+ t.Fatalf("expect error, got none")
+ }
+
+ if atomic.LoadInt64(&reqCount) > 3 {
+ t.Errorf("expect no more than 3 requests, but received %d", reqCount)
+ }
+}
+
+func TestDownloadBufferStrategy(t *testing.T) {
+ cases := map[string]struct {
+ partSize int64
+ strategy *recordedWriterReadFromProvider
+ expectedSize int64
+ }{
+ "no strategy": {
+ partSize: manager.DefaultDownloadPartSize,
+ expectedSize: 10 * sdkio.MebiByte,
+ },
+ "partSize modulo bufferSize == 0": {
+ partSize: 5 * sdkio.MebiByte,
+ strategy: &recordedWriterReadFromProvider{
+ WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), // 1 MiB
+ },
+ expectedSize: 10 * sdkio.MebiByte, // 10 MiB
+ },
+ "partSize modulo bufferSize > 0": {
+ partSize: 5 * 1024 * 1204, // 5 MiB
+ strategy: &recordedWriterReadFromProvider{
+ WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(2 * int(sdkio.MebiByte)), // 2 MiB
+ },
+ expectedSize: 10 * sdkio.MebiByte, // 10 MiB
+ },
+ }
+
+ for name, tCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ expected := managertesting.GetTestBytes(int(tCase.expectedSize))
+
+ client, _, _ := newDownloadRangeClient(expected)
+
+ d := manager.NewDownloader(client, func(d *manager.Downloader) {
+ d.PartSize = tCase.partSize
+ if tCase.strategy != nil {
+ d.BufferProvider = tCase.strategy
+ }
+ })
+
+ buffer := aws.NewWriteAtBuffer(make([]byte, len(expected)))
+
+ n, err := d.Download(context.Background(), buffer, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+ if err != nil {
+ t.Errorf("failed to download: %v", err)
+ }
+
+ if e, a := len(expected), int(n); e != a {
+ t.Errorf("expected %v, got %v downloaded bytes", e, a)
+ }
+
+ if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) {
+ t.Errorf("downloaded bytes did not match expected")
+ }
+
+ if tCase.strategy != nil {
+ if e, a := tCase.strategy.callbacksVended, tCase.strategy.callbacksExecuted; e != a {
+ t.Errorf("expected %v, got %v", e, a)
+ }
+ }
+ })
+ }
+}
+
+type testErrReader struct {
+ Buf []byte
+ Err error
+ Len int64
+
+ off int
+}
+
+func (r *testErrReader) Read(p []byte) (int, error) {
+ to := len(r.Buf) - r.off
+
+ n := copy(p, r.Buf[r.off:to])
+ r.off += n
+
+ if n < len(p) {
+ return n, r.Err
+
+ }
+
+ return n, nil
+}
+
+func TestDownloadBufferStrategy_Errors(t *testing.T) {
+ expected := managertesting.GetTestBytes(int(10 * sdkio.MebiByte))
+
+ client, _, _ := newDownloadRangeClient(expected)
+ strat := &recordedWriterReadFromProvider{
+ WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(int(2 * sdkio.MebiByte)),
+ }
+
+ seenOps := make(map[string]struct{})
+ orig := client.GetObjectFn
+ client.GetObjectFn = func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
+ out, err := orig(ctx, params, optFns...)
+
+ fingerPrint := fmt.Sprintf("%s/%s/%s", *params.Bucket, *params.Key, *params.Range)
+ if _, ok := seenOps[fingerPrint]; ok {
+ return out, err
+ }
+ seenOps[fingerPrint] = struct{}{}
+
+ _, _ = io.Copy(ioutil.Discard, out.Body)
+
+ out.Body = ioutil.NopCloser(&badReader{err: io.ErrUnexpectedEOF})
+
+ return out, err
+ }
+
+ d := manager.NewDownloader(client, func(d *manager.Downloader) {
+ d.PartSize = 5 * sdkio.MebiByte
+ d.BufferProvider = strat
+ d.Concurrency = 1
+ })
+
+ buffer := aws.NewWriteAtBuffer(make([]byte, len(expected)))
+
+ n, err := d.Download(context.Background(), buffer, &s3.GetObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ })
+ if err != nil {
+ t.Errorf("failed to download: %v", err)
+ }
+
+ if e, a := len(expected), int(n); e != a {
+ t.Errorf("expected %v, got %v downloaded bytes", e, a)
+ }
+
+ if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) {
+ t.Errorf("downloaded bytes did not match expected")
+ }
+
+ if e, a := strat.callbacksVended, strat.callbacksExecuted; e != a {
+ t.Errorf("expected %v, got %v", e, a)
+ }
+}
+
+type recordedWriterReadFromProvider struct {
+ callbacksVended uint32
+ callbacksExecuted uint32
+ manager.WriterReadFromProvider
+}
+
+func (r *recordedWriterReadFromProvider) GetReadFrom(writer io.Writer) (manager.WriterReadFrom, func()) {
+ w, cleanup := r.WriterReadFromProvider.GetReadFrom(writer)
+
+ atomic.AddUint32(&r.callbacksVended, 1)
+ return w, func() {
+ atomic.AddUint32(&r.callbacksExecuted, 1)
+ cleanup()
+ }
+}
+
+type badReader struct {
+ err error
+}
+
+func (b *badReader) Read(p []byte) (int, error) {
+ tb := managertesting.GetTestBytes(len(p))
+ copy(p, tb)
+
+ return len(p), b.err
+}
diff --git a/feature/s3/manager/examples_test.go b/feature/s3/manager/examples_test.go
new file mode 100644
index 00000000000..ef4162fdb20
--- /dev/null
+++ b/feature/s3/manager/examples_test.go
@@ -0,0 +1,69 @@
+package manager_test
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+// ExampleNewUploader_overrideReadSeekerProvider gives an example
+// on a custom ReadSeekerWriteToProvider can be provided to Uploader
+// to define how parts will be buffered in memory.
+func ExampleNewUploader_overrideReadSeekerProvider() {
+ cfg, err := config.LoadDefaultConfig()
+ if err != nil {
+ panic(err)
+ }
+
+ uploader := manager.NewUploader(s3.NewFromConfig(cfg), func(u *manager.Uploader) {
+ // Define a strategy that will buffer 25 MiB in memory
+ u.BufferProvider = manager.NewBufferedReadSeekerWriteToPool(25 * 1024 * 1024)
+ })
+
+ _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("examplebucket"),
+ Key: aws.String("largeobject"),
+ Body: bytes.NewReader([]byte("large_multi_part_upload")),
+ })
+ if err != nil {
+ panic(err)
+ }
+}
+
+// ExampleNewUploader_overrideTransport gives an example
+// on how to override the default HTTP transport. This can
+// be used to tune timeouts such as response headers, or
+// write / read buffer usage when writing or reading respectively
+// from the net/http transport.
+func ExampleNewUploader_overrideTransport() {
+ cfg, err := config.LoadDefaultConfig()
+ if err != nil {
+ panic(err)
+ }
+
+ client := s3.NewFromConfig(cfg, func(o *s3.Options) {
+ // Override Default Transport Values
+ o.HTTPClient = aws.NewBuildableHTTPClient().WithTransportOptions(func(tr *http.Transport) {
+ tr.ResponseHeaderTimeout = 1 * time.Second
+ tr.WriteBufferSize = 1024 * 1024
+ tr.ReadBufferSize = 1024 * 1024
+ })
+ })
+
+ uploader := manager.NewUploader(client)
+
+ _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("examplebucket"),
+ Key: aws.String("largeobject"),
+ Body: bytes.NewReader([]byte("large_multi_part_upload")),
+ })
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/feature/s3/manager/go.mod b/feature/s3/manager/go.mod
new file mode 100644
index 00000000000..99b133cbca1
--- /dev/null
+++ b/feature/s3/manager/go.mod
@@ -0,0 +1,21 @@
+module github.com/aws/aws-sdk-go-v2/feature/s3/manager
+
+go 1.15
+
+require (
+ github.com/aws/aws-sdk-go-v2 v0.26.0
+ github.com/aws/aws-sdk-go-v2/config v0.1.1
+ github.com/aws/aws-sdk-go-v2/service/s3 v0.26.0
+ github.com/awslabs/smithy-go v0.1.2-0.20201012175301-b4d8737f29d1
+ github.com/google/go-cmp v0.4.1
+)
+
+replace (
+ github.com/aws/aws-sdk-go-v2 => ../../../
+ github.com/aws/aws-sdk-go-v2/config => ../../../config/
+ github.com/aws/aws-sdk-go-v2/credentials => ../../../credentials/
+ github.com/aws/aws-sdk-go-v2/ec2imds => ../../../ec2imds
+ github.com/aws/aws-sdk-go-v2/service/internal/s3shared => ../../../service/internal/s3shared
+ github.com/aws/aws-sdk-go-v2/service/s3 => ../../../service/s3/
+ github.com/aws/aws-sdk-go-v2/service/sts => ../../../service/sts
+)
diff --git a/feature/s3/manager/go.sum b/feature/s3/manager/go.sum
new file mode 100644
index 00000000000..8ecbc067ee5
--- /dev/null
+++ b/feature/s3/manager/go.sum
@@ -0,0 +1,27 @@
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v0.0.0-20200930084954-897dfb99530c h1:v1H0WQmb+pNOZ/xDXGT3wXn6aceSN3I2wqK0VpQM/ZM=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v0.0.0-20200930084954-897dfb99530c/go.mod h1:GRJ/IvA6A00/2tAw9KgMTM8as5gAlNI0FVCKBc+aRnA=
+github.com/awslabs/smithy-go v0.0.0-20200930175536-2cd7f70a8c2f/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI=
+github.com/awslabs/smithy-go v0.0.0-20201009221937-21015eb9ec4b/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI=
+github.com/awslabs/smithy-go v0.1.0/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI=
+github.com/awslabs/smithy-go v0.1.1 h1:v1hUSAYf3w2ClKr58C+AtwoyPVoBjWyWT8thf7/VRtU=
+github.com/awslabs/smithy-go v0.1.1/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI=
+github.com/awslabs/smithy-go v0.1.2-0.20201012175301-b4d8737f29d1 h1:5eAoxqWUc2VMuT3ob/pUYCLliBYEk3dccw6P/reTuRY=
+github.com/awslabs/smithy-go v0.1.2-0.20201012175301-b4d8737f29d1/go.mod h1:hPOQwnmBLHsUphH13tVSjQhTAFma0/0XoZGbBcOuABI=
+github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0=
+github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
+github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
diff --git a/feature/s3/manager/integ_bucket_region_test.go b/feature/s3/manager/integ_bucket_region_test.go
new file mode 100644
index 00000000000..d3f886619b7
--- /dev/null
+++ b/feature/s3/manager/integ_bucket_region_test.go
@@ -0,0 +1,25 @@
+// +build integration
+
+package manager_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+func TestGetBucketRegion(t *testing.T) {
+ expectRegion := integConfig.Region
+
+ region, err := manager.GetBucketRegion(context.Background(), s3.NewFromConfig(integConfig), aws.ToString(bucketName))
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ if e, a := expectRegion, region; e != a {
+ t.Errorf("expect %s bucket region, got %s", e, a)
+ }
+}
diff --git a/feature/s3/manager/integ_shared_test.go b/feature/s3/manager/integ_shared_test.go
new file mode 100644
index 00000000000..68377617359
--- /dev/null
+++ b/feature/s3/manager/integ_shared_test.go
@@ -0,0 +1,104 @@
+// +build integration
+
+package manager_test
+
+import (
+ "context"
+ "crypto/md5"
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/integration"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+var integConfig aws.Config
+
+func init() {
+ var err error
+
+ integConfig, err = config.LoadDefaultConfig(config.WithDefaultRegion("us-west-2"))
+ if err != nil {
+ panic(err)
+ }
+}
+
+var bucketName *string
+var client *s3.Client
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+ flag.CommandLine.Visit(func(f *flag.Flag) {
+ if !(f.Name == "run" || f.Name == "test.run") {
+ return
+ }
+ value := f.Value.String()
+ if value == `NONE` {
+ os.Exit(0)
+ }
+ })
+
+ client = s3.NewFromConfig(integConfig)
+ bucketName = aws.String(integration.GenerateBucketName())
+ if err := integration.SetupBucket(client, *bucketName, integConfig.Region); err != nil {
+ panic(err)
+ }
+
+ var result int
+ defer func() {
+ if err := integration.CleanupBucket(client, *bucketName); err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ }
+ if r := recover(); r != nil {
+ fmt.Fprintln(os.Stderr, "S3 integration tests panicked,", r)
+ result = 1
+ }
+ os.Exit(result)
+ }()
+
+ result = m.Run()
+}
+
+type dlwriter struct {
+ buf []byte
+}
+
+func newDLWriter(size int) *dlwriter {
+ return &dlwriter{buf: make([]byte, size)}
+}
+
+func (d dlwriter) WriteAt(p []byte, pos int64) (n int, err error) {
+ if pos > int64(len(d.buf)) {
+ return 0, io.EOF
+ }
+
+ written := 0
+ for i, b := range p {
+ if i >= len(d.buf) {
+ break
+ }
+ d.buf[pos+int64(i)] = b
+ written++
+ }
+ return written, nil
+}
+
+func validate(t *testing.T, key string, md5value string) {
+ mgr := manager.NewDownloader(client)
+ params := &s3.GetObjectInput{Bucket: bucketName, Key: &key}
+
+ w := newDLWriter(1024 * 1024 * 20)
+ n, err := mgr.Download(context.Background(), w, params)
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if e, a := md5value, fmt.Sprintf("%x", md5.Sum(w.buf[0:n])); e != a {
+ t.Errorf("expect %s md5 value, got %s", e, a)
+ }
+}
diff --git a/feature/s3/manager/integ_upload_test.go b/feature/s3/manager/integ_upload_test.go
new file mode 100644
index 00000000000..fb34c6021b9
--- /dev/null
+++ b/feature/s3/manager/integ_upload_test.go
@@ -0,0 +1,99 @@
+// +build integration
+
+package manager_test
+
+import (
+ "bytes"
+ "context"
+ "crypto/md5"
+ "errors"
+ "fmt"
+ "regexp"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/awslabs/smithy-go/middleware"
+)
+
+var integBuf12MB = make([]byte, 1024*1024*12)
+var integMD512MB = fmt.Sprintf("%x", md5.Sum(integBuf12MB))
+
+func TestUploadConcurrently(t *testing.T) {
+ key := "12mb-1"
+ mgr := manager.NewUploader(client)
+ out, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: bucketName,
+ Key: &key,
+ Body: bytes.NewReader(integBuf12MB),
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ if len(out.UploadID) == 0 {
+ t.Errorf("expect upload ID but was empty")
+ }
+
+ re := regexp.MustCompile(`^https?://.+/` + key + `$`)
+ if e, a := re.String(), out.Location; !re.MatchString(a) {
+ t.Errorf("expect %s to match URL regexp %q, did not", e, a)
+ }
+
+ validate(t, key, integMD512MB)
+}
+
+type invalidateHash struct{}
+
+func (b *invalidateHash) ID() string {
+ return "s3manager:InvalidateHash"
+}
+
+func (b *invalidateHash) RegisterMiddleware(stack *middleware.Stack) error {
+ return stack.Serialize.Add(b, middleware.After)
+}
+
+func (b *invalidateHash) HandleSerialize(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) (
+ out middleware.SerializeOutput, metadata middleware.Metadata, err error,
+) {
+ if input, ok := in.Parameters.(*s3.UploadPartInput); ok && aws.ToInt32(input.PartNumber) == 2 {
+ ctx = v4.SetPayloadHash(ctx, "000")
+ }
+
+ return next.HandleSerialize(ctx, in)
+}
+
+func TestUploadFailCleanup(t *testing.T) {
+ key := "12mb-leave"
+ mgr := manager.NewUploader(client, func(u *manager.Uploader) {
+ u.LeavePartsOnError = false
+ u.ClientOptions = append(u.ClientOptions, func(options *s3.Options) {
+ options.APIOptions = append(options.APIOptions, (&invalidateHash{}).RegisterMiddleware)
+ })
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: bucketName,
+ Key: &key,
+ Body: bytes.NewReader(integBuf12MB),
+ })
+ if err == nil {
+ t.Fatalf("expect error, but did not get one")
+ }
+
+ uploadID := ""
+ var uf manager.MultiUploadFailure
+ if !errors.As(err, &uf) {
+ t.Errorf("")
+ } else if uploadID = uf.UploadID(); len(uploadID) == 0 {
+ t.Errorf("expect upload ID to not be empty, but was")
+ }
+
+ _, err = client.ListParts(context.Background(), &s3.ListPartsInput{
+ Bucket: bucketName, Key: &key, UploadId: &uploadID,
+ })
+ if err == nil {
+ t.Errorf("expect error for list parts, but got none")
+ }
+}
diff --git a/feature/s3/manager/internal/integration/downloader/README.md b/feature/s3/manager/internal/integration/downloader/README.md
new file mode 100644
index 00000000000..50d690c2060
--- /dev/null
+++ b/feature/s3/manager/internal/integration/downloader/README.md
@@ -0,0 +1,22 @@
+## Performance Utility
+
+Downloads a test file from a S3 bucket using the SDK's S3 download manager. Allows passing
+in a custom configuration for the HTTP client and SDK's Download Manager behavior.
+
+## Build
+```sh
+go test -tags "integration perftest" -c -o download.test ./s3manager/internal/integration/download
+```
+
+## Usage Example:
+```sh
+AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./download.test \
+-test.bench=. \
+-test.benchmem \
+-test.benchtime 1x \
+-bucket aws-sdk-go-data \
+-client.idle-conns 1000 \
+-client.idle-conns-host 300 \
+-client.timeout.connect=1s \
+-client.timeout.response-header=1s
+```
diff --git a/feature/s3/manager/internal/integration/downloader/client.go b/feature/s3/manager/internal/integration/downloader/client.go
new file mode 100644
index 00000000000..6e2a4b1d616
--- /dev/null
+++ b/feature/s3/manager/internal/integration/downloader/client.go
@@ -0,0 +1,34 @@
+// +build integration,perftest
+
+package downloader
+
+import (
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+)
+
+func NewHTTPClient(cfg ClientConfig) aws.HTTPClient {
+ return aws.NewBuildableHTTPClient().WithTransportOptions(func(tr *http.Transport) {
+ *tr = http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: cfg.Timeouts.Connect,
+ KeepAlive: 30 * time.Second,
+ }).DialContext,
+ MaxIdleConns: cfg.MaxIdleConns,
+ MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
+ IdleConnTimeout: 90 * time.Second,
+
+ DisableKeepAlives: !cfg.KeepAlive,
+ TLSHandshakeTimeout: cfg.Timeouts.TLSHandshake,
+ ExpectContinueTimeout: cfg.Timeouts.ExpectContinue,
+ ResponseHeaderTimeout: cfg.Timeouts.ResponseHeader,
+
+ ReadBufferSize: cfg.ReadBufferSize,
+ WriteBufferSize: cfg.WriteBufferSize,
+ }
+ })
+}
diff --git a/feature/s3/manager/internal/integration/downloader/config.go b/feature/s3/manager/internal/integration/downloader/config.go
new file mode 100644
index 00000000000..fc4f5dd0018
--- /dev/null
+++ b/feature/s3/manager/internal/integration/downloader/config.go
@@ -0,0 +1,116 @@
+// +build integration,perftest
+
+package downloader
+
+import (
+ "flag"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+)
+
+type SDKConfig struct {
+ PartSize int64
+ Concurrency int
+ BufferProvider manager.WriterReadFromProvider
+}
+
+func (c *SDKConfig) SetupFlags(prefix string, flagset *flag.FlagSet) {
+ prefix += "sdk."
+
+ flagset.Int64Var(&c.PartSize, prefix+"part-size", manager.DefaultDownloadPartSize,
+ "Specifies the `size` of parts of the object to download.")
+ flagset.IntVar(&c.Concurrency, prefix+"concurrency", manager.DefaultDownloadConcurrency,
+ "Specifies the number of parts to download `at once`.")
+}
+
+func (c *SDKConfig) Validate() error {
+ return nil
+}
+
+type ClientConfig struct {
+ KeepAlive bool
+ Timeouts Timeouts
+
+ MaxIdleConns int
+ MaxIdleConnsPerHost int
+
+ // Go 1.13
+ ReadBufferSize int
+ WriteBufferSize int
+}
+
+func (c *ClientConfig) SetupFlags(prefix string, flagset *flag.FlagSet) {
+ prefix += "client."
+
+ flagset.BoolVar(&c.KeepAlive, prefix+"http-keep-alive", true,
+ "Specifies if HTTP keep alive is enabled.")
+
+ defTR := http.DefaultTransport.(*http.Transport)
+
+ flagset.IntVar(&c.MaxIdleConns, prefix+"idle-conns", defTR.MaxIdleConns,
+ "Specifies max idle connection pool size.")
+
+ flagset.IntVar(&c.MaxIdleConnsPerHost, prefix+"idle-conns-host", http.DefaultMaxIdleConnsPerHost,
+ "Specifies max idle connection pool per host, will be truncated by idle-conns.")
+
+ flagset.IntVar(&c.ReadBufferSize, prefix+"read-buffer", defTR.ReadBufferSize, "size of the transport read buffer used")
+ flagset.IntVar(&c.WriteBufferSize, prefix+"writer-buffer", defTR.WriteBufferSize, "size of the transport write buffer used")
+
+ c.Timeouts.SetupFlags(prefix, flagset)
+}
+
+func (c *ClientConfig) Validate() error {
+ var errs Errors
+
+ if err := c.Timeouts.Validate(); err != nil {
+ errs = append(errs, err)
+ }
+
+ if len(errs) != 0 {
+ return errs
+ }
+ return nil
+}
+
+type Timeouts struct {
+ Connect time.Duration
+ TLSHandshake time.Duration
+ ExpectContinue time.Duration
+ ResponseHeader time.Duration
+}
+
+func (c *Timeouts) SetupFlags(prefix string, flagset *flag.FlagSet) {
+ prefix += "timeout."
+
+ flagset.DurationVar(&c.Connect, prefix+"connect", 30*time.Second,
+ "The `timeout` connecting to the remote host.")
+
+ defTR := http.DefaultTransport.(*http.Transport)
+
+ flagset.DurationVar(&c.TLSHandshake, prefix+"tls", defTR.TLSHandshakeTimeout,
+ "The `timeout` waiting for the TLS handshake to complete.")
+
+ flagset.DurationVar(&c.ExpectContinue, prefix+"expect-continue", defTR.ExpectContinueTimeout,
+ "The `timeout` waiting for the TLS handshake to complete.")
+
+ flagset.DurationVar(&c.ResponseHeader, prefix+"response-header", defTR.ResponseHeaderTimeout,
+ "The `timeout` waiting for the TLS handshake to complete.")
+}
+
+func (c *Timeouts) Validate() error {
+ return nil
+}
+
+type Errors []error
+
+func (es Errors) Error() string {
+ var buf strings.Builder
+ for _, e := range es {
+ buf.WriteString(e.Error())
+ }
+
+ return buf.String()
+}
diff --git a/feature/s3/manager/internal/integration/downloader/main_test.go b/feature/s3/manager/internal/integration/downloader/main_test.go
new file mode 100644
index 00000000000..38e275e00fb
--- /dev/null
+++ b/feature/s3/manager/internal/integration/downloader/main_test.go
@@ -0,0 +1,277 @@
+// +build integration,perftest
+
+package downloader
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/integration"
+ "github.com/aws/aws-sdk-go-v2/internal/awstesting"
+ "github.com/aws/aws-sdk-go-v2/internal/sdkio"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+var benchConfig BenchmarkConfig
+
+type BenchmarkConfig struct {
+ bucket string
+ tempdir string
+ clientConfig ClientConfig
+ sizes string
+ parts string
+ concurrency string
+ bufferSize string
+ uploadPartSize int64
+}
+
+func (b *BenchmarkConfig) SetupFlags(prefix string, flagSet *flag.FlagSet) {
+ flagSet.StringVar(&b.bucket, "bucket", "", "Bucket to use for benchmark")
+ flagSet.StringVar(&b.tempdir, "temp", os.TempDir(), "location to create temporary files")
+
+ flagSet.StringVar(&b.sizes, "size",
+ fmt.Sprintf("%d,%d",
+ 5*sdkio.MebiByte,
+ 1*sdkio.GibiByte), "file sizes to benchmark separated by comma")
+
+ flagSet.StringVar(&b.parts, "part",
+ fmt.Sprintf("%d,%d,%d",
+ manager.DefaultDownloadPartSize,
+ 25*sdkio.MebiByte,
+ 100*sdkio.MebiByte), "part sizes to benchmark separated by comma")
+
+ flagSet.StringVar(&b.concurrency, "concurrency",
+ fmt.Sprintf("%d,%d,%d",
+ manager.DefaultDownloadConcurrency,
+ 2*manager.DefaultDownloadConcurrency,
+ 100),
+ "part sizes to benchmark separated comma")
+
+ flagSet.StringVar(&b.bufferSize, "buffer", fmt.Sprintf("%d,%d", 0, 1*sdkio.MebiByte), "part sizes to benchmark separated comma")
+ flagSet.Int64Var(&b.uploadPartSize, "upload-part-size", 0, "upload part size, defaults to download part size if not specified")
+ b.clientConfig.SetupFlags(prefix, flagSet)
+}
+
+func (b *BenchmarkConfig) BufferSizes() []int {
+ ints, err := b.stringToInt(b.bufferSize)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse file sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) FileSizes() []int64 {
+ ints, err := b.stringToInt64(b.sizes)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse file sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) PartSizes() []int64 {
+ ints, err := b.stringToInt64(b.parts)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse part sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) Concurrences() []int {
+ ints, err := b.stringToInt(b.concurrency)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse part sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) stringToInt(s string) ([]int, error) {
+ int64s, err := b.stringToInt64(s)
+ if err != nil {
+ return nil, err
+ }
+
+ var ints []int
+ for i := range int64s {
+ ints = append(ints, int(int64s[i]))
+ }
+
+ return ints, nil
+}
+
+func (b *BenchmarkConfig) stringToInt64(s string) ([]int64, error) {
+ var sizes []int64
+
+ split := strings.Split(s, ",")
+
+ for _, size := range split {
+ size = strings.Trim(size, " ")
+ i, err := strconv.ParseInt(size, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid integer %s: %v", size, err)
+ }
+
+ sizes = append(sizes, i)
+ }
+
+ return sizes, nil
+}
+
+func BenchmarkDownload(b *testing.B) {
+ baseSdkConfig := SDKConfig{}
+
+ for _, fileSize := range benchConfig.FileSizes() {
+ b.Run(fmt.Sprintf("%s File", integration.SizeToName(int(fileSize))), func(b *testing.B) {
+ for _, partSize := range benchConfig.PartSizes() {
+ if partSize > fileSize {
+ continue
+ }
+ uploadPartSize := getUploadPartSize(fileSize, benchConfig.uploadPartSize, partSize)
+ b.Run(fmt.Sprintf("%s PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) {
+ b.Logf("setting up s3 file size")
+ key, err := setupDownloadTest(benchConfig.bucket, fileSize, uploadPartSize)
+ if err != nil {
+ b.Fatalf("failed to setup download test: %v", err)
+ }
+ for _, concurrency := range benchConfig.Concurrences() {
+ b.Run(fmt.Sprintf("%d Concurrency", concurrency), func(b *testing.B) {
+ for _, bufferSize := range benchConfig.BufferSizes() {
+ var name string
+ if bufferSize == 0 {
+ name = "unbuffered"
+ } else {
+ name = fmt.Sprintf("%s buffer", integration.SizeToName(bufferSize))
+ }
+ b.Run(name, func(b *testing.B) {
+ sdkConfig := baseSdkConfig
+ sdkConfig.Concurrency = concurrency
+ sdkConfig.PartSize = partSize
+ if bufferSize > 0 {
+ sdkConfig.BufferProvider = manager.NewPooledBufferedWriterReadFromProvider(bufferSize)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchDownload(b, benchConfig.bucket, key, &awstesting.DiscardAt{}, sdkConfig, benchConfig.clientConfig)
+ }
+ })
+ }
+ })
+ }
+ b.Log("removing test file")
+ err = teardownDownloadTest(benchConfig.bucket, key)
+ if err != nil {
+ b.Fatalf("failed to cleanup test file: %v", err)
+ }
+ })
+ }
+ })
+ }
+}
+
+func benchDownload(b *testing.B, bucket, key string, body io.WriterAt, sdkConfig SDKConfig, clientConfig ClientConfig) {
+ downloader := newDownloader(clientConfig, sdkConfig)
+ _, err := downloader.Download(context.Background(), body, &s3.GetObjectInput{
+ Bucket: &bucket,
+ Key: &key,
+ })
+ if err != nil {
+ b.Fatalf("failed to download object, %v", err)
+ }
+}
+
+func TestMain(m *testing.M) {
+ if strings.EqualFold(os.Getenv("BUILD_ONLY"), "true") {
+ os.Exit(0)
+ }
+ benchConfig.SetupFlags("", flag.CommandLine)
+ flag.Parse()
+ os.Exit(m.Run())
+}
+
+func setupDownloadTest(bucket string, fileSize, partSize int64) (key string, err error) {
+ er := &awstesting.EndlessReader{}
+ lr := io.LimitReader(er, fileSize)
+
+ key = integration.MustUUID()
+
+ defaultConfig, err := config.LoadDefaultConfig()
+ if err != nil {
+ return "", err
+ }
+
+ client := s3.NewFromConfig(defaultConfig)
+
+ uploader := manager.NewUploader(client, func(u *manager.Uploader) {
+ u.PartSize = partSize
+ u.Concurrency = runtime.NumCPU() * 2
+ })
+
+ _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: &bucket,
+ Body: lr,
+ Key: &key,
+ })
+ if err != nil {
+ err = fmt.Errorf("failed to upload test object to s3: %v", err)
+ }
+
+ return
+}
+
+func teardownDownloadTest(bucket, key string) error {
+ defaultConfig, err := config.LoadDefaultConfig()
+ if err != nil {
+ log.Fatalf("failed to load config: %v", err)
+ }
+
+ client := s3.NewFromConfig(defaultConfig)
+
+ _, err = client.DeleteObject(context.Background(), &s3.DeleteObjectInput{Bucket: &bucket, Key: &key})
+ return err
+}
+
+func newDownloader(clientConfig ClientConfig, sdkConfig SDKConfig) *manager.Downloader {
+ defaultConfig, err := config.LoadDefaultConfig()
+ if err != nil {
+ log.Fatalf("failed to load config: %v", err)
+ }
+
+ client := s3.NewFromConfig(defaultConfig, func(options *s3.Options) {
+ options.HTTPClient = NewHTTPClient(clientConfig)
+ })
+
+ downloader := manager.NewDownloader(client, func(d *manager.Downloader) {
+ d.PartSize = sdkConfig.PartSize
+ d.Concurrency = sdkConfig.Concurrency
+ d.BufferProvider = sdkConfig.BufferProvider
+ })
+
+ return downloader
+}
+
+func getUploadPartSize(fileSize, uploadPartSize, downloadPartSize int64) int64 {
+ partSize := uploadPartSize
+
+ if partSize == 0 {
+ partSize = downloadPartSize
+ }
+ if fileSize/partSize > int64(manager.MaxUploadParts) {
+ partSize = (fileSize / int64(manager.MaxUploadParts)) + 1
+ }
+
+ return partSize
+}
diff --git a/feature/s3/manager/internal/integration/integration.go b/feature/s3/manager/internal/integration/integration.go
new file mode 100644
index 00000000000..ded44f39ae2
--- /dev/null
+++ b/feature/s3/manager/internal/integration/integration.go
@@ -0,0 +1,204 @@
+package integration
+
+import (
+ "context"
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/aws/aws-sdk-go-v2/service/s3/types"
+ smithyrand "github.com/awslabs/smithy-go/rand"
+)
+
+var uuid = smithyrand.NewUUID(rand.Reader)
+
+// MustUUID returns an UUID string or panics
+func MustUUID() string {
+ uuid, err := uuid.GetUUID()
+ if err != nil {
+ panic(err)
+ }
+ return uuid
+}
+
+// CreateFileOfSize will return an *os.File that is of size bytes
+func CreateFileOfSize(dir string, size int64) (*os.File, error) {
+ file, err := ioutil.TempFile(dir, "s3integration")
+ if err != nil {
+ return nil, err
+ }
+
+ err = file.Truncate(size)
+ if err != nil {
+ file.Close()
+ os.Remove(file.Name())
+ return nil, err
+ }
+
+ return file, nil
+}
+
+// SizeToName returns a human-readable string for the given size bytes
+func SizeToName(size int) string {
+ units := []string{"B", "KB", "MB", "GB"}
+ i := 0
+ for size >= 1024 {
+ size /= 1024
+ i++
+ }
+
+ if i > len(units)-1 {
+ i = len(units) - 1
+ }
+
+ return fmt.Sprintf("%d%s", size, units[i])
+}
+
+// BucketPrefix is the root prefix of integration test buckets.
+const BucketPrefix = "aws-sdk-go-v2-integration"
+
+// GenerateBucketName returns a unique bucket name.
+func GenerateBucketName() string {
+ var id [16]byte
+ _, err := rand.Read(id[:])
+ if err != nil {
+ panic(err)
+ }
+
+ return fmt.Sprintf("%s-%x",
+ BucketPrefix, id)
+}
+
+// SetupBucket returns a test bucket created for the integration tests.
+func SetupBucket(client *s3.Client, bucketName, region string) (err error) {
+ fmt.Println("Setup: Creating test bucket,", bucketName)
+ _, err = client.CreateBucket(context.Background(), &s3.CreateBucketInput{
+ Bucket: &bucketName,
+ CreateBucketConfiguration: &types.CreateBucketConfiguration{
+ LocationConstraint: types.BucketLocationConstraint(region),
+ },
+ })
+ if err != nil {
+ return fmt.Errorf("failed to create bucket %s, %v", bucketName, err)
+ }
+
+ fmt.Println("Setup: Waiting for bucket to exist,", bucketName)
+ err = waitUntilBucketExists(context.Background(), client, &s3.HeadBucketInput{Bucket: &bucketName})
+ if err != nil {
+ return fmt.Errorf("failed waiting for bucket %s to be created, %v",
+ bucketName, err)
+ }
+
+ return nil
+}
+
+func waitUntilBucketExists(ctx context.Context, client *s3.Client, params *s3.HeadBucketInput) error {
+ for i := 0; i < 20; i++ {
+ _, err := client.HeadBucket(ctx, params)
+ if err == nil {
+ return nil
+ }
+
+ var httpErr interface{ HTTPStatusCode() int }
+
+ if !errors.As(err, &httpErr) {
+ return err
+ }
+
+ if httpErr.HTTPStatusCode() == http.StatusMovedPermanently || httpErr.HTTPStatusCode() == http.StatusForbidden {
+ return nil
+ }
+
+ if httpErr.HTTPStatusCode() != http.StatusNotFound {
+ return err
+ }
+
+ time.Sleep(5 * time.Second)
+ }
+ return nil
+}
+
+// CleanupBucket deletes the contents of a S3 bucket, before deleting the bucket
+// it self.
+func CleanupBucket(client *s3.Client, bucketName string) error {
+ var errs []error
+
+ {
+ fmt.Println("TearDown: Deleting objects from test bucket,", bucketName)
+ input := &s3.ListObjectsV2Input{Bucket: &bucketName}
+ for {
+ listObjectsV2, err := client.ListObjectsV2(context.Background(), input)
+ if err != nil {
+ return fmt.Errorf("failed to list objects, %w", err)
+ }
+
+ var delete types.Delete
+ for _, content := range listObjectsV2.Contents {
+ obj := content
+ delete.Objects = append(delete.Objects, &types.ObjectIdentifier{Key: obj.Key})
+ }
+
+ deleteObjects, err := client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{
+ Bucket: &bucketName,
+ Delete: &delete,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ break
+ }
+ for _, deleteError := range deleteObjects.Errors {
+ errs = append(errs, fmt.Errorf("failed to delete %s, %s", aws.ToString(deleteError.Key), aws.ToString(deleteError.Message)))
+ }
+
+ if aws.ToBool(listObjectsV2.IsTruncated) {
+ input.ContinuationToken = listObjectsV2.NextContinuationToken
+ } else {
+ break
+ }
+ }
+ }
+
+ {
+ fmt.Println("TearDown: Deleting partial uploads from test bucket,", bucketName)
+
+ input := &s3.ListMultipartUploadsInput{Bucket: &bucketName}
+ for {
+ uploads, err := client.ListMultipartUploads(context.Background(), input)
+ if err != nil {
+ return fmt.Errorf("failed to list multipart objects, %w", err)
+ }
+
+ for _, upload := range uploads.Uploads {
+ client.AbortMultipartUpload(context.Background(), &s3.AbortMultipartUploadInput{
+ Bucket: &bucketName,
+ Key: upload.Key,
+ UploadId: upload.UploadId,
+ })
+ }
+
+ if aws.ToBool(uploads.IsTruncated) {
+ input.KeyMarker = uploads.NextKeyMarker
+ input.UploadIdMarker = uploads.NextUploadIdMarker
+ } else {
+ break
+ }
+ }
+ }
+
+ if len(errs) != 0 {
+ return fmt.Errorf("failed to delete objects, %s", errs)
+ }
+
+ fmt.Println("TearDown: Deleting test bucket,", bucketName)
+ if _, err := client.DeleteBucket(context.Background(), &s3.DeleteBucketInput{Bucket: &bucketName}); err != nil {
+ return fmt.Errorf("failed to delete test bucket %s, %w", bucketName, err)
+ }
+
+ return nil
+}
diff --git a/feature/s3/manager/internal/integration/uploader/README.md b/feature/s3/manager/internal/integration/uploader/README.md
new file mode 100644
index 00000000000..6f12095fd35
--- /dev/null
+++ b/feature/s3/manager/internal/integration/uploader/README.md
@@ -0,0 +1,22 @@
+## Performance Utility
+
+Uploads a file to a S3 bucket using the SDK's S3 upload manager. Allows passing
+in a custom configuration for the HTTP client and SDK's Upload Manager behavior.
+
+## Build
+```sh
+go test -tags "integration perftest" -c -o uploader.test ./s3manager/internal/integration/performance/uploader
+```
+
+## Usage Example:
+```sh
+AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./uploader.test \
+-test.bench=. \
+-test.benchmem \
+-test.benchtime 1x \
+-bucket aws-sdk-go-data \
+-client.idle-conns 1000 \
+-client.idle-conns-host 300 \
+-client.timeout.connect=1s \
+-client.timeout.response-header=1s
+```
diff --git a/feature/s3/manager/internal/integration/uploader/client.go b/feature/s3/manager/internal/integration/uploader/client.go
new file mode 100644
index 00000000000..e409c61f890
--- /dev/null
+++ b/feature/s3/manager/internal/integration/uploader/client.go
@@ -0,0 +1,31 @@
+// +build integration,perftest
+
+package uploader
+
+import (
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+)
+
+func NewHTTPClient(cfg ClientConfig) aws.HTTPClient {
+ return aws.NewBuildableHTTPClient().WithTransportOptions(func(transport *http.Transport) {
+ *transport = http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: cfg.Timeouts.Connect,
+ KeepAlive: 30 * time.Second,
+ }).DialContext,
+ MaxIdleConns: cfg.MaxIdleConns,
+ MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
+ IdleConnTimeout: 90 * time.Second,
+
+ DisableKeepAlives: !cfg.KeepAlive,
+ TLSHandshakeTimeout: cfg.Timeouts.TLSHandshake,
+ ExpectContinueTimeout: cfg.Timeouts.ExpectContinue,
+ ResponseHeaderTimeout: cfg.Timeouts.ResponseHeader,
+ }
+ })
+}
diff --git a/feature/s3/manager/internal/integration/uploader/config.go b/feature/s3/manager/internal/integration/uploader/config.go
new file mode 100644
index 00000000000..a89f43de068
--- /dev/null
+++ b/feature/s3/manager/internal/integration/uploader/config.go
@@ -0,0 +1,109 @@
+// +build integration,perftest
+
+package uploader
+
+import (
+ "flag"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+)
+
+type SDKConfig struct {
+ PartSize int64
+ Concurrency int
+ BufferProvider manager.ReadSeekerWriteToProvider
+}
+
+func (c *SDKConfig) SetupFlags(prefix string, flagset *flag.FlagSet) {
+ prefix += "sdk."
+
+ flagset.Int64Var(&c.PartSize, prefix+"part-size", manager.DefaultUploadPartSize,
+ "Specifies the `size` of parts of the object to upload.")
+ flagset.IntVar(&c.Concurrency, prefix+"concurrency", manager.DefaultUploadConcurrency,
+ "Specifies the number of parts to upload `at once`.")
+}
+
+func (c *SDKConfig) Validate() error {
+ return nil
+}
+
+type ClientConfig struct {
+ KeepAlive bool
+ Timeouts Timeouts
+
+ MaxIdleConns int
+ MaxIdleConnsPerHost int
+}
+
+func (c *ClientConfig) SetupFlags(prefix string, flagset *flag.FlagSet) {
+ prefix += "client."
+
+ flagset.BoolVar(&c.KeepAlive, prefix+"http-keep-alive", true,
+ "Specifies if HTTP keep alive is enabled.")
+
+ defTR := http.DefaultTransport.(*http.Transport)
+
+ flagset.IntVar(&c.MaxIdleConns, prefix+"idle-conns", defTR.MaxIdleConns,
+ "Specifies max idle connection pool size.")
+
+ flagset.IntVar(&c.MaxIdleConnsPerHost, prefix+"idle-conns-host", http.DefaultMaxIdleConnsPerHost,
+ "Specifies max idle connection pool per host, will be truncated by idle-conns.")
+
+ c.Timeouts.SetupFlags(prefix, flagset)
+}
+
+func (c *ClientConfig) Validate() error {
+ var errs Errors
+
+ if err := c.Timeouts.Validate(); err != nil {
+ errs = append(errs, err)
+ }
+
+ if len(errs) != 0 {
+ return errs
+ }
+ return nil
+}
+
+type Timeouts struct {
+ Connect time.Duration
+ TLSHandshake time.Duration
+ ExpectContinue time.Duration
+ ResponseHeader time.Duration
+}
+
+func (c *Timeouts) SetupFlags(prefix string, flagset *flag.FlagSet) {
+ prefix += "timeout."
+
+ flagset.DurationVar(&c.Connect, prefix+"connect", 30*time.Second,
+ "The `timeout` connecting to the remote host.")
+
+ defTR := http.DefaultTransport.(*http.Transport)
+
+ flagset.DurationVar(&c.TLSHandshake, prefix+"tls", defTR.TLSHandshakeTimeout,
+ "The `timeout` waiting for the TLS handshake to complete.")
+
+ flagset.DurationVar(&c.ExpectContinue, prefix+"expect-continue", defTR.ExpectContinueTimeout,
+ "The `timeout` waiting for the TLS handshake to complete.")
+
+ flagset.DurationVar(&c.ResponseHeader, prefix+"response-header", defTR.ResponseHeaderTimeout,
+ "The `timeout` waiting for the TLS handshake to complete.")
+}
+
+func (c *Timeouts) Validate() error {
+ return nil
+}
+
+type Errors []error
+
+func (es Errors) Error() string {
+ var buf strings.Builder
+ for _, e := range es {
+ buf.WriteString(e.Error())
+ }
+
+ return buf.String()
+}
diff --git a/feature/s3/manager/internal/integration/uploader/main_test.go b/feature/s3/manager/internal/integration/uploader/main_test.go
new file mode 100644
index 00000000000..d333288f055
--- /dev/null
+++ b/feature/s3/manager/internal/integration/uploader/main_test.go
@@ -0,0 +1,231 @@
+// +build integration,perftest
+
+package uploader
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/integration"
+ "github.com/aws/aws-sdk-go-v2/internal/awstesting"
+ "github.com/aws/aws-sdk-go-v2/internal/sdkio"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+func newUploader(clientConfig ClientConfig, sdkConfig SDKConfig) *manager.Uploader {
+ defaultConfig, err := config.LoadDefaultConfig()
+ if err != nil {
+ log.Fatalf("failed to load config: %v", err)
+ }
+
+ client := s3.NewFromConfig(defaultConfig, func(o *s3.Options) {
+ o.HTTPClient = NewHTTPClient(clientConfig)
+ })
+
+ uploader := manager.NewUploader(client, func(u *manager.Uploader) {
+ u.PartSize = sdkConfig.PartSize
+ u.Concurrency = sdkConfig.Concurrency
+ u.BufferProvider = sdkConfig.BufferProvider
+ })
+
+ return uploader
+}
+
+func getUploadPartSize(fileSize, uploadPartSize int64) int64 {
+ partSize := uploadPartSize
+
+ if fileSize/partSize > int64(manager.MaxUploadParts) {
+ partSize = (fileSize / int64(manager.MaxUploadParts)) + 1
+ }
+
+ return partSize
+}
+
+var benchConfig BenchmarkConfig
+
+type BenchmarkConfig struct {
+ bucket string
+ tempdir string
+ clientConfig ClientConfig
+ sizes string
+ parts string
+ concurrency string
+ bufferSize string
+}
+
+func (b *BenchmarkConfig) SetupFlags(prefix string, flagSet *flag.FlagSet) {
+ flagSet.StringVar(&b.bucket, "bucket", "", "Bucket to use for benchmark")
+ flagSet.StringVar(&b.tempdir, "temp", os.TempDir(), "location to create temporary files")
+
+ flagSet.StringVar(&b.sizes, "size",
+ fmt.Sprintf("%d,%d",
+ 5*sdkio.MebiByte,
+ 1*sdkio.GibiByte), "file sizes to benchmark separated by comma")
+
+ flagSet.StringVar(&b.parts, "part",
+ fmt.Sprintf("%d,%d,%d",
+ manager.DefaultUploadPartSize,
+ 25*sdkio.MebiByte,
+ 100*sdkio.MebiByte), "part sizes to benchmark separated by comma")
+
+ flagSet.StringVar(&b.concurrency, "concurrency",
+ fmt.Sprintf("%d,%d,%d",
+ manager.DefaultUploadConcurrency,
+ 2*manager.DefaultUploadConcurrency,
+ 100),
+ "concurrences to benchmark separated comma")
+
+ flagSet.StringVar(&b.bufferSize, "buffer", fmt.Sprintf("%d,%d", 0, 1*sdkio.MebiByte), "part sizes to benchmark separated comma")
+ b.clientConfig.SetupFlags(prefix, flagSet)
+}
+
+func (b *BenchmarkConfig) BufferSizes() []int {
+ ints, err := b.stringToInt(b.bufferSize)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse file sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) FileSizes() []int64 {
+ ints, err := b.stringToInt64(b.sizes)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse file sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) PartSizes() []int64 {
+ ints, err := b.stringToInt64(b.parts)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse part sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) Concurrences() []int {
+ ints, err := b.stringToInt(b.concurrency)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse part sizes: %v", err))
+ }
+
+ return ints
+}
+
+func (b *BenchmarkConfig) stringToInt(s string) ([]int, error) {
+ int64s, err := b.stringToInt64(s)
+ if err != nil {
+ return nil, err
+ }
+
+ var ints []int
+ for i := range int64s {
+ ints = append(ints, int(int64s[i]))
+ }
+
+ return ints, nil
+}
+
+func (b *BenchmarkConfig) stringToInt64(s string) ([]int64, error) {
+ var sizes []int64
+
+ split := strings.Split(s, ",")
+
+ for _, size := range split {
+ size = strings.Trim(size, " ")
+ i, err := strconv.ParseInt(size, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid integer %s: %v", size, err)
+ }
+
+ sizes = append(sizes, i)
+ }
+
+ return sizes, nil
+}
+
+func BenchmarkUpload(b *testing.B) {
+ baseSdkConfig := SDKConfig{}
+
+ for _, fileSize := range benchConfig.FileSizes() {
+ b.Run(fmt.Sprintf("%s File", integration.SizeToName(int(fileSize))), func(b *testing.B) {
+ for _, concurrency := range benchConfig.Concurrences() {
+ b.Run(fmt.Sprintf("%d Concurrency", concurrency), func(b *testing.B) {
+ for _, partSize := range benchConfig.PartSizes() {
+ if partSize > fileSize {
+ continue
+ }
+ partSize = getUploadPartSize(fileSize, partSize)
+ b.Run(fmt.Sprintf("%s PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) {
+ for _, bufferSize := range benchConfig.BufferSizes() {
+ var name string
+ if bufferSize == 0 {
+ name = "Unbuffered"
+ } else {
+ name = fmt.Sprintf("%s Buffer", integration.SizeToName(bufferSize))
+ }
+ b.Run(name, func(b *testing.B) {
+ sdkConfig := baseSdkConfig
+
+ sdkConfig.Concurrency = concurrency
+ sdkConfig.PartSize = partSize
+ if bufferSize > 0 {
+ sdkConfig.BufferProvider = manager.NewBufferedReadSeekerWriteToPool(bufferSize)
+ }
+
+ for i := 0; i < b.N; i++ {
+ for {
+ b.ResetTimer()
+ reader := aws.ReadSeekCloser(io.LimitReader(&awstesting.EndlessReader{}, fileSize))
+ err := benchUpload(b, benchConfig.bucket, integration.MustUUID(), reader, sdkConfig, benchConfig.clientConfig)
+ if err != nil {
+ b.Logf("upload failed, retrying: %v", err)
+ continue
+ }
+ break
+ }
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func benchUpload(b *testing.B, bucket, key string, reader io.ReadSeeker, sdkConfig SDKConfig, clientConfig ClientConfig) error {
+ uploader := newUploader(clientConfig, sdkConfig)
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: &bucket,
+ Key: &key,
+ Body: reader,
+ })
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func TestMain(m *testing.M) {
+ if strings.EqualFold(os.Getenv("BUILD_ONLY"), "true") {
+ os.Exit(0)
+ }
+ benchConfig.SetupFlags("", flag.CommandLine)
+ flag.Parse()
+ os.Exit(m.Run())
+}
diff --git a/feature/s3/manager/internal/testing/endpoints.go b/feature/s3/manager/internal/testing/endpoints.go
new file mode 100644
index 00000000000..aa2f62ed72c
--- /dev/null
+++ b/feature/s3/manager/internal/testing/endpoints.go
@@ -0,0 +1,14 @@
+package testing
+
+import (
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+// EndpointResolverFunc is a mock s3 endpoint resolver that wraps the given function
+type EndpointResolverFunc func(region string, options s3.ResolverOptions) (aws.Endpoint, error)
+
+// ResolveEndpoint returns the results from the wrapped function.
+func (m EndpointResolverFunc) ResolveEndpoint(region string, options s3.ResolverOptions) (aws.Endpoint, error) {
+ return m(region, options)
+}
diff --git a/feature/s3/manager/internal/testing/rand.go b/feature/s3/manager/internal/testing/rand.go
new file mode 100644
index 00000000000..2c8d27e194f
--- /dev/null
+++ b/feature/s3/manager/internal/testing/rand.go
@@ -0,0 +1,28 @@
+package testing
+
+import (
+ "fmt"
+ "math/rand"
+
+ "github.com/aws/aws-sdk-go-v2/internal/sdkio"
+)
+
+var randBytes = func() []byte {
+ rr := rand.New(rand.NewSource(0))
+ b := make([]byte, 10*sdkio.MebiByte)
+
+ if _, err := rr.Read(b); err != nil {
+ panic(fmt.Sprintf("failed to read random bytes, %v", err))
+ }
+ return b
+}()
+
+// GetTestBytes returns a pseudo-random []byte of length size
+func GetTestBytes(size int) []byte {
+ if len(randBytes) >= size {
+ return randBytes[:size]
+ }
+
+ b := append(randBytes, GetTestBytes(size-len(randBytes))...)
+ return b
+}
diff --git a/feature/s3/manager/internal/testing/upload.go b/feature/s3/manager/internal/testing/upload.go
new file mode 100644
index 00000000000..067f35e5dd2
--- /dev/null
+++ b/feature/s3/manager/internal/testing/upload.go
@@ -0,0 +1,196 @@
+package testing
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "sync"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+// UploadLoggingClient is a mock client that can be used to record and stub responses for testing the s3manager.Uploader.
+type UploadLoggingClient struct {
+ Invocations []string
+ Params []interface{}
+
+ ConsumeBody bool
+
+ PutObjectFn func(*UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error)
+ UploadPartFn func(*UploadLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error)
+ CreateMultipartUploadFn func(*UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
+ CompleteMultipartUploadFn func(*UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
+ AbortMultipartUploadFn func(*UploadLoggingClient, *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
+
+ ignoredOperations []string
+
+ PartNum int
+ m sync.Mutex
+}
+
+func (u *UploadLoggingClient) simulateHTTPClientOption(optFns ...func(*s3.Options)) error {
+ o := s3.Options{
+ HTTPClient: httpDoFunc(func(request *http.Request) (*http.Response, error) {
+ return &http.Response{
+ Request: request,
+ }, nil
+ }),
+ }
+
+ for _, fn := range optFns {
+ fn(&o)
+ }
+
+ _, err := o.HTTPClient.Do(&http.Request{URL: &url.URL{
+ Scheme: "https",
+ Host: "mock.amazonaws.com",
+ Path: "/key",
+ RawQuery: "foo=bar",
+ }})
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+type httpDoFunc func(*http.Request) (*http.Response, error)
+
+func (f httpDoFunc) Do(r *http.Request) (*http.Response, error) {
+ return f(r)
+}
+
+func (u *UploadLoggingClient) traceOperation(name string, params interface{}) {
+ if contains(u.ignoredOperations, name) {
+ return
+ }
+
+ u.Invocations = append(u.Invocations, name)
+ u.Params = append(u.Params, params)
+}
+
+// PutObject is the S3 PutObject API.
+func (u *UploadLoggingClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ if u.ConsumeBody {
+ io.Copy(ioutil.Discard, params.Body)
+ }
+
+ u.traceOperation("PutObject", params)
+ if err := u.simulateHTTPClientOption(optFns...); err != nil {
+ return nil, err
+ }
+
+ if u.PutObjectFn != nil {
+ return u.PutObjectFn(u, params)
+ }
+
+ return &s3.PutObjectOutput{
+ VersionId: aws.String("VERSION-ID"),
+ }, nil
+}
+
+// UploadPart is the S3 UploadPart API.
+func (u *UploadLoggingClient) UploadPart(ctx context.Context, params *s3.UploadPartInput, optFns ...func(*s3.Options)) (*s3.UploadPartOutput, error) {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ if u.ConsumeBody {
+ io.Copy(ioutil.Discard, params.Body)
+ }
+
+ u.traceOperation("UploadPart", params)
+ if err := u.simulateHTTPClientOption(optFns...); err != nil {
+ return nil, err
+ }
+
+ u.PartNum++
+
+ if u.UploadPartFn != nil {
+ return u.UploadPartFn(u, params)
+ }
+
+ return &s3.UploadPartOutput{
+ ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum)),
+ }, nil
+}
+
+// CreateMultipartUpload is the S3 CreateMultipartUpload API.
+func (u *UploadLoggingClient) CreateMultipartUpload(ctx context.Context, params *s3.CreateMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ u.traceOperation("CreateMultipartUpload", params)
+ if err := u.simulateHTTPClientOption(optFns...); err != nil {
+ return nil, err
+ }
+
+ if u.CreateMultipartUploadFn != nil {
+ return u.CreateMultipartUploadFn(u, params)
+ }
+
+ return &s3.CreateMultipartUploadOutput{
+ UploadId: aws.String("UPLOAD-ID"),
+ }, nil
+}
+
+// CompleteMultipartUpload is the S3 CompleteMultipartUpload API.
+func (u *UploadLoggingClient) CompleteMultipartUpload(ctx context.Context, params *s3.CompleteMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ u.traceOperation("CompleteMultipartUpload", params)
+ if err := u.simulateHTTPClientOption(optFns...); err != nil {
+ return nil, err
+ }
+
+ if u.CompleteMultipartUploadFn != nil {
+ return u.CompleteMultipartUploadFn(u, params)
+ }
+
+ return &s3.CompleteMultipartUploadOutput{
+ Location: aws.String("http://location"),
+ VersionId: aws.String("VERSION-ID"),
+ }, nil
+}
+
+// AbortMultipartUpload is the S3 AbortMultipartUpload API.
+func (u *UploadLoggingClient) AbortMultipartUpload(ctx context.Context, params *s3.AbortMultipartUploadInput, optFns ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ u.traceOperation("AbortMultipartUpload", params)
+ if err := u.simulateHTTPClientOption(optFns...); err != nil {
+ return nil, err
+ }
+
+ if u.AbortMultipartUploadFn != nil {
+ return u.AbortMultipartUploadFn(u, params)
+ }
+
+ return &s3.AbortMultipartUploadOutput{}, nil
+}
+
+// NewUploadLoggingClient returns a new UploadLoggingClient.
+func NewUploadLoggingClient(ignoreOps []string) (*UploadLoggingClient, *[]string, *[]interface{}) {
+ client := &UploadLoggingClient{
+ ignoredOperations: ignoreOps,
+ }
+
+ return client, &client.Invocations, &client.Params
+}
+
+func contains(src []string, s string) bool {
+ for _, v := range src {
+ if s == v {
+ return true
+ }
+ }
+ return false
+}
diff --git a/feature/s3/manager/pool.go b/feature/s3/manager/pool.go
new file mode 100644
index 00000000000..6b93a3bc443
--- /dev/null
+++ b/feature/s3/manager/pool.go
@@ -0,0 +1,251 @@
+package manager
+
+import (
+ "context"
+ "fmt"
+ "sync"
+)
+
+type byteSlicePool interface {
+ Get(context.Context) (*[]byte, error)
+ Put(*[]byte)
+ ModifyCapacity(int)
+ SliceSize() int64
+ Close()
+}
+
+type maxSlicePool struct {
+ // allocator is defined as a function pointer to allow
+ // for test cases to instrument custom tracers when allocations
+ // occur.
+ allocator sliceAllocator
+
+ slices chan *[]byte
+ allocations chan struct{}
+ capacityChange chan struct{}
+
+ max int
+ sliceSize int64
+
+ mtx sync.RWMutex
+}
+
+func newMaxSlicePool(sliceSize int64) *maxSlicePool {
+ p := &maxSlicePool{sliceSize: sliceSize}
+ p.allocator = p.newSlice
+
+ return p
+}
+
+var errZeroCapacity = fmt.Errorf("get called on zero capacity pool")
+
+func (p *maxSlicePool) Get(ctx context.Context) (*[]byte, error) {
+ // check if context is canceled before attempting to get a slice
+ // this ensures priority is given to the cancel case first
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ p.mtx.RLock()
+
+ for {
+ select {
+ case bs, ok := <-p.slices:
+ p.mtx.RUnlock()
+ if !ok {
+ // attempt to get on a zero capacity pool
+ return nil, errZeroCapacity
+ }
+ return bs, nil
+ case <-ctx.Done():
+ p.mtx.RUnlock()
+ return nil, ctx.Err()
+ default:
+ // pass
+ }
+
+ select {
+ case _, ok := <-p.allocations:
+ p.mtx.RUnlock()
+ if !ok {
+ // attempt to get on a zero capacity pool
+ return nil, errZeroCapacity
+ }
+ return p.allocator(), nil
+ case <-ctx.Done():
+ p.mtx.RUnlock()
+ return nil, ctx.Err()
+ default:
+ // In the event that there are no slices or allocations available
+ // This prevents some deadlock situations that can occur around sync.RWMutex
+ // When a lock request occurs on ModifyCapacity, no new readers are allowed to acquire a read lock.
+ // By releasing the read lock here and waiting for a notification, we prevent a deadlock situation where
+ // Get could hold the read lock indefinitely waiting for capacity, ModifyCapacity is waiting for a write lock,
+ // and a Put is blocked trying to get a read-lock which is blocked by ModifyCapacity.
+
+ // Short-circuit if the pool capacity is zero.
+ if p.max == 0 {
+ p.mtx.RUnlock()
+ return nil, errZeroCapacity
+ }
+
+ // Since we will be releasing the read-lock we need to take the reference to the channel.
+ // Since channels are references we will still get notified if slices are added, or if
+ // the channel is closed due to a capacity modification. This specifically avoids a data race condition
+ // where ModifyCapacity both closes a channel and initializes a new one while we don't have a read-lock.
+ c := p.capacityChange
+
+ p.mtx.RUnlock()
+
+ select {
+ case _ = <-c:
+ p.mtx.RLock()
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ }
+ }
+}
+
+func (p *maxSlicePool) Put(bs *[]byte) {
+ p.mtx.RLock()
+ defer p.mtx.RUnlock()
+
+ if p.max == 0 {
+ return
+ }
+
+ select {
+ case p.slices <- bs:
+ p.notifyCapacity()
+ default:
+ // If the new channel when attempting to add the slice then we drop the slice.
+ // The logic here is to prevent a deadlock situation if channel is already at max capacity.
+ // Allows us to reap allocations that are returned and are no longer needed.
+ }
+}
+
+func (p *maxSlicePool) ModifyCapacity(delta int) {
+ if delta == 0 {
+ return
+ }
+
+ p.mtx.Lock()
+ defer p.mtx.Unlock()
+
+ p.max += delta
+
+ if p.max == 0 {
+ p.empty()
+ return
+ }
+
+ if p.capacityChange != nil {
+ close(p.capacityChange)
+ }
+ p.capacityChange = make(chan struct{}, p.max)
+
+ origAllocations := p.allocations
+ p.allocations = make(chan struct{}, p.max)
+
+ newAllocs := len(origAllocations) + delta
+ for i := 0; i < newAllocs; i++ {
+ p.allocations <- struct{}{}
+ }
+
+ if origAllocations != nil {
+ close(origAllocations)
+ }
+
+ origSlices := p.slices
+ p.slices = make(chan *[]byte, p.max)
+ if origSlices == nil {
+ return
+ }
+
+ close(origSlices)
+ for bs := range origSlices {
+ select {
+ case p.slices <- bs:
+ default:
+ // If the new channel blocks while adding slices from the old channel
+ // then we drop the slice. The logic here is to prevent a deadlock situation
+ // if the new channel has a smaller capacity then the old.
+ }
+ }
+}
+
+func (p *maxSlicePool) notifyCapacity() {
+ select {
+ case p.capacityChange <- struct{}{}:
+ default:
+ // This *shouldn't* happen as the channel is both buffered to the max pool capacity size and is resized
+ // on capacity modifications. This is just a safety to ensure that a blocking situation can't occur.
+ }
+}
+
+func (p *maxSlicePool) SliceSize() int64 {
+ return p.sliceSize
+}
+
+func (p *maxSlicePool) Close() {
+ p.mtx.Lock()
+ defer p.mtx.Unlock()
+ p.empty()
+}
+
+func (p *maxSlicePool) empty() {
+ p.max = 0
+
+ if p.capacityChange != nil {
+ close(p.capacityChange)
+ p.capacityChange = nil
+ }
+
+ if p.allocations != nil {
+ close(p.allocations)
+ for range p.allocations {
+ // drain channel
+ }
+ p.allocations = nil
+ }
+
+ if p.slices != nil {
+ close(p.slices)
+ for range p.slices {
+ // drain channel
+ }
+ p.slices = nil
+ }
+}
+
+func (p *maxSlicePool) newSlice() *[]byte {
+ bs := make([]byte, p.sliceSize)
+ return &bs
+}
+
+type returnCapacityPoolCloser struct {
+ byteSlicePool
+ returnCapacity int
+}
+
+func (n *returnCapacityPoolCloser) ModifyCapacity(delta int) {
+ if delta > 0 {
+ n.returnCapacity = -1 * delta
+ }
+ n.byteSlicePool.ModifyCapacity(delta)
+}
+
+func (n *returnCapacityPoolCloser) Close() {
+ if n.returnCapacity < 0 {
+ n.byteSlicePool.ModifyCapacity(n.returnCapacity)
+ }
+}
+
+type sliceAllocator func() *[]byte
+
+var newByteSlicePool = func(sliceSize int64) byteSlicePool {
+ return newMaxSlicePool(sliceSize)
+}
diff --git a/feature/s3/manager/pool_test.go b/feature/s3/manager/pool_test.go
new file mode 100644
index 00000000000..9c6302e9b39
--- /dev/null
+++ b/feature/s3/manager/pool_test.go
@@ -0,0 +1,197 @@
+package manager
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "testing"
+)
+
+func TestMaxSlicePool(t *testing.T) {
+ pool := newMaxSlicePool(0)
+
+ var wg sync.WaitGroup
+ for i := 0; i < 100; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ // increase pool capacity by 2
+ pool.ModifyCapacity(2)
+
+ // remove 2 items
+ bsOne, err := pool.Get(context.Background())
+ if err != nil {
+ t.Errorf("failed to get slice from pool: %v", err)
+ }
+ bsTwo, err := pool.Get(context.Background())
+ if err != nil {
+ t.Errorf("failed to get slice from pool: %v", err)
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+
+ // attempt to remove a 3rd in parallel
+ bs, err := pool.Get(context.Background())
+ if err != nil {
+ t.Errorf("failed to get slice from pool: %v", err)
+ }
+ pool.Put(bs)
+
+ // attempt to remove a 4th that has been canceled
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ bs, err = pool.Get(ctx)
+ if err == nil {
+ pool.Put(bs)
+ t.Errorf("expected no slice to be returned")
+ return
+ }
+ }()
+
+ pool.Put(bsOne)
+
+ <-done
+
+ pool.ModifyCapacity(-1)
+
+ pool.Put(bsTwo)
+
+ pool.ModifyCapacity(-1)
+
+ // any excess returns should drop
+ rando := make([]byte, 0)
+ pool.Put(&rando)
+ }()
+ }
+ wg.Wait()
+
+ if e, a := 0, len(pool.slices); e != a {
+ t.Errorf("expected %v, got %v", e, a)
+ }
+ if e, a := 0, len(pool.allocations); e != a {
+ t.Errorf("expected %v, got %v", e, a)
+ }
+ if e, a := 0, pool.max; e != a {
+ t.Errorf("expected %v, got %v", e, a)
+ }
+
+ _, err := pool.Get(context.Background())
+ if err == nil {
+ t.Errorf("expected error on zero capacity pool")
+ }
+
+ pool.Close()
+}
+
+func TestPoolShouldPreferAllocatedSlicesOverNewAllocations(t *testing.T) {
+ pool := newMaxSlicePool(0)
+ defer pool.Close()
+
+ // Prepare pool: make it so that pool contains 1 allocated slice and 1 allocation permit
+ pool.ModifyCapacity(2)
+ initialSlice, err := pool.Get(context.Background())
+ if err != nil {
+ t.Errorf("failed to get slice from pool: %v", err)
+ }
+ pool.Put(initialSlice)
+
+ for i := 0; i < 100; i++ {
+ newSlice, err := pool.Get(context.Background())
+ if err != nil {
+ t.Errorf("failed to get slice from pool: %v", err)
+ return
+ }
+
+ if newSlice != initialSlice {
+ t.Errorf("pool allocated a new slice despite it having pre-allocated one")
+ return
+ }
+ pool.Put(newSlice)
+ }
+}
+
+type recordedPartPool struct {
+ recordedAllocs uint64
+ recordedGets uint64
+ recordedOutstanding int64
+ *maxSlicePool
+}
+
+func newRecordedPartPool(sliceSize int64) *recordedPartPool {
+ sp := newMaxSlicePool(sliceSize)
+
+ rp := &recordedPartPool{}
+
+ allocator := sp.allocator
+ sp.allocator = func() *[]byte {
+ atomic.AddUint64(&rp.recordedAllocs, 1)
+ return allocator()
+ }
+
+ rp.maxSlicePool = sp
+
+ return rp
+}
+
+func (r *recordedPartPool) Get(ctx context.Context) (*[]byte, error) {
+ atomic.AddUint64(&r.recordedGets, 1)
+ atomic.AddInt64(&r.recordedOutstanding, 1)
+ return r.maxSlicePool.Get(ctx)
+}
+
+func (r *recordedPartPool) Put(b *[]byte) {
+ atomic.AddInt64(&r.recordedOutstanding, -1)
+ r.maxSlicePool.Put(b)
+}
+
+func swapByteSlicePool(f func(sliceSize int64) byteSlicePool) func() {
+ orig := newByteSlicePool
+
+ newByteSlicePool = f
+
+ return func() {
+ newByteSlicePool = orig
+ }
+}
+
+type syncSlicePool struct {
+ sync.Pool
+ sliceSize int64
+}
+
+func newSyncSlicePool(sliceSize int64) *syncSlicePool {
+ p := &syncSlicePool{sliceSize: sliceSize}
+ p.New = func() interface{} {
+ bs := make([]byte, p.sliceSize)
+ return &bs
+ }
+ return p
+}
+
+func (s *syncSlicePool) Get(ctx context.Context) (*[]byte, error) {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ return s.Pool.Get().(*[]byte), nil
+ }
+}
+
+func (s *syncSlicePool) Put(bs *[]byte) {
+ s.Pool.Put(bs)
+}
+
+func (s *syncSlicePool) ModifyCapacity(_ int) {
+ return
+}
+
+func (s *syncSlicePool) SliceSize() int64 {
+ return s.sliceSize
+}
+
+func (s *syncSlicePool) Close() {
+ return
+}
diff --git a/feature/s3/manager/read_seeker_write_to.go b/feature/s3/manager/read_seeker_write_to.go
new file mode 100644
index 00000000000..ce117c32a13
--- /dev/null
+++ b/feature/s3/manager/read_seeker_write_to.go
@@ -0,0 +1,65 @@
+package manager
+
+import (
+ "io"
+ "sync"
+)
+
+// ReadSeekerWriteTo defines an interface implementing io.WriteTo and io.ReadSeeker
+type ReadSeekerWriteTo interface {
+ io.ReadSeeker
+ io.WriterTo
+}
+
+// BufferedReadSeekerWriteTo wraps a BufferedReadSeeker with an io.WriteAt
+// implementation.
+type BufferedReadSeekerWriteTo struct {
+ *BufferedReadSeeker
+}
+
+// WriteTo writes to the given io.Writer from BufferedReadSeeker until there's no more data to write or
+// an error occurs. Returns the number of bytes written and any error encountered during the write.
+func (b *BufferedReadSeekerWriteTo) WriteTo(writer io.Writer) (int64, error) {
+ return io.Copy(writer, b.BufferedReadSeeker)
+}
+
+// ReadSeekerWriteToProvider provides an implementation of io.WriteTo for an io.ReadSeeker
+type ReadSeekerWriteToProvider interface {
+ GetWriteTo(seeker io.ReadSeeker) (r ReadSeekerWriteTo, cleanup func())
+}
+
+// BufferedReadSeekerWriteToPool uses a sync.Pool to create and reuse
+// []byte slices for buffering parts in memory
+type BufferedReadSeekerWriteToPool struct {
+ pool sync.Pool
+}
+
+// NewBufferedReadSeekerWriteToPool will return a new BufferedReadSeekerWriteToPool that will create
+// a pool of reusable buffers . If size is less then < 64 KiB then the buffer
+// will default to 64 KiB. Reason: io.Copy from writers or readers that don't support io.WriteTo or io.ReadFrom
+// respectively will default to copying 32 KiB.
+func NewBufferedReadSeekerWriteToPool(size int) *BufferedReadSeekerWriteToPool {
+ if size < 65536 {
+ size = 65536
+ }
+
+ return &BufferedReadSeekerWriteToPool{
+ pool: sync.Pool{New: func() interface{} {
+ return make([]byte, size)
+ }},
+ }
+}
+
+// GetWriteTo will wrap the provided io.ReadSeeker with a BufferedReadSeekerWriteTo.
+// The provided cleanup must be called after operations have been completed on the
+// returned io.ReadSeekerWriteTo in order to signal the return of resources to the pool.
+func (p *BufferedReadSeekerWriteToPool) GetWriteTo(seeker io.ReadSeeker) (r ReadSeekerWriteTo, cleanup func()) {
+ buffer := p.pool.Get().([]byte)
+
+ r = &BufferedReadSeekerWriteTo{BufferedReadSeeker: NewBufferedReadSeeker(seeker, buffer)}
+ cleanup = func() {
+ p.pool.Put(buffer)
+ }
+
+ return r, cleanup
+}
diff --git a/feature/s3/manager/shared_test.go b/feature/s3/manager/shared_test.go
new file mode 100644
index 00000000000..ab2deef300d
--- /dev/null
+++ b/feature/s3/manager/shared_test.go
@@ -0,0 +1,4 @@
+package manager_test
+
+var buf12MB = make([]byte, 1024*1024*12)
+var buf2MB = make([]byte, 1024*1024*2)
diff --git a/feature/s3/manager/upload.go b/feature/s3/manager/upload.go
new file mode 100644
index 00000000000..90aad5a9fd4
--- /dev/null
+++ b/feature/s3/manager/upload.go
@@ -0,0 +1,685 @@
+package manager
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "sort"
+ "sync"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws/middleware"
+ "github.com/aws/aws-sdk-go-v2/internal/awsutil"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/aws/aws-sdk-go-v2/service/s3/types"
+)
+
+// MaxUploadParts is the maximum allowed number of parts in a multi-part upload
+// on Amazon S3.
+const MaxUploadParts int32 = 10000
+
+// MinUploadPartSize is the minimum allowed part size when uploading a part to
+// Amazon S3.
+const MinUploadPartSize int64 = 1024 * 1024 * 5
+
+// DefaultUploadPartSize is the default part size to buffer chunks of a
+// payload into.
+const DefaultUploadPartSize = MinUploadPartSize
+
+// DefaultUploadConcurrency is the default number of goroutines to spin up when
+// using Upload().
+const DefaultUploadConcurrency = 5
+
+// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned
+// will satisfy this interface when a multi part upload failed to upload all
+// chucks to S3. In the case of a failure the UploadID is needed to operate on
+// the chunks, if any, which were uploaded.
+//
+// Example:
+//
+// u := s3manager.NewUploader(client)
+// output, err := u.upload(context.Background(), input)
+// if err != nil {
+// var multierr s3manager.MultiUploadFailure
+// if errors.As(err, &multierr) {
+// fmt.Printf("upload failure UploadID=%s, %s\n", multierr.UploadID(), multierr.Error())
+// } else {
+// fmt.Printf("upload failure, %s\n", err.Error())
+// }
+// }
+//
+type MultiUploadFailure interface {
+ error
+
+ // UploadID returns the upload id for the S3 multipart upload that failed.
+ UploadID() string
+}
+
+// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
+// Composed of BaseError for code, message, and original error
+//
+// Should be used for an error that occurred failing a S3 multipart upload,
+// and a upload ID is available. If an uploadID is not available a more relevant
+type multiUploadError struct {
+ err error
+
+ // ID for multipart upload which failed.
+ uploadID string
+}
+
+// batchItemError returns the string representation of the error.
+//
+// See apierr.BaseError ErrorWithExtra for output format
+//
+// Satisfies the error interface.
+func (m *multiUploadError) Error() string {
+ var extra string
+ if m.err != nil {
+ extra = fmt.Sprintf(", cause: %s", m.err.Error())
+ }
+ return fmt.Sprintf("upload multipart failed, upload id: %s%s", m.uploadID, extra)
+}
+
+// Unwrap returns the underlying error that cause the upload failure
+func (m *multiUploadError) Unwrap() error {
+ return m.err
+}
+
+// UploadID returns the id of the S3 upload which failed.
+func (m *multiUploadError) UploadID() string {
+ return m.uploadID
+}
+
+// UploadOutput represents a response from the Upload() call.
+type UploadOutput struct {
+ // The URL where the object was uploaded to.
+ Location string
+
+ // The version of the object that was uploaded. Will only be populated if
+ // the S3 Bucket is versioned. If the bucket is not versioned this field
+ // will not be set.
+ VersionID *string
+
+ // The ID for a multipart upload to S3. In the case of an error the error
+ // can be cast to the MultiUploadFailure interface to extract the upload ID.
+ UploadID string
+}
+
+// WithUploaderRequestOptions appends to the Uploader's API client options.
+func WithUploaderRequestOptions(opts ...func(*s3.Options)) func(*Uploader) {
+ return func(u *Uploader) {
+ u.ClientOptions = append(u.ClientOptions, opts...)
+ }
+}
+
+// The Uploader structure that calls Upload(). It is safe to call Upload()
+// on this structure for multiple objects and across concurrent goroutines.
+// Mutating the Uploader's properties is not safe to be done concurrently.
+type Uploader struct {
+ // The buffer size (in bytes) to use when buffering data into chunks and
+ // sending them as parts to S3. The minimum allowed part size is 5MB, and
+ // if this value is set to zero, the DefaultUploadPartSize value will be used.
+ PartSize int64
+
+ // The number of goroutines to spin up in parallel per call to Upload when
+ // sending parts. If this is set to zero, the DefaultUploadConcurrency value
+ // will be used.
+ //
+ // The concurrency pool is not shared between calls to Upload.
+ Concurrency int
+
+ // Setting this value to true will cause the SDK to avoid calling
+ // AbortMultipartUpload on a failure, leaving all successfully uploaded
+ // parts on S3 for manual recovery.
+ //
+ // Note that storing parts of an incomplete multipart upload counts towards
+ // space usage on S3 and will add additional costs if not cleaned up.
+ LeavePartsOnError bool
+
+ // MaxUploadParts is the max number of parts which will be uploaded to S3.
+ // Will be used to calculate the partsize of the object to be uploaded.
+ // E.g: 5GB file, with MaxUploadParts set to 100, will upload the file
+ // as 100, 50MB parts. With a limited of s3.MaxUploadParts (10,000 parts).
+ //
+ // MaxUploadParts must not be used to limit the total number of bytes uploaded.
+ // Use a type like to io.LimitReader (https://golang.org/pkg/io/#LimitedReader)
+ // instead. An io.LimitReader is helpful when uploading an unbounded reader
+ // to S3, and you know its maximum size. Otherwise the reader's io.EOF returned
+ // error must be used to signal end of stream.
+ //
+ // Defaults to package const's MaxUploadParts value.
+ MaxUploadParts int32
+
+ // The client to use when uploading to S3.
+ S3 UploadAPIClient
+
+ // List of request options that will be passed down to individual API
+ // operation requests made by the uploader.
+ ClientOptions []func(*s3.Options)
+
+ // Defines the buffer strategy used when uploading a part
+ BufferProvider ReadSeekerWriteToProvider
+
+ // partPool allows for the re-usage of streaming payload part buffers between upload calls
+ partPool byteSlicePool
+}
+
+// NewUploader creates a new Uploader instance to upload objects to S3. Pass In
+// additional functional options to customize the uploader's behavior. Requires a
+// client.ConfigProvider in order to create a S3 service client. The session.Session
+// satisfies the client.ConfigProvider interface.
+//
+// Example:
+// // Load AWS Config
+// cfg, err := config.LoadDefaultConfig()
+// if err != nil {
+// panic(err)
+// }
+//
+// // Create an S3 Client with the config
+// client := s3.NewFromConfig(cfg)
+//
+// // Create an uploader passing it the client
+// uploader := s3manager.NewUploader(client)
+//
+// // Create an uploader with the client and custom options
+// uploader := s3manager.NewUploader(client, func(u *s3manager.Uploader) {
+// u.PartSize = 64 * 1024 * 1024 // 64MB per part
+// })
+func NewUploader(client UploadAPIClient, options ...func(*Uploader)) *Uploader {
+ u := &Uploader{
+ S3: client,
+ PartSize: DefaultUploadPartSize,
+ Concurrency: DefaultUploadConcurrency,
+ LeavePartsOnError: false,
+ MaxUploadParts: MaxUploadParts,
+ BufferProvider: defaultUploadBufferProvider(),
+ }
+
+ for _, option := range options {
+ option(u)
+ }
+
+ u.partPool = newByteSlicePool(u.PartSize)
+
+ return u
+}
+
+// Upload uploads an object to S3, intelligently buffering large
+// files into smaller chunks and sending them in parallel across multiple
+// goroutines. You can configure the buffer size and concurrency through the
+// Uploader parameters.
+//
+// Additional functional options can be provided to configure the individual
+// upload. These options are copies of the Uploader instance Upload is called from.
+// Modifying the options will not impact the original Uploader instance.
+//
+// Use the WithUploaderRequestOptions helper function to pass in request
+// options that will be applied to all API operations made with this uploader.
+//
+// It is safe to call this method concurrently across goroutines.
+func (u Uploader) Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*Uploader)) (*UploadOutput, error) {
+ i := uploader{in: input, cfg: u, ctx: ctx}
+
+ // Copy ClientOptions
+ clientOptions := make([]func(*s3.Options), 0, len(i.cfg.ClientOptions)+1)
+ clientOptions = append(clientOptions, func(o *s3.Options) {
+ o.APIOptions = append(o.APIOptions, middleware.AddUserAgentKey(userAgentKey))
+ })
+ clientOptions = append(clientOptions, i.cfg.ClientOptions...)
+ i.cfg.ClientOptions = clientOptions
+
+ for _, opt := range opts {
+ opt(&i.cfg)
+ }
+
+ return i.upload()
+}
+
+// internal structure to manage an upload to S3.
+type uploader struct {
+ ctx context.Context
+ cfg Uploader
+
+ in *s3.PutObjectInput
+
+ readerPos int64 // current reader position
+ totalSize int64 // set to -1 if the size is not known
+}
+
+// internal logic for deciding whether to upload a single part or use a
+// multipart upload.
+func (u *uploader) upload() (*UploadOutput, error) {
+ if err := u.init(); err != nil {
+ return nil, fmt.Errorf("unable to initialize upload: %w", err)
+ }
+ defer u.cfg.partPool.Close()
+
+ if u.cfg.PartSize < MinUploadPartSize {
+ return nil, fmt.Errorf("part size must be at least %d bytes", MinUploadPartSize)
+ }
+
+ // Do one read to determine if we have more than one part
+ reader, _, cleanup, err := u.nextReader()
+ if err == io.EOF { // single part
+ return u.singlePart(reader, cleanup)
+ } else if err != nil {
+ cleanup()
+ return nil, fmt.Errorf("read upload data failed: %w", err)
+ }
+
+ mu := multiuploader{uploader: u}
+ return mu.upload(reader, cleanup)
+}
+
+// init will initialize all default options.
+func (u *uploader) init() error {
+ if u.cfg.Concurrency == 0 {
+ u.cfg.Concurrency = DefaultUploadConcurrency
+ }
+ if u.cfg.PartSize == 0 {
+ u.cfg.PartSize = DefaultUploadPartSize
+ }
+ if u.cfg.MaxUploadParts == 0 {
+ u.cfg.MaxUploadParts = MaxUploadParts
+ }
+
+ // Try to get the total size for some optimizations
+ if err := u.initSize(); err != nil {
+ return err
+ }
+
+ // If PartSize was changed or partPool was never setup then we need to allocated a new pool
+ // so that we return []byte slices of the correct size
+ poolCap := u.cfg.Concurrency + 1
+ if u.cfg.partPool == nil || u.cfg.partPool.SliceSize() != u.cfg.PartSize {
+ u.cfg.partPool = newByteSlicePool(u.cfg.PartSize)
+ u.cfg.partPool.ModifyCapacity(poolCap)
+ } else {
+ u.cfg.partPool = &returnCapacityPoolCloser{byteSlicePool: u.cfg.partPool}
+ u.cfg.partPool.ModifyCapacity(poolCap)
+ }
+
+ return nil
+}
+
+// initSize tries to detect the total stream size, setting u.totalSize. If
+// the size is not known, totalSize is set to -1.
+func (u *uploader) initSize() error {
+ u.totalSize = -1
+
+ switch r := u.in.Body.(type) {
+ case io.Seeker:
+ n, err := aws.SeekerLen(r)
+ if err != nil {
+ return err
+ }
+ u.totalSize = n
+
+ // Try to adjust partSize if it is too small and account for
+ // integer division truncation.
+ if u.totalSize/u.cfg.PartSize >= int64(u.cfg.MaxUploadParts) {
+ // Add one to the part size to account for remainders
+ // during the size calculation. e.g odd number of bytes.
+ u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1
+ }
+ }
+
+ return nil
+}
+
+// nextReader returns a seekable reader representing the next packet of data.
+// This operation increases the shared u.readerPos counter, but note that it
+// does not need to be wrapped in a mutex because nextReader is only called
+// from the main thread.
+func (u *uploader) nextReader() (io.ReadSeeker, int, func(), error) {
+ switch r := u.in.Body.(type) {
+ case readerAtSeeker:
+ var err error
+
+ n := u.cfg.PartSize
+ if u.totalSize >= 0 {
+ bytesLeft := u.totalSize - u.readerPos
+
+ if bytesLeft <= u.cfg.PartSize {
+ err = io.EOF
+ n = bytesLeft
+ }
+ }
+
+ var (
+ reader io.ReadSeeker
+ cleanup func()
+ )
+
+ reader = io.NewSectionReader(r, u.readerPos, n)
+ if u.cfg.BufferProvider != nil {
+ reader, cleanup = u.cfg.BufferProvider.GetWriteTo(reader)
+ } else {
+ cleanup = func() {}
+ }
+
+ u.readerPos += n
+
+ return reader, int(n), cleanup, err
+
+ default:
+ part, err := u.cfg.partPool.Get(u.ctx)
+ if err != nil {
+ return nil, 0, func() {}, err
+ }
+
+ n, err := readFillBuf(r, *part)
+ u.readerPos += int64(n)
+
+ cleanup := func() {
+ u.cfg.partPool.Put(part)
+ }
+
+ return bytes.NewReader((*part)[0:n]), n, cleanup, err
+ }
+}
+
+func readFillBuf(r io.Reader, b []byte) (offset int, err error) {
+ for offset < len(b) && err == nil {
+ var n int
+ n, err = r.Read(b[offset:])
+ offset += n
+ }
+
+ return offset, err
+}
+
+// singlePart contains upload logic for uploading a single chunk via
+// a regular PutObject request. Multipart requests require at least two
+// parts, or at least 5MB of data.
+func (u *uploader) singlePart(r io.ReadSeeker, cleanup func()) (*UploadOutput, error) {
+ defer cleanup()
+
+ params := &s3.PutObjectInput{}
+ awsutil.Copy(params, u.in)
+ params.Body = r
+
+ // Need to use request form because URL generated in request is
+ // used in return.
+
+ var locationRecorder recordLocationClient
+ out, err := u.cfg.S3.PutObject(u.ctx, params, append(u.cfg.ClientOptions, locationRecorder.WrapClient())...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UploadOutput{
+ Location: locationRecorder.location,
+ VersionID: out.VersionId,
+ }, nil
+}
+
+type httpClient interface {
+ Do(r *http.Request) (*http.Response, error)
+}
+
+type recordLocationClient struct {
+ httpClient
+ location string
+}
+
+func (c *recordLocationClient) WrapClient() func(o *s3.Options) {
+ return func(o *s3.Options) {
+ c.httpClient = o.HTTPClient
+ o.HTTPClient = c
+ }
+}
+
+func (c *recordLocationClient) Do(r *http.Request) (resp *http.Response, err error) {
+ resp, err = c.httpClient.Do(r)
+ if err != nil {
+ return resp, err
+ }
+
+ if resp.Request != nil && resp.Request.URL != nil {
+ url := *resp.Request.URL
+ url.RawQuery = ""
+ c.location = url.String()
+ }
+
+ return resp, err
+}
+
+// internal structure to manage a specific multipart upload to S3.
+type multiuploader struct {
+ *uploader
+ wg sync.WaitGroup
+ m sync.Mutex
+ err error
+ uploadID string
+ parts completedParts
+}
+
+// keeps track of a single chunk of data being sent to S3.
+type chunk struct {
+ buf io.ReadSeeker
+ num int32
+ cleanup func()
+}
+
+// completedParts is a wrapper to make parts sortable by their part number,
+// since S3 required this list to be sent in sorted order.
+type completedParts []*types.CompletedPart
+
+func (a completedParts) Len() int { return len(a) }
+func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
+
+// upload will perform a multipart upload using the firstBuf buffer containing
+// the first chunk of data.
+func (u *multiuploader) upload(firstBuf io.ReadSeeker, cleanup func()) (*UploadOutput, error) {
+ params := &s3.CreateMultipartUploadInput{}
+ awsutil.Copy(params, u.in)
+
+ // Create the multipart
+ var locationRecorder recordLocationClient
+ resp, err := u.cfg.S3.CreateMultipartUpload(u.ctx, params, append(u.cfg.ClientOptions, locationRecorder.WrapClient())...)
+ if err != nil {
+ cleanup()
+ return nil, err
+ }
+ u.uploadID = *resp.UploadId
+
+ // Create the workers
+ ch := make(chan chunk, u.cfg.Concurrency)
+ for i := 0; i < u.cfg.Concurrency; i++ {
+ u.wg.Add(1)
+ go u.readChunk(ch)
+ }
+
+ // Send part 1 to the workers
+ var num int32 = 1
+ ch <- chunk{buf: firstBuf, num: num, cleanup: cleanup}
+
+ // Read and queue the rest of the parts
+ for u.geterr() == nil && err == nil {
+ var (
+ reader io.ReadSeeker
+ nextChunkLen int
+ ok bool
+ )
+
+ reader, nextChunkLen, cleanup, err = u.nextReader()
+ ok, err = u.shouldContinue(num, nextChunkLen, err)
+ if !ok {
+ cleanup()
+ if err != nil {
+ u.seterr(err)
+ }
+ break
+ }
+
+ num++
+
+ ch <- chunk{buf: reader, num: num, cleanup: cleanup}
+ }
+
+ // Close the channel, wait for workers, and complete upload
+ close(ch)
+ u.wg.Wait()
+ complete := u.complete()
+
+ if err := u.geterr(); err != nil {
+ return nil, &multiUploadError{
+ err: err,
+ uploadID: u.uploadID,
+ }
+ }
+
+ return &UploadOutput{
+ Location: locationRecorder.location,
+ VersionID: complete.VersionId,
+ UploadID: u.uploadID,
+ }, nil
+}
+
+func (u *multiuploader) shouldContinue(part int32, nextChunkLen int, err error) (bool, error) {
+ if err != nil && err != io.EOF {
+ return false, fmt.Errorf("read multipart upload data failed, %w", err)
+ }
+
+ if nextChunkLen == 0 {
+ // No need to upload empty part, if file was empty to start
+ // with empty single part would of been created and never
+ // started multipart upload.
+ return false, nil
+ }
+
+ part++
+ // This upload exceeded maximum number of supported parts, error now.
+ if part > u.cfg.MaxUploadParts || part > MaxUploadParts {
+ var msg string
+ if part > u.cfg.MaxUploadParts {
+ msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit",
+ u.cfg.MaxUploadParts)
+ } else {
+ msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit",
+ MaxUploadParts)
+ }
+ return false, fmt.Errorf(msg)
+ }
+
+ return true, err
+}
+
+// readChunk runs in worker goroutines to pull chunks off of the ch channel
+// and send() them as UploadPart requests.
+func (u *multiuploader) readChunk(ch chan chunk) {
+ defer u.wg.Done()
+ for {
+ data, ok := <-ch
+
+ if !ok {
+ break
+ }
+
+ if u.geterr() == nil {
+ if err := u.send(data); err != nil {
+ u.seterr(err)
+ }
+ }
+
+ data.cleanup()
+ }
+}
+
+// send performs an UploadPart request and keeps track of the completed
+// part information.
+func (u *multiuploader) send(c chunk) error {
+ params := &s3.UploadPartInput{
+ Bucket: u.in.Bucket,
+ Key: u.in.Key,
+ Body: c.buf,
+ UploadId: &u.uploadID,
+ SSECustomerAlgorithm: u.in.SSECustomerAlgorithm,
+ SSECustomerKey: u.in.SSECustomerKey,
+ PartNumber: &c.num,
+ }
+
+ resp, err := u.cfg.S3.UploadPart(u.ctx, params, u.cfg.ClientOptions...)
+ if err != nil {
+ return err
+ }
+
+ n := c.num
+ completed := &types.CompletedPart{ETag: resp.ETag, PartNumber: &n}
+
+ u.m.Lock()
+ u.parts = append(u.parts, completed)
+ u.m.Unlock()
+
+ return nil
+}
+
+// geterr is a thread-safe getter for the error object
+func (u *multiuploader) geterr() error {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ return u.err
+}
+
+// seterr is a thread-safe setter for the error object
+func (u *multiuploader) seterr(e error) {
+ u.m.Lock()
+ defer u.m.Unlock()
+
+ u.err = e
+}
+
+// fail will abort the multipart unless LeavePartsOnError is set to true.
+func (u *multiuploader) fail() {
+ if u.cfg.LeavePartsOnError {
+ return
+ }
+
+ params := &s3.AbortMultipartUploadInput{
+ Bucket: u.in.Bucket,
+ Key: u.in.Key,
+ UploadId: &u.uploadID,
+ }
+ _, err := u.cfg.S3.AbortMultipartUpload(u.ctx, params, u.cfg.ClientOptions...)
+ if err != nil {
+ // TODO: Add logging
+ //logMessage(u.cfg.S3, aws.LogDebug, fmt.Sprintf("failed to abort multipart upload, %v", err))
+ _ = err
+ }
+}
+
+// complete successfully completes a multipart upload and returns the response.
+func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
+ if u.geterr() != nil {
+ u.fail()
+ return nil
+ }
+
+ // Parts must be sorted in PartNumber order.
+ sort.Sort(u.parts)
+
+ params := &s3.CompleteMultipartUploadInput{
+ Bucket: u.in.Bucket,
+ Key: u.in.Key,
+ UploadId: &u.uploadID,
+ MultipartUpload: &types.CompletedMultipartUpload{Parts: u.parts},
+ }
+ resp, err := u.cfg.S3.CompleteMultipartUpload(u.ctx, params, u.cfg.ClientOptions...)
+ if err != nil {
+ u.seterr(err)
+ u.fail()
+ }
+
+ return resp
+}
+
+type readerAtSeeker interface {
+ io.ReaderAt
+ io.ReadSeeker
+}
diff --git a/feature/s3/manager/upload_internal_test.go b/feature/s3/manager/upload_internal_test.go
new file mode 100644
index 00000000000..03088bf0df8
--- /dev/null
+++ b/feature/s3/manager/upload_internal_test.go
@@ -0,0 +1,320 @@
+package manager
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
+ "github.com/aws/aws-sdk-go-v2/internal/sdkio"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+)
+
+type testReader struct {
+ br *bytes.Reader
+ m sync.Mutex
+}
+
+func (r *testReader) Read(p []byte) (n int, err error) {
+ r.m.Lock()
+ defer r.m.Unlock()
+ return r.br.Read(p)
+}
+
+func TestUploadByteSlicePool(t *testing.T) {
+ cases := map[string]struct {
+ PartSize int64
+ FileSize int64
+ Concurrency int
+ ExAllocations uint64
+ }{
+ "single part, single concurrency": {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 5,
+ ExAllocations: 2,
+ Concurrency: 1,
+ },
+ "multi-part, single concurrency": {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 10,
+ ExAllocations: 2,
+ Concurrency: 1,
+ },
+ "multi-part, multiple concurrency": {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 20,
+ ExAllocations: 3,
+ Concurrency: 2,
+ },
+ }
+
+ for name, tt := range cases {
+ t.Run(name, func(t *testing.T) {
+ var p *recordedPartPool
+
+ unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
+ p = newRecordedPartPool(sliceSize)
+ return p
+ })
+ defer unswap()
+
+ client, _, _ := s3testing.NewUploadLoggingClient(nil)
+
+ uploader := NewUploader(client, func(u *Uploader) {
+ u.PartSize = tt.PartSize
+ u.Concurrency = tt.Concurrency
+ })
+
+ expected := s3testing.GetTestBytes(int(tt.FileSize))
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: &testReader{br: bytes.NewReader(expected)},
+ })
+ if err != nil {
+ t.Errorf("expected no error, but got %v", err)
+ }
+
+ if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
+ t.Fatalf("expected zero outsnatding pool parts, got %d", v)
+ }
+
+ gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs)
+
+ t.Logf("total gets %v, total allocations %v", gets, allocs)
+ if e, a := tt.ExAllocations, allocs; a > e {
+ t.Errorf("expected %v allocations, got %v", e, a)
+ }
+ })
+ }
+}
+
+func TestUploadByteSlicePool_Failures(t *testing.T) {
+ const (
+ putObject = "PutObject"
+ createMultipartUpload = "CreateMultipartUpload"
+ uploadPart = "UploadPart"
+ completeMultipartUpload = "CompleteMultipartUpload"
+ )
+
+ cases := map[string]struct {
+ PartSize int64
+ FileSize int64
+ Operations []string
+ }{
+ "single part": {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 4,
+ Operations: []string{
+ putObject,
+ },
+ },
+ "multi-part": {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 10,
+ Operations: []string{
+ createMultipartUpload,
+ uploadPart,
+ completeMultipartUpload,
+ },
+ },
+ }
+
+ for name, tt := range cases {
+ t.Run(name, func(t *testing.T) {
+ for _, operation := range tt.Operations {
+ t.Run(operation, func(t *testing.T) {
+ var p *recordedPartPool
+
+ unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
+ p = newRecordedPartPool(sliceSize)
+ return p
+ })
+ defer unswap()
+
+ client, _, _ := s3testing.NewUploadLoggingClient(nil)
+
+ switch operation {
+ case putObject:
+ client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
+ return nil, fmt.Errorf("put object failure")
+ }
+ case createMultipartUpload:
+ client.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
+ return nil, fmt.Errorf("create multipart upload failure")
+ }
+ case uploadPart:
+ client.UploadPartFn = func(*s3testing.UploadLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
+ return nil, fmt.Errorf("upload part failure")
+ }
+ case completeMultipartUpload:
+ client.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) {
+ return nil, fmt.Errorf("complete multipart upload failure")
+ }
+ }
+
+ uploader := NewUploader(client, func(u *Uploader) {
+ u.Concurrency = 1
+ u.PartSize = tt.PartSize
+ })
+
+ expected := s3testing.GetTestBytes(int(tt.FileSize))
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: &testReader{br: bytes.NewReader(expected)},
+ })
+ if err == nil {
+ t.Fatalf("expected error but got none")
+ }
+
+ if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
+ t.Fatalf("expected zero outsnatding pool parts, got %d", v)
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) {
+ var (
+ pools []*recordedPartPool
+ mtx sync.Mutex
+ )
+
+ unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
+ mtx.Lock()
+ defer mtx.Unlock()
+ b := newRecordedPartPool(sliceSize)
+ pools = append(pools, b)
+ return b
+ })
+ defer unswap()
+
+ client, _, _ := s3testing.NewUploadLoggingClient(nil)
+
+ uploader := NewUploader(client, func(u *Uploader) {
+ u.PartSize = 5 * sdkio.MebiByte
+ u.Concurrency = 2
+ })
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: &testReader{br: bytes.NewReader(expected)},
+ })
+ if err != nil {
+ t.Errorf("expected no error, but got %v", err)
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: &testReader{br: bytes.NewReader(expected)},
+ }, func(u *Uploader) {
+ u.PartSize = 6 * sdkio.MebiByte
+ })
+ if err != nil {
+ t.Errorf("expected no error, but got %v", err)
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ if e, a := 3, len(pools); e != a {
+ t.Errorf("expected %v, got %v", e, a)
+ }
+
+ for _, p := range pools {
+ if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
+ t.Fatalf("expected zero outsnatding pool parts, got %d", v)
+ }
+
+ t.Logf("total gets %v, total allocations %v",
+ atomic.LoadUint64(&p.recordedGets),
+ atomic.LoadUint64(&p.recordedAllocs))
+ }
+}
+
+func BenchmarkPools(b *testing.B) {
+ cases := []struct {
+ PartSize int64
+ FileSize int64
+ Concurrency int
+ ExAllocations uint64
+ }{
+ 0: {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 5,
+ Concurrency: 1,
+ },
+ 1: {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 10,
+ Concurrency: 1,
+ },
+ 2: {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 20,
+ Concurrency: 2,
+ },
+ 3: {
+ PartSize: sdkio.MebiByte * 5,
+ FileSize: sdkio.MebiByte * 250,
+ Concurrency: 10,
+ },
+ }
+
+ client, _, _ := s3testing.NewUploadLoggingClient(nil)
+
+ pools := map[string]func(sliceSize int64) byteSlicePool{
+ "sync.Pool": func(sliceSize int64) byteSlicePool {
+ return newSyncSlicePool(sliceSize)
+ },
+ "custom": func(sliceSize int64) byteSlicePool {
+ return newMaxSlicePool(sliceSize)
+ },
+ }
+
+ for name, poolFunc := range pools {
+ b.Run(name, func(b *testing.B) {
+ unswap := swapByteSlicePool(poolFunc)
+ defer unswap()
+ for i, c := range cases {
+ b.Run(strconv.Itoa(i), func(b *testing.B) {
+ uploader := NewUploader(client, func(u *Uploader) {
+ u.PartSize = c.PartSize
+ u.Concurrency = c.Concurrency
+ })
+
+ expected := s3testing.GetTestBytes(int(c.FileSize))
+ b.ResetTimer()
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: &testReader{br: bytes.NewReader(expected)},
+ })
+ if err != nil {
+ b.Fatalf("expected no error, but got %v", err)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/feature/s3/manager/upload_test.go b/feature/s3/manager/upload_test.go
new file mode 100644
index 00000000000..17a42c5336d
--- /dev/null
+++ b/feature/s3/manager/upload_test.go
@@ -0,0 +1,1134 @@
+package manager_test
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "reflect"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws/retry"
+ "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
+ s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
+ "github.com/aws/aws-sdk-go-v2/internal/awstesting"
+ "github.com/aws/aws-sdk-go-v2/internal/sdk"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/aws/aws-sdk-go-v2/service/s3/types"
+ "github.com/google/go-cmp/cmp"
+)
+
+// getReaderLength discards the bytes from reader and returns the length
+func getReaderLength(r io.Reader) int64 {
+ n, _ := io.Copy(ioutil.Discard, r)
+ return n
+}
+
+func TestUploadOrderMulti(t *testing.T) {
+ c, invocations, args := s3testing.NewUploadLoggingClient(nil)
+ u := manager.NewUploader(c)
+
+ resp, err := u.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key - value"),
+ Body: bytes.NewReader(buf12MB),
+ ServerSideEncryption: types.ServerSideEncryptionAwsKms,
+ SSEKMSKeyId: aws.String("KmsId"),
+ ContentType: aws.String("content/type"),
+ })
+
+ if err != nil {
+ t.Errorf("Expected no error but received %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
+ "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(err)
+ }
+
+ if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ if "UPLOAD-ID" != resp.UploadID {
+ t.Errorf("expect %q, got %q", "UPLOAD-ID", resp.UploadID)
+ }
+
+ if "VERSION-ID" != *resp.VersionID {
+ t.Errorf("expect %q, got %q", "VERSION-ID", *resp.VersionID)
+ }
+
+ // Validate input values
+
+ // UploadPart
+ for i := 1; i < 4; i++ {
+ v := aws.ToString((*args)[i].(*s3.UploadPartInput).UploadId)
+ if "UPLOAD-ID" != v {
+ t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
+ }
+ }
+
+ // CompleteMultipartUpload
+ v := aws.ToString((*args)[4].(*s3.CompleteMultipartUploadInput).UploadId)
+ if "UPLOAD-ID" != v {
+ t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
+ }
+
+ parts := (*args)[4].(*s3.CompleteMultipartUploadInput).MultipartUpload.Parts
+
+ for i := 0; i < 3; i++ {
+ num := aws.ToInt32(parts[i].PartNumber)
+ etag := aws.ToString(parts[i].ETag)
+
+ if int32(i+1) != num {
+ t.Errorf("expect %d, got %d", i+1, num)
+ }
+
+ if matched, err := regexp.MatchString(`^ETAG\d+$`, etag); !matched || err != nil {
+ t.Errorf("Failed regexp expression `^ETAG\\d+$`")
+ }
+ }
+
+ // Custom headers
+ cmu := (*args)[0].(*s3.CreateMultipartUploadInput)
+
+ if e, a := types.ServerSideEncryptionAwsKms, cmu.ServerSideEncryption; e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ if e, a := "KmsId", aws.ToString(cmu.SSEKMSKeyId); e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ if e, a := "content/type", aws.ToString(cmu.ContentType); e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+}
+
+func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
+ s, ops, args := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(s, func(u *manager.Uploader) {
+ u.PartSize = 1024 * 1024 * 7
+ u.Concurrency = 1
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(buf12MB),
+ })
+
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}
+ if !reflect.DeepEqual(vals, *ops) {
+ t.Errorf("expect %v, got %v", vals, *ops)
+ }
+
+ // Part lengths
+ if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); 1024*1024*7 != len {
+ t.Errorf("expect %d, got %d", 1024*1024*7, len)
+ }
+ if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); 1024*1024*5 != len {
+ t.Errorf("expect %d, got %d", 1024*1024*5, len)
+ }
+}
+
+func TestUploadIncreasePartSize(t *testing.T) {
+ s, invocations, args := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(s, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ u.MaxUploadParts = 2
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(buf12MB),
+ })
+
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ if int64(manager.DefaultDownloadPartSize) != mgr.PartSize {
+ t.Errorf("expect %d, got %d", manager.DefaultDownloadPartSize, mgr.PartSize)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ // Part lengths
+ if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); (1024*1024*6)+1 != len {
+ t.Errorf("expect %d, got %d", (1024*1024*6)+1, len)
+ }
+
+ if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); (1024*1024*6)-1 != len {
+ t.Errorf("expect %d, got %d", (1024*1024*6)-1, len)
+ }
+}
+
+func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
+ mgr := manager.NewUploader(s3.New(s3.Options{}), func(u *manager.Uploader) {
+ u.PartSize = 5
+ })
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(buf12MB),
+ })
+
+ if resp != nil {
+ t.Errorf("Expected response to be nil, but received %v", resp)
+ }
+
+ if err == nil {
+ t.Errorf("Expected error, but received nil")
+ }
+
+ if e, a := "part size must be at least", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v to be in %v", e, a)
+ }
+}
+
+func TestUploadOrderSingle(t *testing.T) {
+ client, invocations, params := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(client)
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key - value"),
+ Body: bytes.NewReader(buf2MB),
+ ServerSideEncryption: types.ServerSideEncryptionAwsKms,
+ SSEKMSKeyId: aws.String("KmsId"),
+ ContentType: aws.String("content/type"),
+ })
+
+ if err != nil {
+ t.Errorf("expect no error but received %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ if e := "VERSION-ID"; e != *resp.VersionID {
+ t.Errorf("expect %q, got %q", e, *resp.VersionID)
+ }
+
+ if len(resp.UploadID) > 0 {
+ t.Errorf("expect empty string, got %q", resp.UploadID)
+ }
+
+ putObjectInput := (*params)[0].(*s3.PutObjectInput)
+
+ if e, a := types.ServerSideEncryptionAwsKms, putObjectInput.ServerSideEncryption; e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ if e, a := "KmsId", aws.ToString(putObjectInput.SSEKMSKeyId); e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ if e, a := "content/type", aws.ToString(putObjectInput.ContentType); e != a {
+ t.Errorf("Expected %q, but received %q", e, a)
+ }
+}
+
+func TestUploadOrderSingleFailure(t *testing.T) {
+ client, ops, _ := s3testing.NewUploadLoggingClient(nil)
+
+ client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
+ return nil, fmt.Errorf("put object failure")
+ }
+
+ mgr := manager.NewUploader(client)
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(buf2MB),
+ })
+
+ if err == nil {
+ t.Error("expect error, got nil")
+ }
+
+ if diff := cmp.Diff([]string{"PutObject"}, *ops); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ if resp != nil {
+ t.Errorf("expect response to be nil, got %v", resp)
+ }
+}
+
+func TestUploadOrderZero(t *testing.T) {
+ c, invocations, params := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c)
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(make([]byte, 0)),
+ })
+
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ if len(resp.Location) == 0 {
+ t.Error("expect Location to not be empty")
+ }
+
+ if len(resp.UploadID) > 0 {
+ t.Errorf("expect empty string, got %q", resp.UploadID)
+ }
+
+ if e, a := int64(0), getReaderLength((*params)[0].(*s3.PutObjectInput).Body); e != a {
+ t.Errorf("Expected %d, but received %d", e, a)
+ }
+}
+
+func TestUploadOrderMultiFailure(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+
+ c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
+ if u.PartNum == 2 {
+ return nil, fmt.Errorf("an unexpected error")
+ }
+ return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
+ }
+
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(buf12MB),
+ })
+
+ if err == nil {
+ t.Error("expect error, got nil")
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+
+ c.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) {
+ return nil, fmt.Errorf("complete multipart error")
+ }
+
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(buf12MB),
+ })
+
+ if err == nil {
+ t.Error("expect error, got nil")
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart",
+ "CompleteMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+
+ c.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
+ return nil, fmt.Errorf("create multipart upload failure")
+ }
+
+ mgr := manager.NewUploader(c)
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(make([]byte, 1024*1024*12)),
+ })
+
+ if err == nil {
+ t.Error("expect error, got nil")
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+
+ c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
+ if u.PartNum == 2 {
+ return nil, fmt.Errorf("upload part failure")
+ }
+ return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
+ }
+
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ u.LeavePartsOnError = true
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(make([]byte, 1024*1024*12)),
+ })
+
+ if err == nil {
+ t.Error("expect error, got nil")
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *invocations); len(diff) > 0 {
+ t.Error(err)
+ }
+}
+
+type failreader struct {
+ times int
+ failCount int
+}
+
+func (f *failreader) Read(b []byte) (int, error) {
+ f.failCount++
+ if f.failCount >= f.times {
+ return 0, fmt.Errorf("random failure")
+ }
+ return len(b), nil
+}
+
+func TestUploadOrderReadFail1(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c)
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &failreader{times: 1},
+ })
+ if err == nil {
+ t.Fatalf("expect error to not be nil")
+ }
+
+ if e, a := "random failure", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v, got %v", e, a)
+ }
+
+ if diff := cmp.Diff([]string(nil), *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func TestUploadOrderReadFail2(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"})
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &failreader{times: 2},
+ })
+ if err == nil {
+ t.Fatalf("expect error to not be nil")
+ }
+
+ if e, a := "random failure", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expect %v, got %q", e, a)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+type sizedReader struct {
+ size int
+ cur int
+ err error
+}
+
+func (s *sizedReader) Read(p []byte) (n int, err error) {
+ if s.cur >= s.size {
+ if s.err == nil {
+ s.err = io.EOF
+ }
+ return 0, s.err
+ }
+
+ n = len(p)
+ s.cur += len(p)
+ if s.cur > s.size {
+ n -= s.cur - s.size
+ }
+
+ return n, err
+}
+
+func TestUploadOrderMultiBufferedReader(t *testing.T) {
+ c, invocations, params := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c)
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &sizedReader{size: 1024 * 1024 * 12},
+ })
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
+ "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ // Part lengths
+ var parts []int64
+ for i := 1; i <= 3; i++ {
+ parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
+ }
+ sort.Slice(parts, func(i, j int) bool {
+ return parts[i] < parts[j]
+ })
+
+ if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
+ c, invocations, params := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c)
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &sizedReader{size: 1024 * 1024 * 12, err: io.EOF},
+ })
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
+ "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ // Part lengths
+ var parts []int64
+ for i := 1; i <= 3; i++ {
+ parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
+ }
+ sort.Slice(parts, func(i, j int) bool {
+ return parts[i] < parts[j]
+ })
+
+ if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the
+// file size is the same as part size.
+func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) {
+ c, invocations, params := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c)
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &sizedReader{size: 1024 * 1024 * 10, err: io.EOF},
+ })
+
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ // Part lengths
+ var parts []int64
+ for i := 1; i <= 2; i++ {
+ parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
+ }
+ sort.Slice(parts, func(i, j int) bool {
+ return parts[i] < parts[j]
+ })
+
+ if diff := cmp.Diff([]int64{1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"})
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ u.MaxUploadParts = 2
+ })
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &sizedReader{size: 1024 * 1024 * 12},
+ })
+ if err == nil {
+ t.Fatal("expect error, got nil")
+ }
+
+ if resp != nil {
+ t.Errorf("expect nil, got %v", resp)
+ }
+
+ if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ if !strings.Contains(err.Error(), "configured MaxUploadParts (2)") {
+ t.Errorf("expect 'configured MaxUploadParts (2)', got %q", err.Error())
+ }
+}
+
+func TestUploadOrderSingleBufferedReader(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c)
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &sizedReader{size: 1024 * 1024 * 2},
+ })
+
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+
+ if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+
+ if len(resp.Location) == 0 {
+ t.Error("expect a value in Location")
+ }
+
+ if len(resp.UploadID) > 0 {
+ t.Errorf("expect no value, got %q", resp.UploadID)
+ }
+}
+
+func TestUploadZeroLenObject(t *testing.T) {
+ client, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+
+ mgr := manager.NewUploader(client)
+ resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: strings.NewReader(""),
+ })
+
+ if err != nil {
+ t.Errorf("expect no error but received %v", err)
+ }
+ if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
+ t.Errorf("expect request to have been made, but was not, %v", diff)
+ }
+
+ // TODO: not needed?
+ if len(resp.Location) == 0 {
+ t.Error("expect a non-empty string value for Location")
+ }
+
+ if len(resp.UploadID) > 0 {
+ t.Errorf("expect empty string, but received %q", resp.UploadID)
+ }
+}
+
+type testIncompleteReader struct {
+ Size int64
+ read int64
+}
+
+func (r *testIncompleteReader) Read(p []byte) (n int, err error) {
+ r.read += int64(len(p))
+ if r.read >= r.Size {
+ return int(r.read - r.Size), io.ErrUnexpectedEOF
+ }
+ return len(p), nil
+}
+
+func TestUploadUnexpectedEOF(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ u.PartSize = manager.MinUploadPartSize
+ })
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: &testIncompleteReader{
+ Size: manager.MinUploadPartSize + 1,
+ },
+ })
+ if err == nil {
+ t.Error("expect error, got nil")
+ }
+
+ // Ensure upload started.
+ if e, a := "CreateMultipartUpload", (*invocations)[0]; e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+
+ // Part may or may not be sent because of timing of sending parts and
+ // reading next part in upload manager. Just check for the last abort.
+ if e, a := "AbortMultipartUpload", (*invocations)[len(*invocations)-1]; e != a {
+ t.Errorf("expect %q, got %q", e, a)
+ }
+}
+
+func TestSSE(t *testing.T) {
+ client, _, _ := s3testing.NewUploadLoggingClient(nil)
+ client.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
+ if params.SSECustomerAlgorithm == nil {
+ t.Fatal("SSECustomerAlgoritm should not be nil")
+ }
+ if params.SSECustomerKey == nil {
+ t.Fatal("SSECustomerKey should not be nil")
+ }
+ return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
+ }
+
+ mgr := manager.NewUploader(client, func(u *manager.Uploader) {
+ u.Concurrency = 5
+ })
+
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ SSECustomerAlgorithm: aws.String("AES256"),
+ SSECustomerKey: aws.String("foo"),
+ Body: bytes.NewBuffer(make([]byte, 1024*1024*10)),
+ })
+
+ if err != nil {
+ t.Fatal("Expected no error, but received" + err.Error())
+ }
+}
+
+func TestUploadWithContextCanceled(t *testing.T) {
+ u := manager.NewUploader(s3.New(s3.Options{
+ UsePathStyle: true,
+ }))
+
+ params := s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: bytes.NewReader(make([]byte, 0)),
+ }
+
+ ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
+ ctx.Error = fmt.Errorf("context canceled")
+ close(ctx.DoneCh)
+
+ _, err := u.Upload(ctx, ¶ms)
+ if err == nil {
+ t.Fatalf("expect error, got nil")
+ }
+
+ if e, a := "canceled", err.Error(); !strings.Contains(a, e) {
+ t.Errorf("expected error message to contain %q, but did not %q", e, a)
+ }
+}
+
+// S3 Uploader incorrectly fails an upload if the content being uploaded
+// has a size of MinPartSize * MaxUploadParts.
+// Github: aws/aws-sdk-go#2557
+func TestUploadMaxPartsEOF(t *testing.T) {
+ c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
+ mgr := manager.NewUploader(c, func(u *manager.Uploader) {
+ u.Concurrency = 1
+ u.PartSize = manager.DefaultUploadPartSize
+ u.MaxUploadParts = 2
+ })
+ f := bytes.NewReader(make([]byte, int(mgr.PartSize)*int(mgr.MaxUploadParts)))
+
+ r1 := io.NewSectionReader(f, 0, manager.DefaultUploadPartSize)
+ r2 := io.NewSectionReader(f, manager.DefaultUploadPartSize, 2*manager.DefaultUploadPartSize)
+ body := io.MultiReader(r1, r2)
+
+ _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("Bucket"),
+ Key: aws.String("Key"),
+ Body: body,
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+
+ expectOps := []string{
+ "CreateMultipartUpload",
+ "UploadPart",
+ "UploadPart",
+ "CompleteMultipartUpload",
+ }
+ if diff := cmp.Diff(expectOps, *invocations); len(diff) > 0 {
+ t.Error(diff)
+ }
+}
+
+func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) {
+ file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name())
+ if err != nil {
+ return nil, nil, err
+ }
+ filename := file.Name()
+ if err := file.Truncate(size); err != nil {
+ return nil, nil, err
+ }
+
+ return file,
+ func(t *testing.T) {
+ if err := file.Close(); err != nil {
+ t.Errorf("failed to close temp file, %s, %v", filename, err)
+ }
+ if err := os.Remove(filename); err != nil {
+ t.Errorf("failed to remove temp file, %s, %v", filename, err)
+ }
+ },
+ nil
+}
+
+func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler {
+ handlers := make([]http.Handler, parts)
+ for i := 0; i < len(handlers); i++ {
+ handlers[i] = &failPartHandler{
+ tb: tb,
+ failsRemaining: retry,
+ successHandler: successPartHandler{tb: tb},
+ }
+ }
+
+ return handlers
+}
+
+func TestUploadRetry(t *testing.T) {
+ const numParts, retries = 3, 10
+
+ testFile, testFileCleanup, err := createTempFile(t, manager.DefaultUploadPartSize*numParts)
+ if err != nil {
+ t.Fatalf("failed to create test file, %v", err)
+ }
+ defer testFileCleanup(t)
+
+ cases := map[string]struct {
+ Body io.Reader
+ PartHandlers func(testing.TB) []http.Handler
+ }{
+ "bytes.Buffer": {
+ Body: bytes.NewBuffer(make([]byte, manager.DefaultUploadPartSize*numParts)),
+ PartHandlers: func(tb testing.TB) []http.Handler {
+ return buildFailHandlers(tb, numParts, retries)
+ },
+ },
+ "bytes.Reader": {
+ Body: bytes.NewReader(make([]byte, manager.DefaultUploadPartSize*numParts)),
+ PartHandlers: func(tb testing.TB) []http.Handler {
+ return buildFailHandlers(tb, numParts, retries)
+ },
+ },
+ "os.File": {
+ Body: testFile,
+ PartHandlers: func(tb testing.TB) []http.Handler {
+ return buildFailHandlers(tb, numParts, retries)
+ },
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ restoreSleep := sdk.TestingUseNoOpSleep()
+ defer restoreSleep()
+
+ mux := newMockS3UploadServer(t, c.PartHandlers(t))
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ client := s3.New(s3.Options{
+ EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) {
+ return aws.Endpoint{
+ URL: server.URL,
+ }, nil
+ }),
+ UsePathStyle: true,
+ Retryer: retry.NewStandard(func(o *retry.StandardOptions) {
+ o.MaxAttempts = retries + 1
+ }),
+ })
+
+ uploader := manager.NewUploader(client)
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: c.Body,
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ })
+ }
+}
+
+func TestUploadBufferStrategy(t *testing.T) {
+ cases := map[string]struct {
+ PartSize int64
+ Size int64
+ Strategy manager.ReadSeekerWriteToProvider
+ callbacks int
+ }{
+ "NoBuffer": {
+ PartSize: manager.DefaultUploadPartSize,
+ Strategy: nil,
+ },
+ "SinglePart": {
+ PartSize: manager.DefaultUploadPartSize,
+ Size: manager.DefaultUploadPartSize,
+ Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)},
+ callbacks: 1,
+ },
+ "MultiPart": {
+ PartSize: manager.DefaultUploadPartSize,
+ Size: manager.DefaultUploadPartSize * 2,
+ Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)},
+ callbacks: 2,
+ },
+ }
+
+ for name, tCase := range cases {
+ t.Run(name, func(t *testing.T) {
+ client, _, _ := s3testing.NewUploadLoggingClient(nil)
+ client.ConsumeBody = true
+
+ uploader := manager.NewUploader(client, func(u *manager.Uploader) {
+ u.PartSize = tCase.PartSize
+ u.BufferProvider = tCase.Strategy
+ u.Concurrency = 1
+ })
+
+ expected := s3testing.GetTestBytes(int(tCase.Size))
+ _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: bytes.NewReader(expected),
+ })
+ if err != nil {
+ t.Fatalf("failed to upload file: %v", err)
+ }
+
+ switch strat := tCase.Strategy.(type) {
+ case *recordedBufferProvider:
+ if !bytes.Equal(expected, strat.content) {
+ t.Errorf("content buffered did not match expected")
+ }
+ if tCase.callbacks != strat.callbackCount {
+ t.Errorf("expected %v, got %v callbacks", tCase.callbacks, strat.callbackCount)
+ }
+ }
+ })
+ }
+}
+
+type mockS3UploadServer struct {
+ *http.ServeMux
+
+ tb testing.TB
+ partHandler []http.Handler
+}
+
+func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer {
+ s := &mockS3UploadServer{
+ ServeMux: http.NewServeMux(),
+ partHandler: partHandler,
+ tb: tb,
+ }
+
+ s.HandleFunc("/", s.handleRequest)
+
+ return s
+}
+
+func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) {
+ defer r.Body.Close()
+
+ _, hasUploads := r.URL.Query()["uploads"]
+
+ switch {
+ case r.Method == "POST" && hasUploads:
+ // CreateMultipartUpload
+ w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp)))
+ w.Write([]byte(createUploadResp))
+
+ case r.Method == "PUT":
+ // UploadPart
+ partNumStr := r.URL.Query().Get("partNumber")
+ id, err := strconv.Atoi(partNumStr)
+ if err != nil {
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("unable to parse partNumber, %q, %v",
+ partNumStr, err))
+ return
+ }
+ id--
+ if id < 0 || id >= len(s.partHandler) {
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("invalid partNumber %v", id))
+ return
+ }
+ s.partHandler[id].ServeHTTP(w, r)
+
+ case r.Method == "POST":
+ // CompleteMultipartUpload
+ w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp)))
+ w.Write([]byte(completeUploadResp))
+
+ case r.Method == "DELETE":
+ // AbortMultipartUpload
+ w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp)))
+ w.WriteHeader(200)
+ w.Write([]byte(abortUploadResp))
+
+ default:
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("invalid request %v %v", r.Method, r.URL))
+ }
+}
+
+func failRequest(w http.ResponseWriter, status int, code, msg string) {
+ msg = fmt.Sprintf(baseRequestErrorResp, code, msg)
+ w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
+ w.WriteHeader(status)
+ w.Write([]byte(msg))
+}
+
+type successPartHandler struct {
+ tb testing.TB
+}
+
+func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ defer r.Body.Close()
+
+ n, err := io.Copy(ioutil.Discard, r.Body)
+ if err != nil {
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("failed to read body, %v", err))
+ return
+ }
+
+ contLenStr := r.Header.Get("Content-Length")
+ expectLen, err := strconv.ParseInt(contLenStr, 10, 64)
+ if err != nil {
+ h.tb.Logf("expect content-length, got %q, %v", contLenStr, err)
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("unable to get content-length %v", err))
+ return
+ }
+ if e, a := expectLen, n; e != a {
+ h.tb.Logf("expect %v read, got %v", e, a)
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf(
+ "content-length and body do not match, %v, %v", e, a))
+ return
+ }
+
+ w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp)))
+ w.Write([]byte(uploadPartResp))
+}
+
+type failPartHandler struct {
+ tb testing.TB
+
+ failsRemaining int
+ successHandler http.Handler
+}
+
+func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ defer r.Body.Close()
+
+ if h.failsRemaining == 0 && h.successHandler != nil {
+ h.successHandler.ServeHTTP(w, r)
+ return
+ }
+
+ io.Copy(ioutil.Discard, r.Body)
+
+ failRequest(w, 500, "InternalException",
+ fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber")))
+
+ h.failsRemaining--
+}
+
+type recordedBufferProvider struct {
+ content []byte
+ size int
+ callbackCount int
+}
+
+func (r *recordedBufferProvider) GetWriteTo(seeker io.ReadSeeker) (manager.ReadSeekerWriteTo, func()) {
+ b := make([]byte, r.size)
+ w := &manager.BufferedReadSeekerWriteTo{BufferedReadSeeker: manager.NewBufferedReadSeeker(seeker, b)}
+
+ return w, func() {
+ r.content = append(r.content, b...)
+ r.callbackCount++
+ }
+}
+
+const createUploadResp = `
+ bucket
+ key
+ abc123
+`
+
+const uploadPartResp = `
+ key
+`
+const baseRequestErrorResp = `
+ %s
+ %s
+ request-id
+ host-id
+`
+
+const completeUploadResp = `
+ bucket
+ key
+ key
+ https://bucket.us-west-2.amazonaws.com/key
+ abc123
+`
+
+const abortUploadResp = ``
diff --git a/feature/s3/manager/writer_read_from.go b/feature/s3/manager/writer_read_from.go
new file mode 100644
index 00000000000..3df983a652a
--- /dev/null
+++ b/feature/s3/manager/writer_read_from.go
@@ -0,0 +1,75 @@
+package manager
+
+import (
+ "bufio"
+ "io"
+ "sync"
+
+ "github.com/aws/aws-sdk-go-v2/internal/sdkio"
+)
+
+// WriterReadFrom defines an interface implementing io.Writer and io.ReaderFrom
+type WriterReadFrom interface {
+ io.Writer
+ io.ReaderFrom
+}
+
+// WriterReadFromProvider provides an implementation of io.ReadFrom for the given io.Writer
+type WriterReadFromProvider interface {
+ GetReadFrom(writer io.Writer) (w WriterReadFrom, cleanup func())
+}
+
+type bufferedWriter interface {
+ WriterReadFrom
+ Flush() error
+ Reset(io.Writer)
+}
+
+type bufferedReadFrom struct {
+ bufferedWriter
+}
+
+func (b *bufferedReadFrom) ReadFrom(r io.Reader) (int64, error) {
+ n, err := b.bufferedWriter.ReadFrom(r)
+ if flushErr := b.Flush(); flushErr != nil && err == nil {
+ err = flushErr
+ }
+ return n, err
+}
+
+// PooledBufferedReadFromProvider is a WriterReadFromProvider that uses a sync.Pool
+// to manage allocation and reuse of *bufio.Writer structures.
+type PooledBufferedReadFromProvider struct {
+ pool sync.Pool
+}
+
+// NewPooledBufferedWriterReadFromProvider returns a new PooledBufferedReadFromProvider
+// Size is used to control the size of the underlying *bufio.Writer created for
+// calls to GetReadFrom.
+func NewPooledBufferedWriterReadFromProvider(size int) *PooledBufferedReadFromProvider {
+ if size < int(32*sdkio.KibiByte) {
+ size = int(64 * sdkio.KibiByte)
+ }
+
+ return &PooledBufferedReadFromProvider{
+ pool: sync.Pool{
+ New: func() interface{} {
+ return &bufferedReadFrom{bufferedWriter: bufio.NewWriterSize(nil, size)}
+ },
+ },
+ }
+}
+
+// GetReadFrom takes an io.Writer and wraps it with a type which satisfies the WriterReadFrom
+// interface/ Additionally a cleanup function is provided which must be called after usage of the WriterReadFrom
+// has been completed in order to allow the reuse of the *bufio.Writer
+func (p *PooledBufferedReadFromProvider) GetReadFrom(writer io.Writer) (r WriterReadFrom, cleanup func()) {
+ buffer := p.pool.Get().(*bufferedReadFrom)
+ buffer.Reset(writer)
+ r = buffer
+ cleanup = func() {
+ buffer.Reset(nil) // Reset to nil writer to release reference
+ p.pool.Put(buffer)
+ }
+ return r, cleanup
+}
diff --git a/feature/s3/manager/writer_read_from_test.go b/feature/s3/manager/writer_read_from_test.go
new file mode 100644
index 00000000000..4f59f68cdc3
--- /dev/null
+++ b/feature/s3/manager/writer_read_from_test.go
@@ -0,0 +1,73 @@
+package manager
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+)
+
+type testBufioWriter struct {
+ ReadFromN int64
+ ReadFromErr error
+ FlushReturn error
+}
+
+func (t testBufioWriter) Write(p []byte) (n int, err error) {
+ panic("unused")
+}
+
+func (t testBufioWriter) ReadFrom(r io.Reader) (n int64, err error) {
+ return t.ReadFromN, t.ReadFromErr
+}
+
+func (t testBufioWriter) Flush() error {
+ return t.FlushReturn
+}
+
+func (t *testBufioWriter) Reset(io.Writer) {
+ panic("unused")
+}
+
+func TestBufferedReadFromFlusher_ReadFrom(t *testing.T) {
+ cases := map[string]struct {
+ w testBufioWriter
+ expectedErr error
+ }{
+ "no errors": {},
+ "error returned from underlying ReadFrom": {
+ w: testBufioWriter{
+ ReadFromN: 42,
+ ReadFromErr: fmt.Errorf("readfrom"),
+ },
+ expectedErr: fmt.Errorf("readfrom"),
+ },
+ "error returned from Flush": {
+ w: testBufioWriter{
+ ReadFromN: 7,
+ FlushReturn: fmt.Errorf("flush"),
+ },
+ expectedErr: fmt.Errorf("flush"),
+ },
+ "error returned from ReadFrom and Flush": {
+ w: testBufioWriter{
+ ReadFromN: 1337,
+ ReadFromErr: fmt.Errorf("readfrom"),
+ FlushReturn: fmt.Errorf("flush"),
+ },
+ expectedErr: fmt.Errorf("readfrom"),
+ },
+ }
+
+ for name, tCase := range cases {
+ t.Log(name)
+ readFromFlusher := bufferedReadFrom{bufferedWriter: &tCase.w}
+ n, err := readFromFlusher.ReadFrom(nil)
+ if e, a := tCase.w.ReadFromN, n; e != a {
+ t.Errorf("expected %v bytes, got %v", e, a)
+ }
+ if e, a := tCase.expectedErr, err; !reflect.DeepEqual(e, a) {
+ t.Errorf("expected error %v. got %v", e, a)
+ }
+ }
+}