From efc148e8a664bdac0b6042f4a7ecf93c299b28f5 Mon Sep 17 00:00:00 2001 From: Mark Theunissen Date: Wed, 20 Nov 2024 17:22:29 +1100 Subject: [PATCH] additional test for negative case, test coverage for policy, remove comment --- functional_tests.go | 211 +++++++++++++++++--- post-policy.go | 53 +++-- post-policy_test.go | 459 ++++++++++++++++++++++++++++++++++++++++++++ retry.go | 8 +- 4 files changed, 679 insertions(+), 52 deletions(-) create mode 100644 post-policy_test.go diff --git a/functional_tests.go b/functional_tests.go index 714bb3cc9..a3dbc49c9 100644 --- a/functional_tests.go +++ b/functional_tests.go @@ -5643,9 +5643,6 @@ func testPresignedPostPolicy() { "policy": "", } - // Seed random based on current time. - rand.Seed(time.Now().Unix()) - // Instantiate new minio client object c, err := minio.New(os.Getenv(serverEndpoint), &minio.Options{ @@ -5692,38 +5689,13 @@ func testPresignedPostPolicy() { } policy := minio.NewPostPolicy() - - if err := policy.SetBucket(""); err == nil { - logError(testName, function, args, startTime, "", "SetBucket did not fail for invalid conditions", err) - return - } - if err := policy.SetKey(""); err == nil { - logError(testName, function, args, startTime, "", "SetKey did not fail for invalid conditions", err) - return - } - if err := policy.SetExpires(time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC)); err == nil { - logError(testName, function, args, startTime, "", "SetExpires did not fail for invalid conditions", err) - return - } - if err := policy.SetContentType(""); err == nil { - logError(testName, function, args, startTime, "", "SetContentType did not fail for invalid conditions", err) - return - } - if err := policy.SetContentLengthRange(1024*1024, 1024); err == nil { - logError(testName, function, args, startTime, "", "SetContentLengthRange did not fail for invalid conditions", err) - return - } - if err := policy.SetUserMetadata("", ""); err == nil { - logError(testName, function, args, startTime, "", "SetUserMetadata did not fail for invalid conditions", err) - return - } - policy.SetBucket(bucketName) policy.SetKey(objectName) policy.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) // expires in 10 days policy.SetContentType("binary/octet-stream") policy.SetContentLengthRange(10, 1024*1024) policy.SetUserMetadata(metadataKey, metadataValue) + policy.SetContentEncoding("gzip") // Add CRC32C checksum := minio.ChecksumCRC32C.ChecksumBytes(buf) @@ -5865,6 +5837,186 @@ func testPresignedPostPolicy() { logSuccess(testName, function, args, startTime) } +// testPresignedPostPolicyWrongFile tests that when we have a policy with a checksum, we cannot POST the wrong file +func testPresignedPostPolicyWrongFile() { + // initialize logging params + startTime := time.Now() + testName := getFuncName() + function := "PresignedPostPolicy(policy)" + args := map[string]interface{}{ + "policy": "", + } + + // Instantiate new minio client object + c, err := minio.New(os.Getenv(serverEndpoint), + &minio.Options{ + Creds: credentials.NewStaticV4(os.Getenv(accessKey), os.Getenv(secretKey), ""), + Transport: createHTTPTransport(), + Secure: mustParseBool(os.Getenv(enableHTTPS)), + }) + if err != nil { + logError(testName, function, args, startTime, "", "MinIO client object creation failed", err) + return + } + + // Enable tracing, write to stderr. + // c.TraceOn(os.Stderr) + + // Set user agent. + c.SetAppInfo("MinIO-go-FunctionalTest", appVersion) + + // Generate a new random bucket name. + bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test-") + + // Make a new bucket in 'us-east-1' (source bucket). + err = c.MakeBucket(context.Background(), bucketName, minio.MakeBucketOptions{Region: "us-east-1"}) + if err != nil { + logError(testName, function, args, startTime, "", "MakeBucket failed", err) + return + } + + defer cleanupBucket(bucketName, c) + + // Generate 33K of data. + reader := getDataReader("datafile-33-kB") + defer reader.Close() + + objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "") + // Azure requires the key to not start with a number + metadataKey := randString(60, rand.NewSource(time.Now().UnixNano()), "user") + metadataValue := randString(60, rand.NewSource(time.Now().UnixNano()), "") + + buf, err := io.ReadAll(reader) + if err != nil { + logError(testName, function, args, startTime, "", "ReadAll failed", err) + return + } + + policy := minio.NewPostPolicy() + policy.SetBucket(bucketName) + policy.SetKey(objectName) + policy.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) // expires in 10 days + policy.SetContentType("binary/octet-stream") + policy.SetContentLengthRange(10, 1024*1024) + policy.SetUserMetadata(metadataKey, metadataValue) + + // Add CRC32C of the 33kB file that the policy will explicitly allow. + checksum := minio.ChecksumCRC32C.ChecksumBytes(buf) + err = policy.SetChecksum(checksum) + if err != nil { + logError(testName, function, args, startTime, "", "SetChecksum failed", err) + return + } + + args["policy"] = policy.String() + + presignedPostPolicyURL, formData, err := c.PresignedPostPolicy(context.Background(), policy) + if err != nil { + logError(testName, function, args, startTime, "", "PresignedPostPolicy failed", err) + return + } + + // At this stage, we have a policy that allows us to upload datafile-33-kB. + // Test that uploading datafile-10-kB, with a different checksum, fails as expected + filePath := getMintDataDirFilePath("datafile-10-kB") + if filePath == "" { + // Make a temp file with 10 KB data. + file, err := os.CreateTemp(os.TempDir(), "PresignedPostPolicyTest") + if err != nil { + logError(testName, function, args, startTime, "", "TempFile creation failed", err) + return + } + if _, err = io.Copy(file, getDataReader("datafile-10-kB")); err != nil { + logError(testName, function, args, startTime, "", "Copy failed", err) + return + } + if err = file.Close(); err != nil { + logError(testName, function, args, startTime, "", "File Close failed", err) + return + } + filePath = file.Name() + } + fileReader := getDataReader("datafile-10-kB") + defer fileReader.Close() + buf10k, err := io.ReadAll(fileReader) + if err != nil { + logError(testName, function, args, startTime, "", "ReadAll failed", err) + return + } + otherChecksum := minio.ChecksumCRC32C.ChecksumBytes(buf10k) + + var formBuf bytes.Buffer + writer := multipart.NewWriter(&formBuf) + for k, v := range formData { + if k == "x-amz-checksum-crc32c" { + v = otherChecksum.Encoded() + } + writer.WriteField(k, v) + } + + // Add file to post request + f, err := os.Open(filePath) + defer f.Close() + if err != nil { + logError(testName, function, args, startTime, "", "File open failed", err) + return + } + w, err := writer.CreateFormFile("file", filePath) + if err != nil { + logError(testName, function, args, startTime, "", "CreateFormFile failed", err) + return + } + _, err = io.Copy(w, f) + if err != nil { + logError(testName, function, args, startTime, "", "Copy failed", err) + return + } + writer.Close() + + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: createHTTPTransport(), + } + args["url"] = presignedPostPolicyURL.String() + + req, err := http.NewRequest(http.MethodPost, presignedPostPolicyURL.String(), bytes.NewReader(formBuf.Bytes())) + if err != nil { + logError(testName, function, args, startTime, "", "HTTP request failed", err) + return + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // Make the POST request with the form data. + res, err := httpClient.Do(req) + if err != nil { + logError(testName, function, args, startTime, "", "HTTP request failed", err) + return + } + defer res.Body.Close() + if res.StatusCode != http.StatusForbidden { + logError(testName, function, args, startTime, "", "HTTP request unexpected status", errors.New(res.Status)) + return + } + + // Read the response body, ensure it has checksum failure message + resBody, err := io.ReadAll(res.Body) + if err != nil { + logError(testName, function, args, startTime, "", "ReadAll failed", err) + return + } + + // Normalize the response body, because S3 uses quotes around the policy condition components + // in the error message, MinIO does not. + resBodyStr := strings.ReplaceAll(string(resBody), `"`, "") + if !strings.Contains(resBodyStr, "Policy Condition failed: [eq, $x-amz-checksum-crc32c, aHnJMw==]") { + logError(testName, function, args, startTime, "", "Unexpected response body", errors.New(resBodyStr)) + return + } + + logSuccess(testName, function, args, startTime) +} + // Tests copy object func testCopyObject() { // initialize logging params @@ -14977,6 +15129,7 @@ func main() { testGetObjectReadAtFunctional() testGetObjectReadAtWhenEOFWasReached() testPresignedPostPolicy() + testPresignedPostPolicyWrongFile() testCopyObject() testComposeObjectErrorCases() testCompose10KSources() diff --git a/post-policy.go b/post-policy.go index 1b01f4101..00c770d74 100644 --- a/post-policy.go +++ b/post-policy.go @@ -85,7 +85,7 @@ func (p *PostPolicy) SetExpires(t time.Time) error { // SetKey - Sets an object name for the policy based upload. func (p *PostPolicy) SetKey(key string) error { - if strings.TrimSpace(key) == "" || key == "" { + if strings.TrimSpace(key) == "" { return errInvalidArgument("Object name is empty.") } policyCond := policyCondition{ @@ -118,7 +118,7 @@ func (p *PostPolicy) SetKeyStartsWith(keyStartsWith string) error { // SetBucket - Sets bucket at which objects will be uploaded to. func (p *PostPolicy) SetBucket(bucketName string) error { - if strings.TrimSpace(bucketName) == "" || bucketName == "" { + if strings.TrimSpace(bucketName) == "" { return errInvalidArgument("Bucket name is empty.") } policyCond := policyCondition{ @@ -135,7 +135,7 @@ func (p *PostPolicy) SetBucket(bucketName string) error { // SetCondition - Sets condition for credentials, date and algorithm func (p *PostPolicy) SetCondition(matchType, condition, value string) error { - if strings.TrimSpace(value) == "" || value == "" { + if strings.TrimSpace(value) == "" { return errInvalidArgument("No value specified for condition") } @@ -156,7 +156,7 @@ func (p *PostPolicy) SetCondition(matchType, condition, value string) error { // SetTagging - Sets tagging for the object for this policy based upload. func (p *PostPolicy) SetTagging(tagging string) error { - if strings.TrimSpace(tagging) == "" || tagging == "" { + if strings.TrimSpace(tagging) == "" { return errInvalidArgument("No tagging specified.") } _, err := tags.ParseObjectXML(strings.NewReader(tagging)) @@ -178,7 +178,7 @@ func (p *PostPolicy) SetTagging(tagging string) error { // SetContentType - Sets content-type of the object for this policy // based upload. func (p *PostPolicy) SetContentType(contentType string) error { - if strings.TrimSpace(contentType) == "" || contentType == "" { + if strings.TrimSpace(contentType) == "" { return errInvalidArgument("No content type specified.") } policyCond := policyCondition{ @@ -211,7 +211,7 @@ func (p *PostPolicy) SetContentTypeStartsWith(contentTypeStartsWith string) erro // SetContentDisposition - Sets content-disposition of the object for this policy func (p *PostPolicy) SetContentDisposition(contentDisposition string) error { - if strings.TrimSpace(contentDisposition) == "" || contentDisposition == "" { + if strings.TrimSpace(contentDisposition) == "" { return errInvalidArgument("No content disposition specified.") } policyCond := policyCondition{ @@ -226,27 +226,44 @@ func (p *PostPolicy) SetContentDisposition(contentDisposition string) error { return nil } +// SetContentEncoding - Sets content-encoding of the object for this policy +func (p *PostPolicy) SetContentEncoding(contentEncoding string) error { + if strings.TrimSpace(contentEncoding) == "" { + return errInvalidArgument("No content encoding specified.") + } + policyCond := policyCondition{ + matchType: "eq", + condition: "$Content-Encoding", + value: contentEncoding, + } + if err := p.addNewPolicy(policyCond); err != nil { + return err + } + p.formData["Content-Encoding"] = contentEncoding + return nil +} + // SetContentLengthRange - Set new min and max content length // condition for all incoming uploads. -func (p *PostPolicy) SetContentLengthRange(min, max int64) error { - if min > max { +func (p *PostPolicy) SetContentLengthRange(minv, maxv int64) error { + if minv > maxv { return errInvalidArgument("Minimum limit is larger than maximum limit.") } - if min < 0 { + if minv < 0 { return errInvalidArgument("Minimum limit cannot be negative.") } - if max <= 0 { + if maxv <= 0 { return errInvalidArgument("Maximum limit cannot be non-positive.") } - p.contentLengthRange.min = min - p.contentLengthRange.max = max + p.contentLengthRange.min = minv + p.contentLengthRange.max = maxv return nil } // SetSuccessActionRedirect - Sets the redirect success url of the object for this policy // based upload. func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error { - if strings.TrimSpace(redirect) == "" || redirect == "" { + if strings.TrimSpace(redirect) == "" { return errInvalidArgument("Redirect is empty") } policyCond := policyCondition{ @@ -264,7 +281,7 @@ func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error { // SetSuccessStatusAction - Sets the status success code of the object for this policy // based upload. func (p *PostPolicy) SetSuccessStatusAction(status string) error { - if strings.TrimSpace(status) == "" || status == "" { + if strings.TrimSpace(status) == "" { return errInvalidArgument("Status is empty") } policyCond := policyCondition{ @@ -282,10 +299,10 @@ func (p *PostPolicy) SetSuccessStatusAction(status string) error { // SetUserMetadata - Set user metadata as a key/value couple. // Can be retrieved through a HEAD request or an event. func (p *PostPolicy) SetUserMetadata(key, value string) error { - if strings.TrimSpace(key) == "" || key == "" { + if strings.TrimSpace(key) == "" { return errInvalidArgument("Key is empty") } - if strings.TrimSpace(value) == "" || value == "" { + if strings.TrimSpace(value) == "" { return errInvalidArgument("Value is empty") } headerName := fmt.Sprintf("x-amz-meta-%s", key) @@ -304,7 +321,7 @@ func (p *PostPolicy) SetUserMetadata(key, value string) error { // SetUserMetadataStartsWith - Set how an user metadata should starts with. // Can be retrieved through a HEAD request or an event. func (p *PostPolicy) SetUserMetadataStartsWith(key, value string) error { - if strings.TrimSpace(key) == "" || key == "" { + if strings.TrimSpace(key) == "" { return errInvalidArgument("Key is empty") } headerName := fmt.Sprintf("x-amz-meta-%s", key) @@ -326,8 +343,6 @@ func (p *PostPolicy) SetChecksum(c Checksum) error { p.formData[amzChecksumAlgo] = c.Type.String() p.formData[c.Type.Key()] = c.Encoded() - // Needed for S3 compatibility. MinIO ignores the checksum keys in the policy. - // https://github.com/minio/minio/blob/RELEASE.2024-08-29T01-40-52Z/cmd/postpolicyform.go#L60-L65 policyCond := policyCondition{ matchType: "eq", condition: fmt.Sprintf("$%s", amzChecksumAlgo), diff --git a/post-policy_test.go b/post-policy_test.go new file mode 100644 index 000000000..c105e053c --- /dev/null +++ b/post-policy_test.go @@ -0,0 +1,459 @@ +/* + * MinIO Go Library for Amazon S3 Compatible Cloud Storage + * Copyright 2015-2023 MinIO, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package minio + +import ( + "strings" + "testing" + "time" + + "github.com/minio/minio-go/v7/pkg/encrypt" +) + +func TestPostPolicySetExpires(t *testing.T) { + tests := []struct { + name string + input time.Time + wantErr bool + wantResult string + }{ + { + name: "valid time", + input: time.Date(2023, time.March, 2, 15, 4, 5, 0, time.UTC), + wantErr: false, + wantResult: "2023-03-02T15:04:05", + }, + { + name: "time before 1970", + input: time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetExpires(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetKey(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + wantResult string + }{ + { + name: "valid key", + input: "my-object", + wantResult: `"eq","$key","my-object"`, + }, + { + name: "empty key", + input: "", + wantErr: true, + }, + { + name: "key with spaces", + input: " ", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetKey(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetKeyStartsWith(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "valid key prefix", + input: "my-prefix/", + want: `["starts-with","$key","my-prefix/"]`, + }, + { + name: "empty prefix (allow any key)", + input: "", + want: `["starts-with","$key",""]`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetKeyStartsWith(tt.input) + if err != nil { + t.Errorf("%s: want no error, got: %v", tt.name, err) + } + + if tt.want != "" { + result := pp.String() + if !strings.Contains(result, tt.want) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.want, result) + } + } + }) + } +} + +func TestPostPolicySetBucket(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + wantResult string + }{ + { + name: "valid bucket", + input: "my-bucket", + wantResult: `"eq","$bucket","my-bucket"`, + }, + { + name: "empty bucket", + input: "", + wantErr: true, + }, + { + name: "bucket with spaces", + input: " ", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetBucket(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetCondition(t *testing.T) { + tests := []struct { + name string + matchType string + condition string + value string + wantErr bool + wantResult string + }{ + { + name: "valid eq condition", + matchType: "eq", + condition: "X-Amz-Date", + value: "20210324T000000Z", + wantResult: `"eq","$X-Amz-Date","20210324T000000Z"`, + }, + { + name: "empty value", + matchType: "eq", + condition: "X-Amz-Date", + value: "", + wantErr: true, + }, + { + name: "invalid condition", + matchType: "eq", + condition: "Invalid-Condition", + value: "somevalue", + wantErr: true, + }, + { + name: "valid starts-with condition", + matchType: "starts-with", + condition: "X-Amz-Credential", + value: "my-access-key", + wantResult: `"starts-with","$X-Amz-Credential","my-access-key"`, + }, + { + name: "empty condition", + matchType: "eq", + condition: "", + value: "somevalue", + wantErr: true, + }, + { + name: "empty matchType", + matchType: "", + condition: "X-Amz-Date", + value: "somevalue", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetCondition(tt.matchType, tt.condition, tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetTagging(t *testing.T) { + tests := []struct { + name string + tagging string + wantErr bool + wantResult string + }{ + { + name: "valid tagging", + tagging: `key1value1`, + wantResult: `"eq","$tagging","key1value1"`, + }, + { + name: "empty tagging", + tagging: "", + wantErr: true, + }, + { + name: "whitespace tagging", + tagging: " ", + wantErr: true, + }, + { + name: "invalid XML", + tagging: `key1value1`, + wantErr: true, + }, + { + name: "invalid schema", + tagging: ``, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetTagging(tt.tagging) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetUserMetadata(t *testing.T) { + tests := []struct { + name string + key string + value string + wantErr bool + wantResult string + }{ + { + name: "valid metadata", + key: "user-key", + value: "user-value", + wantResult: `"eq","$x-amz-meta-user-key","user-value"`, + }, + { + name: "empty key", + key: "", + value: "somevalue", + wantErr: true, + }, + { + name: "empty value", + key: "user-key", + value: "", + wantErr: true, + }, + { + name: "key with spaces", + key: " ", + value: "somevalue", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetUserMetadata(tt.key, tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetChecksum(t *testing.T) { + tests := []struct { + name string + checksum Checksum + wantErr bool + wantResult string + }{ + { + name: "valid checksum SHA256", + checksum: ChecksumSHA256.ChecksumBytes([]byte("somerandomdata")), + wantResult: `[["eq","$x-amz-checksum-algorithm","SHA256"],["eq","$x-amz-checksum-sha256","29/7Qm/iMzZ1O3zMbO0luv6mYWyS6JIqPYV9lc8w1PA="]]`, + }, + { + name: "valid checksum CRC32", + checksum: ChecksumCRC32.ChecksumBytes([]byte("somerandomdata")), + wantResult: `[["eq","$x-amz-checksum-algorithm","CRC32"],["eq","$x-amz-checksum-crc32","7sOPnw=="]]`, + }, + { + name: "empty checksum", + checksum: Checksum{}, + wantResult: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + err := pp.SetChecksum(tt.checksum) + if (err != nil) != tt.wantErr { + t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err) + } + + if tt.wantResult != "" { + result := pp.String() + if !strings.Contains(result, tt.wantResult) { + t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result) + } + } + }) + } +} + +func TestPostPolicySetEncryption(t *testing.T) { + tests := []struct { + name string + sseType string + keyID string + want map[string]string + }{ + { + name: "SSE-S3 encryption", + sseType: "SSE-S3", + keyID: "my-key-id", + want: map[string]string{ + "X-Amz-Server-Side-Encryption": "aws:kms", + "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": "my-key-id", + }, + }, + { + name: "SSE-C encryption with Key ID", + sseType: "SSE-C", + keyID: "my-key-id", + want: map[string]string{ + "X-Amz-Server-Side-Encryption-Customer-Key": "bXktc2VjcmV0LWtleTEyMzQ1Njc4OTBhYmNkZWZnaGk=", + "X-Amz-Server-Side-Encryption-Customer-Key-Md5": "T1mefJwyXBH43sRtfEgRZQ==", + "X-Amz-Server-Side-Encryption-Customer-Algorithm": "AES256", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pp := NewPostPolicy() + + var sse encrypt.ServerSide + var err error + if tt.sseType == "SSE-S3" { + sse, err = encrypt.NewSSEKMS(tt.keyID, nil) + if err != nil { + t.Fatalf("Failed to create SSE-KMS: %v", err) + } + } else if tt.sseType == "SSE-C" { + sse, err = encrypt.NewSSEC([]byte("my-secret-key1234567890abcdefghi")) + if err != nil { + t.Fatalf("Failed to create SSE-C: %v", err) + } + } else { + t.Fatalf("Unknown SSE type: %s", tt.sseType) + } + + pp.SetEncryption(sse) + + for k, v := range tt.want { + if pp.formData[k] != v { + t.Errorf("%s: want %s: %s, got: %s", tt.name, k, v, pp.formData[k]) + } + } + }) + } +} diff --git a/retry.go b/retry.go index d15eb5901..134121b43 100644 --- a/retry.go +++ b/retry.go @@ -45,7 +45,7 @@ var DefaultRetryCap = time.Second // newRetryTimer creates a timer with exponentially increasing // delays until the maximum retry attempts are reached. -func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, unit, cap time.Duration, jitter float64) <-chan int { +func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, unit, maxSleep time.Duration, jitter float64) <-chan int { attemptCh := make(chan int) // computes the exponential backoff duration according to @@ -59,10 +59,10 @@ func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, unit, cap time jitter = MaxJitter } - // sleep = random_between(0, min(cap, base * 2 ** attempt)) + // sleep = random_between(0, min(maxSleep, base * 2 ** attempt)) sleep := unit * time.Duration(1< cap { - sleep = cap + if sleep > maxSleep { + sleep = maxSleep } if jitter != NoJitter { sleep -= time.Duration(c.random.Float64() * float64(sleep) * jitter)