From a7db927da9cf4c6cf242a5db83e44a16d75a8291 Mon Sep 17 00:00:00 2001 From: Brenna N Epp Date: Tue, 3 Dec 2024 23:42:37 -0800 Subject: [PATCH] fix(storage): add backoff to gRPC write retries (#11200) --- storage/bucket_test.go | 8 +- storage/client_test.go | 6 +- storage/grpc_client.go | 134 ++++++++++++++++++++++++----- storage/grpc_client_test.go | 166 ++++++++++++++++++++++++++++++++++++ storage/http_client.go | 9 +- storage/invoke.go | 6 +- storage/invoke_test.go | 44 ++++++---- storage/storage.go | 78 ++++++++++++++--- storage/storage_test.go | 40 ++++----- 9 files changed, 411 insertions(+), 80 deletions(-) diff --git a/storage/bucket_test.go b/storage/bucket_test.go index abd117d6e45a..06cec765905d 100644 --- a/storage/bucket_test.go +++ b/storage/bucket_test.go @@ -1116,11 +1116,11 @@ func TestBucketRetryer(t *testing.T) { WithErrorFunc(func(err error) bool { return false })) }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: 2 * time.Second, Max: 30 * time.Second, Multiplier: 3, - }, + }), policy: RetryAlways, maxAttempts: expectedAttempts(5), shouldRetry: func(err error) bool { return false }, @@ -1135,9 +1135,9 @@ func TestBucketRetryer(t *testing.T) { })) }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Multiplier: 3, - }}, + })}, }, { name: "set policy only", diff --git a/storage/client_test.go b/storage/client_test.go index 0ec2044cdc24..928a99a2f2d8 100644 --- a/storage/client_test.go +++ b/storage/client_test.go @@ -1396,7 +1396,7 @@ func TestRetryMaxAttemptsEmulated(t *testing.T) { instructions := map[string][]string{"storage.buckets.get": {"return-503", "return-503", "return-503", "return-503", "return-503"}} testID := createRetryTest(t, client, instructions) ctx = callctx.SetHeaders(ctx, "x-retry-test-id", testID) - config := &retryConfig{maxAttempts: expectedAttempts(3), backoff: &gax.Backoff{Initial: 10 * time.Millisecond}} + config := &retryConfig{maxAttempts: expectedAttempts(3), backoff: gaxBackoffFromStruct(&gax.Backoff{Initial: 10 * time.Millisecond})} _, err = client.GetBucket(ctx, bucket, nil, idempotent(true), withRetryConfig(config)) var ae *apierror.APIError @@ -1421,7 +1421,7 @@ func TestTimeoutErrorEmulated(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) defer cancel() time.Sleep(5 * time.Nanosecond) - config := &retryConfig{backoff: &gax.Backoff{Initial: 10 * time.Millisecond}} + config := &retryConfig{backoff: gaxBackoffFromStruct(&gax.Backoff{Initial: 10 * time.Millisecond})} _, err := client.GetBucket(ctx, bucket, nil, idempotent(true), withRetryConfig(config)) // Error may come through as a context.DeadlineExceeded (HTTP) or status.DeadlineExceeded (gRPC) @@ -1447,7 +1447,7 @@ func TestRetryDeadlineExceedeEmulated(t *testing.T) { instructions := map[string][]string{"storage.buckets.get": {"return-504", "return-504"}} testID := createRetryTest(t, client, instructions) ctx = callctx.SetHeaders(ctx, "x-retry-test-id", testID) - config := &retryConfig{maxAttempts: expectedAttempts(4), backoff: &gax.Backoff{Initial: 10 * time.Millisecond}} + config := &retryConfig{maxAttempts: expectedAttempts(4), backoff: gaxBackoffFromStruct(&gax.Backoff{Initial: 10 * time.Millisecond})} if _, err := client.GetBucket(ctx, bucket, nil, idempotent(true), withRetryConfig(config)); err != nil { t.Fatalf("GetBucket: got unexpected error %v, want nil", err) } diff --git a/storage/grpc_client.go b/storage/grpc_client.go index 937360a4afd8..bb746b8f031a 100644 --- a/storage/grpc_client.go +++ b/storage/grpc_client.go @@ -29,6 +29,7 @@ import ( "cloud.google.com/go/internal/trace" gapic "cloud.google.com/go/storage/internal/apiv2" "cloud.google.com/go/storage/internal/apiv2/storagepb" + "github.com/google/uuid" "github.com/googleapis/gax-go/v2" "google.golang.org/api/googleapi" "google.golang.org/api/iterator" @@ -1223,7 +1224,7 @@ func (c *grpcStorageClient) OpenWriter(params *openWriterParams, opts ...storage } } - o, off, err := gw.uploadBuffer(recvd, offset, doneReading) + o, off, err := gw.uploadBuffer(recvd, offset, doneReading, newUploadBufferRetryConfig(gw.settings)) if err != nil { err = checkCanceled(err) errorf(err) @@ -2091,12 +2092,7 @@ func (w *gRPCWriter) queryProgress() (int64, error) { // completed. // // Returns object, persisted size, and any error that is not retriable. -func (w *gRPCWriter) uploadBuffer(recvd int, start int64, doneReading bool) (*storagepb.Object, int64, error) { - var shouldRetry = ShouldRetry - if w.settings.retry != nil && w.settings.retry.shouldRetry != nil { - shouldRetry = w.settings.retry.shouldRetry - } - +func (w *gRPCWriter) uploadBuffer(recvd int, start int64, doneReading bool, retryConfig *uploadBufferRetryConfig) (*storagepb.Object, int64, error) { var err error var lastWriteOfEntireObject bool @@ -2143,6 +2139,7 @@ sendBytes: // label this loop so that we can use a continue statement from a nes if w.stream == nil { hds := []string{"x-goog-request-params", fmt.Sprintf("bucket=projects/_/buckets/%s", url.QueryEscape(w.bucket))} ctx := gax.InsertMetadataIntoOutgoingContext(w.ctx, hds...) + ctx = setInvocationHeaders(ctx, retryConfig.invocationID, retryConfig.attempts) w.stream, err = w.c.raw.BidiWriteObject(ctx) if err != nil { @@ -2188,7 +2185,11 @@ sendBytes: // label this loop so that we can use a continue statement from a nes // Retriable errors mean we should start over and attempt to // resend the entire buffer via a new stream. // If not retriable, falling through will return the error received. - if shouldRetry(err) { + err = retryConfig.retriable(w.ctx, err) + + if err == nil { + retryConfig.doBackOff(w.ctx) + // TODO: Add test case for failure modes of querying progress. writeOffset, err = w.determineOffset(start) if err != nil { @@ -2230,11 +2231,17 @@ sendBytes: // label this loop so that we can use a continue statement from a nes if !lastWriteOfEntireObject { resp, err := w.stream.Recv() - // Retriable errors mean we should start over and attempt to - // resend the entire buffer via a new stream. - // If not retriable, falling through will return the error received - // from closing the stream. - if shouldRetry(err) { + if err != nil { + // Retriable errors mean we should start over and attempt to + // resend the entire buffer via a new stream. + // If not retriable, falling through will return the error received + // from closing the stream. + err = retryConfig.retriable(w.ctx, err) + if err != nil { + return nil, 0, err + } + + retryConfig.doBackOff(w.ctx) writeOffset, err = w.determineOffset(start) if err != nil { return nil, 0, err @@ -2246,9 +2253,6 @@ sendBytes: // label this loop so that we can use a continue statement from a nes continue sendBytes } - if err != nil { - return nil, 0, err - } if resp.GetPersistedSize() != writeOffset { // Retry if not all bytes were persisted. @@ -2274,7 +2278,14 @@ sendBytes: // label this loop so that we can use a continue statement from a nes var obj *storagepb.Object for obj == nil { resp, err := w.stream.Recv() - if shouldRetry(err) { + + if err != nil { + err = retryConfig.retriable(w.ctx, err) + if err != nil { + return nil, 0, err + } + retryConfig.doBackOff(w.ctx) + writeOffset, err = w.determineOffset(start) if err != nil { return nil, 0, err @@ -2283,9 +2294,6 @@ sendBytes: // label this loop so that we can use a continue statement from a nes w.stream = nil continue sendBytes } - if err != nil { - return nil, 0, err - } obj = resp.GetResource() } @@ -2370,3 +2378,89 @@ func checkCanceled(err error) error { return err } + +type uploadBufferRetryConfig struct { + attempts int + invocationID string + config *retryConfig + lastErr error +} + +func newUploadBufferRetryConfig(settings *settings) *uploadBufferRetryConfig { + config := settings.retry + + if config == nil { + config = defaultRetry.clone() + } + + if config.shouldRetry == nil { + config.shouldRetry = ShouldRetry + } + + if config.backoff == nil { + config.backoff = &gaxBackoff{} + } else { + config.backoff.SetMultiplier(settings.retry.backoff.GetMultiplier()) + config.backoff.SetInitial(settings.retry.backoff.GetInitial()) + config.backoff.SetMax(settings.retry.backoff.GetMax()) + } + + return &uploadBufferRetryConfig{ + attempts: 1, + invocationID: uuid.New().String(), + config: config, + } +} + +// retriable determines if a retry is necessary and if so returns a nil error; +// otherwise it returns the error to be surfaced to the user. +func (retry *uploadBufferRetryConfig) retriable(ctx context.Context, err error) error { + if err == nil { + // a nil err does not need to be retried + return nil + } + if err != context.Canceled && err != context.DeadlineExceeded { + retry.lastErr = err + } + + if retry.config.policy == RetryNever { + return err + } + + if retry.config.maxAttempts != nil && retry.attempts >= *retry.config.maxAttempts { + return fmt.Errorf("storage: retry failed after %v attempts; last error: %w", retry.attempts, err) + } + + retry.attempts++ + + // Explicitly check context cancellation so that we can distinguish between a + // DEADLINE_EXCEEDED error from the server and a user-set context deadline. + // Unfortunately gRPC will codes.DeadlineExceeded (which may be retryable if it's + // sent by the server) in both cases. + ctxErr := ctx.Err() + if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { + if retry.lastErr != nil { + return fmt.Errorf("retry failed with %v; last error: %w", ctxErr, retry.lastErr) + } + return ctxErr + } + + if !retry.config.shouldRetry(err) { + return err + } + return nil +} + +// doBackOff pauses for the appropriate amount of time; it should be called after +// encountering a retriable error. +func (retry *uploadBufferRetryConfig) doBackOff(ctx context.Context) error { + p := retry.config.backoff.Pause() + + if ctxErr := gax.Sleep(ctx, p); ctxErr != nil { + if retry.lastErr != nil { + return fmt.Errorf("retry failed with %v; last error: %w", ctxErr, retry.lastErr) + } + return ctxErr + } + return nil +} diff --git a/storage/grpc_client_test.go b/storage/grpc_client_test.go index d75975e40ecc..86230964d2c1 100644 --- a/storage/grpc_client_test.go +++ b/storage/grpc_client_test.go @@ -16,15 +16,22 @@ package storage import ( "bytes" + "context" "crypto/md5" + "errors" "hash/crc32" + "io" "math/rand" "testing" "time" "cloud.google.com/go/storage/internal/apiv2/storagepb" "github.com/google/go-cmp/cmp" + "google.golang.org/api/option" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/mem" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" ) @@ -218,3 +225,162 @@ func TestBytesCodecV2(t *testing.T) { }) } } + +func TestWriteBackoff(t *testing.T) { + ctx := context.Background() + + writeStream := &MockWriteStream{} + streamInterceptor := grpc.WithStreamInterceptor( + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return writeStream, nil + }) + + c, err := NewGRPCClient(ctx, option.WithoutAuthentication(), option.WithGRPCDialOption(streamInterceptor)) + if err != nil { + t.Fatalf("NewGRPCClient: %v", err) + } + + w := gRPCWriter{ + c: (c.tc).(*grpcStorageClient), + buf: []byte("012345689abcdefg"), + chunkSize: 10, + ctx: ctx, + attrs: &ObjectAttrs{}, + } + + maxAttempts := 4 // be careful changing this value; some tests depend on it + + retriableErr := status.Errorf(codes.Internal, "a retriable error") + nonretriableErr := status.Errorf(codes.PermissionDenied, "a non-retriable error") + testCases := []struct { + desc string + recvErrs []error + sendErrs []error + closeSend bool // if true, will mimic finishing the upload so CloseSend will be called + expectedRetries int + expectedErr error // final err from uploadBuffer should wrap this err + }{ + { + desc: "retriable error on receive", + recvErrs: []error{retriableErr, retriableErr, retriableErr, retriableErr}, + expectedRetries: maxAttempts - 1, + expectedErr: retriableErr, + }, + { + desc: "retriable error on send", + sendErrs: []error{io.EOF}, + recvErrs: []error{retriableErr, retriableErr, retriableErr, retriableErr}, + expectedRetries: maxAttempts - 1, + expectedErr: retriableErr, + }, + { + desc: "non-retriable error on send", + sendErrs: []error{io.EOF}, + recvErrs: []error{nonretriableErr}, + expectedRetries: 0, + expectedErr: nonretriableErr, + }, + { + desc: "non-retriable err after send closed", + recvErrs: []error{nil, nonretriableErr}, + expectedRetries: 0, + expectedErr: nonretriableErr, + closeSend: true, + }, + { + desc: "retriable err after send closed", + recvErrs: []error{nil, retriableErr, retriableErr, retriableErr, retriableErr}, + expectedRetries: maxAttempts - 1, + expectedErr: retriableErr, + closeSend: true, + }, + } + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + backoff := &mockBackoff{} + retryConfig := newUploadBufferRetryConfig(&settings{}) + retryConfig.config.maxAttempts = &maxAttempts + retryConfig.config.backoff = backoff + + writeStream.RecvMsgErrToReturn = test.recvErrs + writeStream.SendMsgErrToReturn = test.sendErrs + + _, _, err = w.uploadBuffer(0, 0, test.closeSend, retryConfig) + if !errors.Is(err, test.expectedErr) { + t.Fatalf("uploadBuffer, want err to wrap: %v, got err: %v", test.expectedErr, err) + } + + if got, want := backoff.pauseCallsCount, test.expectedRetries; got != want { + t.Errorf("backoff.Pause was called %d times, expected %d", got, want) + } + }) + } +} + +type MockWriteStream struct { + grpc.ClientStream + SendMsgErrToReturn []error // if the array is empty, returns nil; otherwise, pops the first value + RecvMsgErrToReturn []error // if the array is empty, returns nil; otherwise, pops the first value +} + +func (s *MockWriteStream) SendMsg(m any) error { + if len(s.SendMsgErrToReturn) < 1 { + return nil + } + err := s.SendMsgErrToReturn[0] + s.SendMsgErrToReturn = s.SendMsgErrToReturn[1:] + return err +} + +// Retriable errors mean we should start over and attempt to +// resend the entire buffer via a new stream. +// If not retriable, falling through will return the error received. +func (s *MockWriteStream) RecvMsg(m any) error { + if len(s.RecvMsgErrToReturn) < 1 { + return nil + } + err := s.RecvMsgErrToReturn[0] + s.RecvMsgErrToReturn = s.RecvMsgErrToReturn[1:] + return err +} + +func (s *MockWriteStream) CloseSend() error { + return nil +} + +// mockBackoff keeps count of the number of time Pause was called. Not thread-safe. +type mockBackoff struct { + Initial time.Duration + Max time.Duration + Multiplier float64 + pauseCallsCount int +} + +func (b *mockBackoff) Pause() time.Duration { + b.pauseCallsCount++ + return 0 +} + +func (b *mockBackoff) SetInitial(i time.Duration) { + b.Initial = i +} + +func (b *mockBackoff) SetMax(m time.Duration) { + b.Max = m +} + +func (b *mockBackoff) SetMultiplier(m float64) { + b.Multiplier = m +} + +func (b *mockBackoff) GetInitial() time.Duration { + return b.Initial +} + +func (b *mockBackoff) GetMax() time.Duration { + return b.Max +} + +func (b *mockBackoff) GetMultiplier() float64 { + return b.Multiplier +} diff --git a/storage/http_client.go b/storage/http_client.go index 221078f3e262..f0a0853f5bd1 100644 --- a/storage/http_client.go +++ b/storage/http_client.go @@ -34,6 +34,7 @@ import ( "cloud.google.com/go/iam/apiv1/iampb" "cloud.google.com/go/internal/optional" "cloud.google.com/go/internal/trace" + "github.com/googleapis/gax-go/v2" "github.com/googleapis/gax-go/v2/callctx" "golang.org/x/oauth2/google" "google.golang.org/api/googleapi" @@ -1022,7 +1023,13 @@ func (c *httpStorageClient) OpenWriter(params *openWriterParams, opts ...storage } if useRetry { if s.retry != nil { - call.WithRetry(s.retry.backoff, s.retry.shouldRetry) + bo := &gax.Backoff{} + if s.retry.backoff != nil { + bo.Multiplier = s.retry.backoff.GetMultiplier() + bo.Initial = s.retry.backoff.GetInitial() + bo.Max = s.retry.backoff.GetMax() + } + call.WithRetry(bo, s.retry.shouldRetry) } else { call.WithRetry(nil, nil) } diff --git a/storage/invoke.go b/storage/invoke.go index 99783f3df47b..b1e838fc7193 100644 --- a/storage/invoke.go +++ b/storage/invoke.go @@ -58,9 +58,9 @@ func run(ctx context.Context, call func(ctx context.Context) error, retry *retry } bo := gax.Backoff{} if retry.backoff != nil { - bo.Multiplier = retry.backoff.Multiplier - bo.Initial = retry.backoff.Initial - bo.Max = retry.backoff.Max + bo.Multiplier = retry.backoff.GetMultiplier() + bo.Initial = retry.backoff.GetInitial() + bo.Max = retry.backoff.GetMax() } var errorFunc func(err error) bool = ShouldRetry if retry.shouldRetry != nil { diff --git a/storage/invoke_test.go b/storage/invoke_test.go index c04c76ba8fa7..c99bbd3ed6f9 100644 --- a/storage/invoke_test.go +++ b/storage/invoke_test.go @@ -27,7 +27,6 @@ import ( "testing" "time" - "github.com/googleapis/gax-go/v2" "github.com/googleapis/gax-go/v2/callctx" "google.golang.org/api/googleapi" "google.golang.org/grpc/codes" @@ -42,12 +41,13 @@ func TestInvoke(t *testing.T) { for _, test := range []struct { desc string - count int // Number of times to return retryable error. + count int // Maximum number of times to return initialErr. initialErr error // Error to return initially. finalErr error // Error to return after count returns of retryCode. retry *retryConfig isIdempotentValue bool expectFinalErr bool + expectedAttempts int }{ { desc: "test fn never returns initial error with count=0", @@ -56,6 +56,7 @@ func TestInvoke(t *testing.T) { finalErr: nil, isIdempotentValue: true, expectFinalErr: true, + expectedAttempts: 1, }, { desc: "non-retryable error is returned without retrying", @@ -64,6 +65,7 @@ func TestInvoke(t *testing.T) { finalErr: nil, isIdempotentValue: true, expectFinalErr: false, + expectedAttempts: 1, }, { desc: "retryable error is retried", @@ -72,6 +74,7 @@ func TestInvoke(t *testing.T) { finalErr: nil, isIdempotentValue: true, expectFinalErr: true, + expectedAttempts: 2, }, { desc: "retryable gRPC error is retried", @@ -80,6 +83,7 @@ func TestInvoke(t *testing.T) { finalErr: nil, isIdempotentValue: true, expectFinalErr: true, + expectedAttempts: 2, }, { desc: "returns non-retryable error after retryable error", @@ -88,6 +92,7 @@ func TestInvoke(t *testing.T) { finalErr: errors.New("bar"), isIdempotentValue: true, expectFinalErr: true, + expectedAttempts: 2, }, { desc: "retryable 5xx error is retried", @@ -96,6 +101,7 @@ func TestInvoke(t *testing.T) { finalErr: nil, isIdempotentValue: true, expectFinalErr: true, + expectedAttempts: 3, }, { desc: "retriable error not retried when non-idempotent", @@ -104,6 +110,7 @@ func TestInvoke(t *testing.T) { finalErr: nil, isIdempotentValue: false, expectFinalErr: false, + expectedAttempts: 1, }, { desc: "non-idempotent retriable error retried when policy is RetryAlways", @@ -113,6 +120,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: false, retry: &retryConfig{policy: RetryAlways}, expectFinalErr: true, + expectedAttempts: 3, }, { desc: "retriable error not retried when policy is RetryNever", @@ -122,6 +130,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: true, retry: &retryConfig{policy: RetryNever}, expectFinalErr: false, + expectedAttempts: 1, }, { desc: "non-retriable error not retried when policy is RetryAlways", @@ -131,6 +140,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: true, retry: &retryConfig{policy: RetryAlways}, expectFinalErr: false, + expectedAttempts: 1, }, { desc: "non-retriable error retried with custom fn", @@ -143,7 +153,8 @@ func TestInvoke(t *testing.T) { return err == io.ErrNoProgress }, }, - expectFinalErr: true, + expectFinalErr: true, + expectedAttempts: 3, }, { desc: "retriable error not retried with custom fn", @@ -156,7 +167,8 @@ func TestInvoke(t *testing.T) { return err == io.ErrNoProgress }, }, - expectFinalErr: false, + expectFinalErr: false, + expectedAttempts: 1, }, { desc: "error not retried when policy is RetryNever despite custom fn", @@ -170,7 +182,8 @@ func TestInvoke(t *testing.T) { }, policy: RetryNever, }, - expectFinalErr: false, + expectFinalErr: false, + expectedAttempts: 1, }, { desc: "non-idempotent retriable error retried when policy is RetryAlways till maxAttempts", @@ -180,6 +193,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: false, retry: &retryConfig{policy: RetryAlways, maxAttempts: expectedAttempts(2)}, expectFinalErr: false, + expectedAttempts: 2, }, { desc: "non-idempotent retriable error not retried when policy is RetryNever with maxAttempts set", @@ -189,6 +203,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: false, retry: &retryConfig{policy: RetryNever, maxAttempts: expectedAttempts(2)}, expectFinalErr: false, + expectedAttempts: 1, }, { desc: "non-retriable error retried with custom fn till maxAttempts", @@ -202,7 +217,8 @@ func TestInvoke(t *testing.T) { }, maxAttempts: expectedAttempts(2), }, - expectFinalErr: false, + expectFinalErr: false, + expectedAttempts: 2, }, { desc: "non-idempotent retriable error retried when policy is RetryAlways till maxAttempts where count equals to maxAttempts-1", @@ -212,6 +228,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: false, retry: &retryConfig{policy: RetryAlways, maxAttempts: expectedAttempts(4)}, expectFinalErr: true, + expectedAttempts: 4, }, { desc: "non-idempotent retriable error retried when policy is RetryAlways till maxAttempts where count equals to maxAttempts", @@ -221,6 +238,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: true, retry: &retryConfig{policy: RetryAlways, maxAttempts: expectedAttempts(4)}, expectFinalErr: false, + expectedAttempts: 4, }, { desc: "non-idempotent retriable error not retried when policy is RetryAlways with maxAttempts equals to zero", @@ -230,6 +248,7 @@ func TestInvoke(t *testing.T) { isIdempotentValue: true, retry: &retryConfig{maxAttempts: expectedAttempts(0), policy: RetryAlways}, expectFinalErr: false, + expectedAttempts: 1, }, } { t.Run(test.desc, func(s *testing.T) { @@ -255,22 +274,17 @@ func TestInvoke(t *testing.T) { if test.retry == nil { test.retry = defaultRetry.clone() } - test.retry.backoff = &gax.Backoff{Initial: time.Millisecond} + bo := &gaxBackoff{} + bo.Initial = time.Millisecond + test.retry.backoff = bo got := run(ctx, call, test.retry, test.isIdempotentValue) if test.expectFinalErr && !errors.Is(got, test.finalErr) { s.Errorf("got %v, want %v", got, test.finalErr) } else if !test.expectFinalErr && !errors.Is(got, test.initialErr) { s.Errorf("got %v, want %v", got, test.initialErr) } - wantAttempts := 1 + test.count - if !test.expectFinalErr { - wantAttempts = 1 - } - if test.retry != nil && test.retry.maxAttempts != nil && *test.retry.maxAttempts != 0 && test.retry.policy != RetryNever { - wantAttempts = *test.retry.maxAttempts - } - wantClientHeader := strings.ReplaceAll(initialClientHeader, "gccl-attempt-count/1", fmt.Sprintf("gccl-attempt-count/%v", wantAttempts)) + wantClientHeader := strings.ReplaceAll(initialClientHeader, "gccl-attempt-count/1", fmt.Sprintf("gccl-attempt-count/%v", test.expectedAttempts)) if gotClientHeader != wantClientHeader { t.Errorf("case %q, retry header:\ngot %v\nwant %v", test.desc, gotClientHeader, wantClientHeader) } diff --git a/storage/storage.go b/storage/storage.go index 66fba72e6a49..e112d124ab65 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -2216,7 +2216,7 @@ type withBackoff struct { } func (wb *withBackoff) apply(config *retryConfig) { - config.backoff = &wb.backoff + config.backoff = gaxBackoffFromStruct(&wb.backoff) } // WithMaxAttempts configures the maximum number of times an API call can be made @@ -2307,8 +2307,58 @@ func (wef *withErrorFunc) apply(config *retryConfig) { config.shouldRetry = wef.shouldRetry } +type backoff interface { + Pause() time.Duration + + SetInitial(time.Duration) + SetMax(time.Duration) + SetMultiplier(float64) + + GetInitial() time.Duration + GetMax() time.Duration + GetMultiplier() float64 +} + +func gaxBackoffFromStruct(bo *gax.Backoff) *gaxBackoff { + if bo == nil { + return nil + } + b := &gaxBackoff{} + b.Backoff = *bo + return b +} + +// gaxBackoff is a gax.Backoff that implements the backoff interface +type gaxBackoff struct { + gax.Backoff +} + +func (b *gaxBackoff) SetInitial(i time.Duration) { + b.Initial = i +} + +func (b *gaxBackoff) SetMax(m time.Duration) { + b.Max = m +} + +func (b *gaxBackoff) SetMultiplier(m float64) { + b.Multiplier = m +} + +func (b *gaxBackoff) GetInitial() time.Duration { + return b.Initial +} + +func (b *gaxBackoff) GetMax() time.Duration { + return b.Max +} + +func (b *gaxBackoff) GetMultiplier() float64 { + return b.Multiplier +} + type retryConfig struct { - backoff *gax.Backoff + backoff backoff policy RetryPolicy shouldRetry func(err error) bool maxAttempts *int @@ -2318,22 +2368,22 @@ func (r *retryConfig) clone() *retryConfig { if r == nil { return nil } - - var bo *gax.Backoff - if r.backoff != nil { - bo = &gax.Backoff{ - Initial: r.backoff.Initial, - Max: r.backoff.Max, - Multiplier: r.backoff.Multiplier, - } - } - - return &retryConfig{ - backoff: bo, + newConfig := &retryConfig{ + backoff: nil, policy: r.policy, shouldRetry: r.shouldRetry, maxAttempts: r.maxAttempts, } + + if r.backoff != nil { + bo := &gaxBackoff{} + bo.Initial = r.backoff.GetInitial() + bo.Max = r.backoff.GetMax() + bo.Multiplier = r.backoff.GetMultiplier() + newConfig.backoff = bo + } + + return newConfig } // composeSourceObj wraps a *raw.ComposeRequestSourceObjects, but adds the methods diff --git a/storage/storage_test.go b/storage/storage_test.go index 4ef1a92f1e1d..89ac12552f51 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -986,11 +986,11 @@ func TestObjectRetryer(t *testing.T) { WithErrorFunc(func(err error) bool { return false })) }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: 2 * time.Second, Max: 30 * time.Second, Multiplier: 3, - }, + }), maxAttempts: expectedAttempts(5), policy: RetryAlways, shouldRetry: func(err error) bool { return false }, @@ -1005,9 +1005,9 @@ func TestObjectRetryer(t *testing.T) { })) }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Multiplier: 3, - }}, + })}, }, { name: "set policy only", @@ -1083,11 +1083,11 @@ func TestClientSetRetry(t *testing.T) { WithErrorFunc(func(err error) bool { return false }), }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: 2 * time.Second, Max: 30 * time.Second, Multiplier: 3, - }, + }), maxAttempts: expectedAttempts(5), policy: RetryAlways, shouldRetry: func(err error) bool { return false }, @@ -1101,9 +1101,9 @@ func TestClientSetRetry(t *testing.T) { }), }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Multiplier: 3, - }}, + })}, }, { name: "set policy only", @@ -1198,11 +1198,11 @@ func TestRetryer(t *testing.T) { WithErrorFunc(ShouldRetry), }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: time.Minute, Max: time.Hour, Multiplier: 6, - }, + }), shouldRetry: ShouldRetry, maxAttempts: expectedAttempts(11), policy: RetryAlways, @@ -1221,11 +1221,11 @@ func TestRetryer(t *testing.T) { WithErrorFunc(ShouldRetry), }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: time.Minute, Max: time.Hour, Multiplier: 6, - }, + }), shouldRetry: ShouldRetry, maxAttempts: expectedAttempts(7), policy: RetryAlways, @@ -1285,10 +1285,10 @@ func TestRetryer(t *testing.T) { policy: RetryAlways, maxAttempts: expectedAttempts(5), shouldRetry: ShouldRetry, - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: time.Nanosecond, Max: time.Microsecond, - }, + }), }, }, { @@ -1307,10 +1307,10 @@ func TestRetryer(t *testing.T) { }), }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: time.Nanosecond, Max: time.Microsecond, - }, + }), }, }, { @@ -1330,10 +1330,10 @@ func TestRetryer(t *testing.T) { policy: RetryNever, maxAttempts: expectedAttempts(5), shouldRetry: ShouldRetry, - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Initial: time.Nanosecond, Max: time.Second, - }, + }), }, }, { @@ -1349,9 +1349,9 @@ func TestRetryer(t *testing.T) { }), }, want: &retryConfig{ - backoff: &gax.Backoff{ + backoff: gaxBackoffFromStruct(&gax.Backoff{ Multiplier: 4, - }, + }), }, }, }