diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..3aa988d406b 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,5 +1,8 @@ ### SDK Features ### SDK Enhancements +* `service/kinesis`: Add support for retrying service specific API errors ([#2751](https://github.com/aws/aws-sdk-go/pull/2751) + * Adds support for retrying the Kinesis API error, LimitExceededException. + * Fixes [#1376](https://github.com/aws/aws-sdk-go/issues/1376) ### SDK Bugs diff --git a/service/kinesis/customizations.go b/service/kinesis/customizations.go index f618f0da698..0ab636735ef 100644 --- a/service/kinesis/customizations.go +++ b/service/kinesis/customizations.go @@ -9,14 +9,14 @@ import ( var readDuration = 5 * time.Second func init() { - ops := []string{ - opGetRecords, - } - initRequest = func(r *request.Request) { - for _, operation := range ops { - if r.Operation.Name == operation { - r.ApplyOptions(request.WithResponseReadTimeout(readDuration)) - } - } + initRequest = customizeRequest +} + +func customizeRequest(r *request.Request) { + if r.Operation.Name == opGetRecords { + r.ApplyOptions(request.WithResponseReadTimeout(readDuration)) } + + // Service specific error codes. Github(aws/aws-sdk-go#1376) + r.RetryErrorCodes = append(r.RetryErrorCodes, ErrCodeLimitExceededException) } diff --git a/service/kinesis/customizations_test.go b/service/kinesis/customizations_test.go index f21c399ab4c..e35dfeb2257 100644 --- a/service/kinesis/customizations_test.go +++ b/service/kinesis/customizations_test.go @@ -1,13 +1,17 @@ package kinesis import ( + "bytes" + "fmt" "io" + "io/ioutil" "net/http" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting/unit" ) @@ -87,3 +91,45 @@ func TestKinesisGetRecordsNoTimeout(t *testing.T) { t.Errorf("Expected no error, but received %v", err) } } + +func TestKinesisCustomRetryErrorCodes(t *testing.T) { + svc := New(unit.Session, &aws.Config{ + MaxRetries: aws.Int(1), + LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), + }) + svc.Handlers.Validate.Clear() + + const jsonErr = `{"__type":%q, "message":"some error message"}` + var reqCount int + resps := []*http.Response{ + { + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(fmt.Sprintf(jsonErr, ErrCodeLimitExceededException)), + )), + }, + { + StatusCode: 200, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, + } + + req, _ := svc.GetRecordsRequest(&GetRecordsInput{}) + req.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{ + Name: "custom send handler", + Fn: func(r *request.Request) { + r.HTTPResponse = resps[reqCount] + reqCount++ + }, + }) + + if err := req.Send(); err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := 2, reqCount; e != a { + t.Errorf("expect %v requests, got %v", e, a) + } +}