Skip to content

Commit

Permalink
Merge pull request #501 from liamg/patch-1
Browse files Browse the repository at this point in the history
fix: Avoid panic when s3 URL is invalid
  • Loading branch information
radeksimko committed Aug 14, 2024
2 parents 5a63fd9 + 8339301 commit 4f07d24
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 8 deletions.
14 changes: 13 additions & 1 deletion get_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,23 +268,35 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
region = "us-east-1"
}
pathParts := strings.SplitN(u.Path, "/", 3)
if len(pathParts) < 3 {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
bucket = pathParts[1]
path = pathParts[2]
// vhost-style, dash region indication
case 4:
// Parse the region out of the first part of the host
// Parse the region out of the second part of the host
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[1], "s3-"), "s3")
if region == "" {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
pathParts := strings.SplitN(u.Path, "/", 2)
if len(pathParts) < 2 {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
bucket = hostParts[0]
path = pathParts[1]
//vhost-style, dot region indication
case 5:
region = hostParts[2]
pathParts := strings.SplitN(u.Path, "/", 2)
if len(pathParts) < 2 {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
bucket = hostParts[0]
path = pathParts[1]

Expand Down
65 changes: 58 additions & 7 deletions get_s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,13 @@ func TestS3Getter_ClientMode_collision(t *testing.T) {

func TestS3Getter_Url(t *testing.T) {
var s3tests = []struct {
name string
url string
region string
bucket string
path string
version string
name string
url string
region string
bucket string
path string
version string
expectedErr string
}{
{
name: "AWSv1234",
Expand Down Expand Up @@ -220,6 +221,11 @@ func TestS3Getter_Url(t *testing.T) {
path: "hello.txt",
version: "",
},
{
name: "malformed s3 url",
url: "s3::https://s3.amazonaws.com/bucket",
expectedErr: "URL is not a valid S3 URL",
},
}

for i, pt := range s3tests {
Expand All @@ -238,7 +244,15 @@ func TestS3Getter_Url(t *testing.T) {
region, bucket, path, version, creds, err := g.parseUrl(u)

if err != nil {
t.Fatalf("err: %s", err)
if pt.expectedErr == "" {
t.Fatalf("err: %s", err)
}
if err.Error() != pt.expectedErr {
t.Fatalf("expected %s, got %s", pt.expectedErr, err.Error())
}
return
} else if pt.expectedErr != "" {
t.Fatalf("expected error, got none")
}
if region != pt.region {
t.Fatalf("expected %s, got %s", pt.region, region)
Expand All @@ -258,3 +272,40 @@ func TestS3Getter_Url(t *testing.T) {
})
}
}

func Test_S3Getter_ParseUrl_Malformed(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "path style",
url: "https://s3.amazonaws.com/bucket",
},
{
name: "vhost-style, dash region indication",
url: "https://bucket.s3-us-east-1.amazonaws.com",
},
{
name: "vhost-style, dot region indication",
url: "https://bucket.s3.us-east-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := new(S3Getter)
u, err := url.Parse(tt.url)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
_, _, _, _, _, err = g.parseUrl(u)
if err == nil {
t.Fatalf("expected error, got none")
}
if err.Error() != "URL is not a valid S3 URL" {
t.Fatalf("expected error 'URL is not a valid S3 URL', got %s", err.Error())
}
})
}

}

0 comments on commit 4f07d24

Please sign in to comment.