diff --git a/workflow/artifacts/gcs/gcs.go b/workflow/artifacts/gcs/gcs.go index 9d62ce609b97..70960640a474 100644 --- a/workflow/artifacts/gcs/gcs.go +++ b/workflow/artifacts/gcs/gcs.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "net/url" "os" "path/filepath" "strings" @@ -14,12 +15,14 @@ import ( "github.com/argoproj/pkg/file" log "github.com/sirupsen/logrus" "golang.org/x/oauth2/google" + "google.golang.org/api/googleapi" "google.golang.org/api/iterator" "google.golang.org/api/option" "k8s.io/apimachinery/pkg/util/wait" "github.com/argoproj/argo-workflows/v3/errors" wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + waitutil "github.com/argoproj/argo-workflows/v3/util/wait" "github.com/argoproj/argo-workflows/v3/workflow/artifacts/common" ) @@ -28,7 +31,41 @@ type ArtifactDriver struct { ServiceAccountKey string } -var _ common.ArtifactDriver = &ArtifactDriver{} +var ( + _ common.ArtifactDriver = &ArtifactDriver{} + defaultRetry = wait.Backoff{Duration: time.Second * 2, Factor: 2.0, Steps: 5, Jitter: 0.1} +) + +// from https://github.com/googleapis/google-cloud-go/blob/master/storage/go110.go +func isTransientGCSErr(err error) bool { + if err == io.ErrUnexpectedEOF { + return true + } + switch e := err.(type) { + case *googleapi.Error: + // Retry on 429 and 5xx, according to + // https://cloud.google.com/storage/docs/exponential-backoff. + return e.Code == 429 || (e.Code >= 500 && e.Code < 600) + case *url.Error: + // Retry socket-level errors ECONNREFUSED and ENETUNREACH (from syscall). + // Unfortunately the error type is unexported, so we resort to string + // matching. + retriable := []string{"connection refused", "connection reset"} + for _, s := range retriable { + if strings.Contains(e.Error(), s) { + return true + } + } + case interface{ Temporary() bool }: + if e.Temporary() { + return true + } + } + if e, ok := err.(interface{ Unwrap() error }); ok { + return isTransientGCSErr(e.Unwrap()) + } + return false +} func (g *ArtifactDriver) newGCSClient() (*storage.Client, error) { if g.ServiceAccountKey != "" { @@ -62,19 +99,19 @@ func newGCSClientDefault() (*storage.Client, error) { // Load function downloads objects from GCS func (g *ArtifactDriver) Load(inputArtifact *wfv1.Artifact, path string) error { - err := wait.ExponentialBackoff(wait.Backoff{Duration: time.Second * 2, Factor: 2.0, Steps: 5, Jitter: 0.1}, + err := waitutil.Backoff(defaultRetry, func() (bool, error) { log.Infof("GCS Load path: %s, key: %s", path, inputArtifact.GCS.Key) gcsClient, err := g.newGCSClient() if err != nil { log.Warnf("Failed to create new GCS client: %v", err) - return false, err + return isTransientGCSErr(err), err } defer gcsClient.Close() err = downloadObjects(gcsClient, inputArtifact.GCS.Bucket, inputArtifact.GCS.Key, path) if err != nil { log.Warnf("Failed to download objects from GCS: %v", err) - return false, err + return isTransientGCSErr(err), err } return true, nil }) @@ -161,17 +198,17 @@ func listByPrefix(client *storage.Client, bucket, prefix, delim string) ([]strin // Save an artifact to GCS compliant storage, e.g., uploading a local file to GCS bucket func (g *ArtifactDriver) Save(path string, outputArtifact *wfv1.Artifact) error { - err := wait.ExponentialBackoff(wait.Backoff{Duration: time.Second * 2, Factor: 2.0, Steps: 5, Jitter: 0.1}, + err := waitutil.Backoff(defaultRetry, func() (bool, error) { log.Infof("GCS Save path: %s, key: %s", path, outputArtifact.GCS.Key) client, err := g.newGCSClient() if err != nil { - return false, err + return isTransientGCSErr(err), err } defer client.Close() err = uploadObjects(client, outputArtifact.GCS.Bucket, outputArtifact.GCS.Key, path) if err != nil { - return false, err + return isTransientGCSErr(err), err } return true, nil }) diff --git a/workflow/artifacts/gcs/gcs_test.go b/workflow/artifacts/gcs/gcs_test.go new file mode 100644 index 000000000000..e70361fb957e --- /dev/null +++ b/workflow/artifacts/gcs/gcs_test.go @@ -0,0 +1,38 @@ +package gcs + +import ( + "errors" + "fmt" + "io" + "net/url" + "testing" + + "google.golang.org/api/googleapi" +) + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + +func TestIsTransientGCSErr(t *testing.T) { + for _, test := range []struct { + err error + shouldretry bool + }{ + {&googleapi.Error{Code: 0}, false}, + {&googleapi.Error{Code: 429}, true}, + {&googleapi.Error{Code: 518}, true}, + {&googleapi.Error{Code: 599}, true}, + {&url.Error{Op: "blah", URL: "blah", Err: errors.New("connection refused")}, true}, + {io.ErrUnexpectedEOF, true}, + {&tlsHandshakeTimeoutError{}, true}, + {fmt.Errorf("Test unwrapping of a temporary error: %w", &googleapi.Error{Code: 500}), true}, + {fmt.Errorf("Test unwrapping of a non-retriable error: %w", &googleapi.Error{Code: 400}), false}, + } { + got := isTransientGCSErr(test.err) + if got != test.shouldretry { + t.Errorf("%+v: got %v, want %v", test, got, test.shouldretry) + } + } +}