diff --git a/functional_tests.go b/functional_tests.go
index 714bb3cc9..a3dbc49c9 100644
--- a/functional_tests.go
+++ b/functional_tests.go
@@ -5643,9 +5643,6 @@ func testPresignedPostPolicy() {
"policy": "",
}
- // Seed random based on current time.
- rand.Seed(time.Now().Unix())
-
// Instantiate new minio client object
c, err := minio.New(os.Getenv(serverEndpoint),
&minio.Options{
@@ -5692,38 +5689,13 @@ func testPresignedPostPolicy() {
}
policy := minio.NewPostPolicy()
-
- if err := policy.SetBucket(""); err == nil {
- logError(testName, function, args, startTime, "", "SetBucket did not fail for invalid conditions", err)
- return
- }
- if err := policy.SetKey(""); err == nil {
- logError(testName, function, args, startTime, "", "SetKey did not fail for invalid conditions", err)
- return
- }
- if err := policy.SetExpires(time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC)); err == nil {
- logError(testName, function, args, startTime, "", "SetExpires did not fail for invalid conditions", err)
- return
- }
- if err := policy.SetContentType(""); err == nil {
- logError(testName, function, args, startTime, "", "SetContentType did not fail for invalid conditions", err)
- return
- }
- if err := policy.SetContentLengthRange(1024*1024, 1024); err == nil {
- logError(testName, function, args, startTime, "", "SetContentLengthRange did not fail for invalid conditions", err)
- return
- }
- if err := policy.SetUserMetadata("", ""); err == nil {
- logError(testName, function, args, startTime, "", "SetUserMetadata did not fail for invalid conditions", err)
- return
- }
-
policy.SetBucket(bucketName)
policy.SetKey(objectName)
policy.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) // expires in 10 days
policy.SetContentType("binary/octet-stream")
policy.SetContentLengthRange(10, 1024*1024)
policy.SetUserMetadata(metadataKey, metadataValue)
+ policy.SetContentEncoding("gzip")
// Add CRC32C
checksum := minio.ChecksumCRC32C.ChecksumBytes(buf)
@@ -5865,6 +5837,186 @@ func testPresignedPostPolicy() {
logSuccess(testName, function, args, startTime)
}
+// testPresignedPostPolicyWrongFile tests that when we have a policy with a checksum, we cannot POST the wrong file
+func testPresignedPostPolicyWrongFile() {
+ // initialize logging params
+ startTime := time.Now()
+ testName := getFuncName()
+ function := "PresignedPostPolicy(policy)"
+ args := map[string]interface{}{
+ "policy": "",
+ }
+
+ // Instantiate new minio client object
+ c, err := minio.New(os.Getenv(serverEndpoint),
+ &minio.Options{
+ Creds: credentials.NewStaticV4(os.Getenv(accessKey), os.Getenv(secretKey), ""),
+ Transport: createHTTPTransport(),
+ Secure: mustParseBool(os.Getenv(enableHTTPS)),
+ })
+ if err != nil {
+ logError(testName, function, args, startTime, "", "MinIO client object creation failed", err)
+ return
+ }
+
+ // Enable tracing, write to stderr.
+ // c.TraceOn(os.Stderr)
+
+ // Set user agent.
+ c.SetAppInfo("MinIO-go-FunctionalTest", appVersion)
+
+ // Generate a new random bucket name.
+ bucketName := randString(60, rand.NewSource(time.Now().UnixNano()), "minio-go-test-")
+
+ // Make a new bucket in 'us-east-1' (source bucket).
+ err = c.MakeBucket(context.Background(), bucketName, minio.MakeBucketOptions{Region: "us-east-1"})
+ if err != nil {
+ logError(testName, function, args, startTime, "", "MakeBucket failed", err)
+ return
+ }
+
+ defer cleanupBucket(bucketName, c)
+
+ // Generate 33K of data.
+ reader := getDataReader("datafile-33-kB")
+ defer reader.Close()
+
+ objectName := randString(60, rand.NewSource(time.Now().UnixNano()), "")
+ // Azure requires the key to not start with a number
+ metadataKey := randString(60, rand.NewSource(time.Now().UnixNano()), "user")
+ metadataValue := randString(60, rand.NewSource(time.Now().UnixNano()), "")
+
+ buf, err := io.ReadAll(reader)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "ReadAll failed", err)
+ return
+ }
+
+ policy := minio.NewPostPolicy()
+ policy.SetBucket(bucketName)
+ policy.SetKey(objectName)
+ policy.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) // expires in 10 days
+ policy.SetContentType("binary/octet-stream")
+ policy.SetContentLengthRange(10, 1024*1024)
+ policy.SetUserMetadata(metadataKey, metadataValue)
+
+ // Add CRC32C of the 33kB file that the policy will explicitly allow.
+ checksum := minio.ChecksumCRC32C.ChecksumBytes(buf)
+ err = policy.SetChecksum(checksum)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "SetChecksum failed", err)
+ return
+ }
+
+ args["policy"] = policy.String()
+
+ presignedPostPolicyURL, formData, err := c.PresignedPostPolicy(context.Background(), policy)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "PresignedPostPolicy failed", err)
+ return
+ }
+
+ // At this stage, we have a policy that allows us to upload datafile-33-kB.
+ // Test that uploading datafile-10-kB, with a different checksum, fails as expected
+ filePath := getMintDataDirFilePath("datafile-10-kB")
+ if filePath == "" {
+ // Make a temp file with 10 KB data.
+ file, err := os.CreateTemp(os.TempDir(), "PresignedPostPolicyTest")
+ if err != nil {
+ logError(testName, function, args, startTime, "", "TempFile creation failed", err)
+ return
+ }
+ if _, err = io.Copy(file, getDataReader("datafile-10-kB")); err != nil {
+ logError(testName, function, args, startTime, "", "Copy failed", err)
+ return
+ }
+ if err = file.Close(); err != nil {
+ logError(testName, function, args, startTime, "", "File Close failed", err)
+ return
+ }
+ filePath = file.Name()
+ }
+ fileReader := getDataReader("datafile-10-kB")
+ defer fileReader.Close()
+ buf10k, err := io.ReadAll(fileReader)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "ReadAll failed", err)
+ return
+ }
+ otherChecksum := minio.ChecksumCRC32C.ChecksumBytes(buf10k)
+
+ var formBuf bytes.Buffer
+ writer := multipart.NewWriter(&formBuf)
+ for k, v := range formData {
+ if k == "x-amz-checksum-crc32c" {
+ v = otherChecksum.Encoded()
+ }
+ writer.WriteField(k, v)
+ }
+
+ // Add file to post request
+ f, err := os.Open(filePath)
+ defer f.Close()
+ if err != nil {
+ logError(testName, function, args, startTime, "", "File open failed", err)
+ return
+ }
+ w, err := writer.CreateFormFile("file", filePath)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "CreateFormFile failed", err)
+ return
+ }
+ _, err = io.Copy(w, f)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "Copy failed", err)
+ return
+ }
+ writer.Close()
+
+ httpClient := &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: createHTTPTransport(),
+ }
+ args["url"] = presignedPostPolicyURL.String()
+
+ req, err := http.NewRequest(http.MethodPost, presignedPostPolicyURL.String(), bytes.NewReader(formBuf.Bytes()))
+ if err != nil {
+ logError(testName, function, args, startTime, "", "HTTP request failed", err)
+ return
+ }
+
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+
+ // Make the POST request with the form data.
+ res, err := httpClient.Do(req)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "HTTP request failed", err)
+ return
+ }
+ defer res.Body.Close()
+ if res.StatusCode != http.StatusForbidden {
+ logError(testName, function, args, startTime, "", "HTTP request unexpected status", errors.New(res.Status))
+ return
+ }
+
+ // Read the response body, ensure it has checksum failure message
+ resBody, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(testName, function, args, startTime, "", "ReadAll failed", err)
+ return
+ }
+
+ // Normalize the response body, because S3 uses quotes around the policy condition components
+ // in the error message, MinIO does not.
+ resBodyStr := strings.ReplaceAll(string(resBody), `"`, "")
+ if !strings.Contains(resBodyStr, "Policy Condition failed: [eq, $x-amz-checksum-crc32c, aHnJMw==]") {
+ logError(testName, function, args, startTime, "", "Unexpected response body", errors.New(resBodyStr))
+ return
+ }
+
+ logSuccess(testName, function, args, startTime)
+}
+
// Tests copy object
func testCopyObject() {
// initialize logging params
@@ -14977,6 +15129,7 @@ func main() {
testGetObjectReadAtFunctional()
testGetObjectReadAtWhenEOFWasReached()
testPresignedPostPolicy()
+ testPresignedPostPolicyWrongFile()
testCopyObject()
testComposeObjectErrorCases()
testCompose10KSources()
diff --git a/post-policy.go b/post-policy.go
index 1b01f4101..00c770d74 100644
--- a/post-policy.go
+++ b/post-policy.go
@@ -85,7 +85,7 @@ func (p *PostPolicy) SetExpires(t time.Time) error {
// SetKey - Sets an object name for the policy based upload.
func (p *PostPolicy) SetKey(key string) error {
- if strings.TrimSpace(key) == "" || key == "" {
+ if strings.TrimSpace(key) == "" {
return errInvalidArgument("Object name is empty.")
}
policyCond := policyCondition{
@@ -118,7 +118,7 @@ func (p *PostPolicy) SetKeyStartsWith(keyStartsWith string) error {
// SetBucket - Sets bucket at which objects will be uploaded to.
func (p *PostPolicy) SetBucket(bucketName string) error {
- if strings.TrimSpace(bucketName) == "" || bucketName == "" {
+ if strings.TrimSpace(bucketName) == "" {
return errInvalidArgument("Bucket name is empty.")
}
policyCond := policyCondition{
@@ -135,7 +135,7 @@ func (p *PostPolicy) SetBucket(bucketName string) error {
// SetCondition - Sets condition for credentials, date and algorithm
func (p *PostPolicy) SetCondition(matchType, condition, value string) error {
- if strings.TrimSpace(value) == "" || value == "" {
+ if strings.TrimSpace(value) == "" {
return errInvalidArgument("No value specified for condition")
}
@@ -156,7 +156,7 @@ func (p *PostPolicy) SetCondition(matchType, condition, value string) error {
// SetTagging - Sets tagging for the object for this policy based upload.
func (p *PostPolicy) SetTagging(tagging string) error {
- if strings.TrimSpace(tagging) == "" || tagging == "" {
+ if strings.TrimSpace(tagging) == "" {
return errInvalidArgument("No tagging specified.")
}
_, err := tags.ParseObjectXML(strings.NewReader(tagging))
@@ -178,7 +178,7 @@ func (p *PostPolicy) SetTagging(tagging string) error {
// SetContentType - Sets content-type of the object for this policy
// based upload.
func (p *PostPolicy) SetContentType(contentType string) error {
- if strings.TrimSpace(contentType) == "" || contentType == "" {
+ if strings.TrimSpace(contentType) == "" {
return errInvalidArgument("No content type specified.")
}
policyCond := policyCondition{
@@ -211,7 +211,7 @@ func (p *PostPolicy) SetContentTypeStartsWith(contentTypeStartsWith string) erro
// SetContentDisposition - Sets content-disposition of the object for this policy
func (p *PostPolicy) SetContentDisposition(contentDisposition string) error {
- if strings.TrimSpace(contentDisposition) == "" || contentDisposition == "" {
+ if strings.TrimSpace(contentDisposition) == "" {
return errInvalidArgument("No content disposition specified.")
}
policyCond := policyCondition{
@@ -226,27 +226,44 @@ func (p *PostPolicy) SetContentDisposition(contentDisposition string) error {
return nil
}
+// SetContentEncoding - Sets content-encoding of the object for this policy
+func (p *PostPolicy) SetContentEncoding(contentEncoding string) error {
+ if strings.TrimSpace(contentEncoding) == "" {
+ return errInvalidArgument("No content encoding specified.")
+ }
+ policyCond := policyCondition{
+ matchType: "eq",
+ condition: "$Content-Encoding",
+ value: contentEncoding,
+ }
+ if err := p.addNewPolicy(policyCond); err != nil {
+ return err
+ }
+ p.formData["Content-Encoding"] = contentEncoding
+ return nil
+}
+
// SetContentLengthRange - Set new min and max content length
// condition for all incoming uploads.
-func (p *PostPolicy) SetContentLengthRange(min, max int64) error {
- if min > max {
+func (p *PostPolicy) SetContentLengthRange(minv, maxv int64) error {
+ if minv > maxv {
return errInvalidArgument("Minimum limit is larger than maximum limit.")
}
- if min < 0 {
+ if minv < 0 {
return errInvalidArgument("Minimum limit cannot be negative.")
}
- if max <= 0 {
+ if maxv <= 0 {
return errInvalidArgument("Maximum limit cannot be non-positive.")
}
- p.contentLengthRange.min = min
- p.contentLengthRange.max = max
+ p.contentLengthRange.min = minv
+ p.contentLengthRange.max = maxv
return nil
}
// SetSuccessActionRedirect - Sets the redirect success url of the object for this policy
// based upload.
func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error {
- if strings.TrimSpace(redirect) == "" || redirect == "" {
+ if strings.TrimSpace(redirect) == "" {
return errInvalidArgument("Redirect is empty")
}
policyCond := policyCondition{
@@ -264,7 +281,7 @@ func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error {
// SetSuccessStatusAction - Sets the status success code of the object for this policy
// based upload.
func (p *PostPolicy) SetSuccessStatusAction(status string) error {
- if strings.TrimSpace(status) == "" || status == "" {
+ if strings.TrimSpace(status) == "" {
return errInvalidArgument("Status is empty")
}
policyCond := policyCondition{
@@ -282,10 +299,10 @@ func (p *PostPolicy) SetSuccessStatusAction(status string) error {
// SetUserMetadata - Set user metadata as a key/value couple.
// Can be retrieved through a HEAD request or an event.
func (p *PostPolicy) SetUserMetadata(key, value string) error {
- if strings.TrimSpace(key) == "" || key == "" {
+ if strings.TrimSpace(key) == "" {
return errInvalidArgument("Key is empty")
}
- if strings.TrimSpace(value) == "" || value == "" {
+ if strings.TrimSpace(value) == "" {
return errInvalidArgument("Value is empty")
}
headerName := fmt.Sprintf("x-amz-meta-%s", key)
@@ -304,7 +321,7 @@ func (p *PostPolicy) SetUserMetadata(key, value string) error {
// SetUserMetadataStartsWith - Set how an user metadata should starts with.
// Can be retrieved through a HEAD request or an event.
func (p *PostPolicy) SetUserMetadataStartsWith(key, value string) error {
- if strings.TrimSpace(key) == "" || key == "" {
+ if strings.TrimSpace(key) == "" {
return errInvalidArgument("Key is empty")
}
headerName := fmt.Sprintf("x-amz-meta-%s", key)
@@ -326,8 +343,6 @@ func (p *PostPolicy) SetChecksum(c Checksum) error {
p.formData[amzChecksumAlgo] = c.Type.String()
p.formData[c.Type.Key()] = c.Encoded()
- // Needed for S3 compatibility. MinIO ignores the checksum keys in the policy.
- // https://github.com/minio/minio/blob/RELEASE.2024-08-29T01-40-52Z/cmd/postpolicyform.go#L60-L65
policyCond := policyCondition{
matchType: "eq",
condition: fmt.Sprintf("$%s", amzChecksumAlgo),
diff --git a/post-policy_test.go b/post-policy_test.go
new file mode 100644
index 000000000..c105e053c
--- /dev/null
+++ b/post-policy_test.go
@@ -0,0 +1,459 @@
+/*
+ * MinIO Go Library for Amazon S3 Compatible Cloud Storage
+ * Copyright 2015-2023 MinIO, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package minio
+
+import (
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/minio/minio-go/v7/pkg/encrypt"
+)
+
+func TestPostPolicySetExpires(t *testing.T) {
+ tests := []struct {
+ name string
+ input time.Time
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid time",
+ input: time.Date(2023, time.March, 2, 15, 4, 5, 0, time.UTC),
+ wantErr: false,
+ wantResult: "2023-03-02T15:04:05",
+ },
+ {
+ name: "time before 1970",
+ input: time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC),
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetExpires(tt.input)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetKey(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid key",
+ input: "my-object",
+ wantResult: `"eq","$key","my-object"`,
+ },
+ {
+ name: "empty key",
+ input: "",
+ wantErr: true,
+ },
+ {
+ name: "key with spaces",
+ input: " ",
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetKey(tt.input)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetKeyStartsWith(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "valid key prefix",
+ input: "my-prefix/",
+ want: `["starts-with","$key","my-prefix/"]`,
+ },
+ {
+ name: "empty prefix (allow any key)",
+ input: "",
+ want: `["starts-with","$key",""]`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetKeyStartsWith(tt.input)
+ if err != nil {
+ t.Errorf("%s: want no error, got: %v", tt.name, err)
+ }
+
+ if tt.want != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.want) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.want, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetBucket(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid bucket",
+ input: "my-bucket",
+ wantResult: `"eq","$bucket","my-bucket"`,
+ },
+ {
+ name: "empty bucket",
+ input: "",
+ wantErr: true,
+ },
+ {
+ name: "bucket with spaces",
+ input: " ",
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetBucket(tt.input)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetCondition(t *testing.T) {
+ tests := []struct {
+ name string
+ matchType string
+ condition string
+ value string
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid eq condition",
+ matchType: "eq",
+ condition: "X-Amz-Date",
+ value: "20210324T000000Z",
+ wantResult: `"eq","$X-Amz-Date","20210324T000000Z"`,
+ },
+ {
+ name: "empty value",
+ matchType: "eq",
+ condition: "X-Amz-Date",
+ value: "",
+ wantErr: true,
+ },
+ {
+ name: "invalid condition",
+ matchType: "eq",
+ condition: "Invalid-Condition",
+ value: "somevalue",
+ wantErr: true,
+ },
+ {
+ name: "valid starts-with condition",
+ matchType: "starts-with",
+ condition: "X-Amz-Credential",
+ value: "my-access-key",
+ wantResult: `"starts-with","$X-Amz-Credential","my-access-key"`,
+ },
+ {
+ name: "empty condition",
+ matchType: "eq",
+ condition: "",
+ value: "somevalue",
+ wantErr: true,
+ },
+ {
+ name: "empty matchType",
+ matchType: "",
+ condition: "X-Amz-Date",
+ value: "somevalue",
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetCondition(tt.matchType, tt.condition, tt.value)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetTagging(t *testing.T) {
+ tests := []struct {
+ name string
+ tagging string
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid tagging",
+ tagging: `key1value1`,
+ wantResult: `"eq","$tagging","key1value1"`,
+ },
+ {
+ name: "empty tagging",
+ tagging: "",
+ wantErr: true,
+ },
+ {
+ name: "whitespace tagging",
+ tagging: " ",
+ wantErr: true,
+ },
+ {
+ name: "invalid XML",
+ tagging: `key1value1`,
+ wantErr: true,
+ },
+ {
+ name: "invalid schema",
+ tagging: ``,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetTagging(tt.tagging)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetUserMetadata(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ value string
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid metadata",
+ key: "user-key",
+ value: "user-value",
+ wantResult: `"eq","$x-amz-meta-user-key","user-value"`,
+ },
+ {
+ name: "empty key",
+ key: "",
+ value: "somevalue",
+ wantErr: true,
+ },
+ {
+ name: "empty value",
+ key: "user-key",
+ value: "",
+ wantErr: true,
+ },
+ {
+ name: "key with spaces",
+ key: " ",
+ value: "somevalue",
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetUserMetadata(tt.key, tt.value)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetChecksum(t *testing.T) {
+ tests := []struct {
+ name string
+ checksum Checksum
+ wantErr bool
+ wantResult string
+ }{
+ {
+ name: "valid checksum SHA256",
+ checksum: ChecksumSHA256.ChecksumBytes([]byte("somerandomdata")),
+ wantResult: `[["eq","$x-amz-checksum-algorithm","SHA256"],["eq","$x-amz-checksum-sha256","29/7Qm/iMzZ1O3zMbO0luv6mYWyS6JIqPYV9lc8w1PA="]]`,
+ },
+ {
+ name: "valid checksum CRC32",
+ checksum: ChecksumCRC32.ChecksumBytes([]byte("somerandomdata")),
+ wantResult: `[["eq","$x-amz-checksum-algorithm","CRC32"],["eq","$x-amz-checksum-crc32","7sOPnw=="]]`,
+ },
+ {
+ name: "empty checksum",
+ checksum: Checksum{},
+ wantResult: "",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ err := pp.SetChecksum(tt.checksum)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%s: want error: %v, got: %v", tt.name, tt.wantErr, err)
+ }
+
+ if tt.wantResult != "" {
+ result := pp.String()
+ if !strings.Contains(result, tt.wantResult) {
+ t.Errorf("%s: want result to contain: '%s', got: '%s'", tt.name, tt.wantResult, result)
+ }
+ }
+ })
+ }
+}
+
+func TestPostPolicySetEncryption(t *testing.T) {
+ tests := []struct {
+ name string
+ sseType string
+ keyID string
+ want map[string]string
+ }{
+ {
+ name: "SSE-S3 encryption",
+ sseType: "SSE-S3",
+ keyID: "my-key-id",
+ want: map[string]string{
+ "X-Amz-Server-Side-Encryption": "aws:kms",
+ "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": "my-key-id",
+ },
+ },
+ {
+ name: "SSE-C encryption with Key ID",
+ sseType: "SSE-C",
+ keyID: "my-key-id",
+ want: map[string]string{
+ "X-Amz-Server-Side-Encryption-Customer-Key": "bXktc2VjcmV0LWtleTEyMzQ1Njc4OTBhYmNkZWZnaGk=",
+ "X-Amz-Server-Side-Encryption-Customer-Key-Md5": "T1mefJwyXBH43sRtfEgRZQ==",
+ "X-Amz-Server-Side-Encryption-Customer-Algorithm": "AES256",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pp := NewPostPolicy()
+
+ var sse encrypt.ServerSide
+ var err error
+ if tt.sseType == "SSE-S3" {
+ sse, err = encrypt.NewSSEKMS(tt.keyID, nil)
+ if err != nil {
+ t.Fatalf("Failed to create SSE-KMS: %v", err)
+ }
+ } else if tt.sseType == "SSE-C" {
+ sse, err = encrypt.NewSSEC([]byte("my-secret-key1234567890abcdefghi"))
+ if err != nil {
+ t.Fatalf("Failed to create SSE-C: %v", err)
+ }
+ } else {
+ t.Fatalf("Unknown SSE type: %s", tt.sseType)
+ }
+
+ pp.SetEncryption(sse)
+
+ for k, v := range tt.want {
+ if pp.formData[k] != v {
+ t.Errorf("%s: want %s: %s, got: %s", tt.name, k, v, pp.formData[k])
+ }
+ }
+ })
+ }
+}
diff --git a/retry.go b/retry.go
index d15eb5901..134121b43 100644
--- a/retry.go
+++ b/retry.go
@@ -45,7 +45,7 @@ var DefaultRetryCap = time.Second
// newRetryTimer creates a timer with exponentially increasing
// delays until the maximum retry attempts are reached.
-func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, unit, cap time.Duration, jitter float64) <-chan int {
+func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, unit, maxSleep time.Duration, jitter float64) <-chan int {
attemptCh := make(chan int)
// computes the exponential backoff duration according to
@@ -59,10 +59,10 @@ func (c *Client) newRetryTimer(ctx context.Context, maxRetry int, unit, cap time
jitter = MaxJitter
}
- // sleep = random_between(0, min(cap, base * 2 ** attempt))
+ // sleep = random_between(0, min(maxSleep, base * 2 ** attempt))
sleep := unit * time.Duration(1< cap {
- sleep = cap
+ if sleep > maxSleep {
+ sleep = maxSleep
}
if jitter != NoJitter {
sleep -= time.Duration(c.random.Float64() * float64(sleep) * jitter)