diff --git a/pmtiles/bucket.go b/pmtiles/bucket.go index 70850d5..8664352 100644 --- a/pmtiles/bucket.go +++ b/pmtiles/bucket.go @@ -13,12 +13,19 @@ import ( "os" "path" "path/filepath" + "strconv" "strings" + "cloud.google.com/go/storage" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/s3" "github.com/cespare/xxhash/v2" "gocloud.dev/blob" + "google.golang.org/api/googleapi" ) // Bucket is an abstration over a gocloud or plain HTTP bucket. @@ -211,35 +218,86 @@ func (ba BucketAdapter) NewRangeReader(ctx context.Context, key string, offset, return body, err } +func etagToGeneration(etag string) int64 { + i, _ := strconv.ParseInt(etag, 10, 64) + return i +} + +func generationToEtag(generation int64) string { + return strconv.FormatInt(generation, 10) +} + +func setProviderEtag(asFunc func(interface{}) bool, etag string) { + var awsV1Req *s3.GetObjectInput + var azblobReq *azblob.DownloadStreamOptions + var gcsHandle **storage.ObjectHandle + if asFunc(&awsV1Req) { + awsV1Req.IfMatch = aws.String(etag) + } else if asFunc(&azblobReq) { + azEtag := azcore.ETag(etag) + azblobReq.AccessConditions = &azblob.AccessConditions{ + ModifiedAccessConditions: &container.ModifiedAccessConditions{ + IfMatch: &azEtag, + }, + } + } else if asFunc(&gcsHandle) { + *gcsHandle = (*gcsHandle).If(storage.Conditions{ + GenerationMatch: etagToGeneration(etag), + }) + } +} + +func getProviderErrorStatusCode(err error) int { + var awsV1Err awserr.RequestFailure + var azureErr *azcore.ResponseError + var gcpErr *googleapi.Error + + if errors.As(err, &awsV1Err); awsV1Err != nil { + return awsV1Err.StatusCode() + } else if errors.As(err, &azureErr); azureErr != nil { + return azureErr.StatusCode + } else if errors.As(err, &gcpErr); gcpErr != nil { + return gcpErr.Code + } + return 404 +} + +func getProviderEtag(reader *blob.Reader) string { + var awsV1Resp s3.GetObjectOutput + var azureResp azblob.DownloadStreamResponse + var gcpResp *storage.Reader + + if reader.As(&awsV1Resp) { + return *awsV1Resp.ETag + } else if reader.As(&azureResp) { + return string(*azureResp.ETag) + } else if reader.As(&gcpResp) { + return generationToEtag(gcpResp.Attrs.Generation) + } + + return "" +} + func (ba BucketAdapter) NewRangeReaderEtag(ctx context.Context, key string, offset, length int64, etag string) (io.ReadCloser, string, int, error) { reader, err := ba.Bucket.NewRangeReader(ctx, key, offset, length, &blob.ReaderOptions{ BeforeRead: func(asFunc func(interface{}) bool) error { - var req *s3.GetObjectInput - if len(etag) > 0 && asFunc(&req) { - req.IfMatch = &etag + if len(etag) > 0 { + setProviderEtag(asFunc, etag) } return nil }, }) status := 206 if err != nil { - var resp awserr.RequestFailure - errors.As(err, &resp) - status = 404 - if resp != nil { - status = resp.StatusCode() - if isRefreshRequiredCode(resp.StatusCode()) { - return nil, "", resp.StatusCode(), &RefreshRequiredError{resp.StatusCode()} - } + status = getProviderErrorStatusCode(err) + if isRefreshRequiredCode(status) { + return nil, "", status, &RefreshRequiredError{status} } + return nil, "", status, err } - resultETag := "" - var resp s3.GetObjectOutput - if reader.As(&resp) { - resultETag = *resp.ETag - } - return reader, resultETag, status, nil + + return reader, getProviderEtag(reader), status, nil } func (ba BucketAdapter) Close() error { diff --git a/pmtiles/bucket_test.go b/pmtiles/bucket_test.go index cc976d1..eee13e8 100644 --- a/pmtiles/bucket_test.go +++ b/pmtiles/bucket_test.go @@ -2,6 +2,7 @@ package pmtiles import ( "context" + "errors" "io" "net/http" "os" @@ -9,8 +10,14 @@ import ( "strings" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3" "github.com/stretchr/testify/assert" _ "gocloud.dev/blob/fileblob" + "google.golang.org/api/googleapi" ) func TestNormalizeLocalFile(t *testing.T) { @@ -206,3 +213,54 @@ func TestFileShorterThan16K(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 3, len(data)) } + +func TestSetProviderEtagAws(t *testing.T) { + var awsV1Req s3.GetObjectInput + assert.Nil(t, awsV1Req.IfMatch) + asFunc := func(i interface{}) bool { + v, ok := i.(**s3.GetObjectInput) + if ok { + *v = &awsV1Req + } + return true + } + setProviderEtag(asFunc, "123") + assert.Equal(t, aws.String("123"), awsV1Req.IfMatch) +} + +func TestSetProviderEtagAzure(t *testing.T) { + var azOptions azblob.DownloadStreamOptions + assert.Nil(t, azOptions.AccessConditions) + asFunc := func(i interface{}) bool { + v, ok := i.(**azblob.DownloadStreamOptions) + if ok { + *v = &azOptions + } + return ok + } + setProviderEtag(asFunc, "123") + assert.Equal(t, azcore.ETag("123"), *azOptions.AccessConditions.ModifiedAccessConditions.IfMatch) +} + +func TestGetProviderErrorStatusCode(t *testing.T) { + awsErr := awserr.NewRequestFailure(awserr.New("", "", nil), 500, "") + statusCode := getProviderErrorStatusCode(awsErr) + assert.Equal(t, 500, statusCode) + + azureErr := &azcore.ResponseError{StatusCode: 500} + statusCode = getProviderErrorStatusCode(azureErr) + assert.Equal(t, 500, statusCode) + + gcpErr := &googleapi.Error{Code: 500} + statusCode = getProviderErrorStatusCode(gcpErr) + assert.Equal(t, 500, statusCode) + + err := errors.New("generic error") + statusCode = getProviderErrorStatusCode(err) + assert.Equal(t, 404, statusCode) +} + +func TestGenerationEtag(t *testing.T) { + assert.Equal(t, int64(123), etagToGeneration("123")) + assert.Equal(t, "123", generationToEtag(int64(123))) +}