From 6c9368b710560906810a93ce0775e331fbda0e2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=C4=93m=20Darios?= Date: Fri, 17 Mar 2023 20:55:22 -0700 Subject: [PATCH 1/2] Added exponential backoff of calls to api gateway api Starts at 2 seconds and retries 5 times while backing off exponentially. --- .../aws/repository/api_gateway_repository.go | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/enumeration/remote/aws/repository/api_gateway_repository.go b/enumeration/remote/aws/repository/api_gateway_repository.go index bb69d7dd6..68a528b22 100644 --- a/enumeration/remote/aws/repository/api_gateway_repository.go +++ b/enumeration/remote/aws/repository/api_gateway_repository.go @@ -2,6 +2,11 @@ package repository import ( "fmt" + "math" + "strings" + "time" + + "github.com/sirupsen/logrus" "github.com/snyk/driftctl/enumeration/remote/cache" "github.com/aws/aws-sdk-go/aws" @@ -30,6 +35,8 @@ type apigatewayRepository struct { cache cache.Cache } +const MaxRetries = 5 + func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository { return &apigatewayRepository{ apigateway.New(session), @@ -37,14 +44,39 @@ func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewa } } +func retryOnFailure(callback func() error, message string) error { + retries := 0 + retry := true + + var err error + for retry && retries < MaxRetries { + sleepTime := time.Duration(math.Pow(2, float64(retries))) * 2 * time.Second + logrus.Warn(message, "Attempt number ", retries+1, "/", MaxRetries, ". Retrying after sleeping for ", sleepTime, "...") + time.Sleep(sleepTime) + logrus.Debug("Awake! Attempting to make API call again.") + + err = callback() + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { + retry = true + } else { + retry = false + } + + retries++ + } + return err +} + func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) { cacheKey := "apigatewayListAllRestApis" v := r.cache.GetAndLock(cacheKey) defer r.cache.Unlock(cacheKey) if v != nil { + logrus.Debug("Getting all rest APIs from cache") return v.([]*apigateway.RestApi), nil } + logrus.Debug("Making a call to get rest APIs not found in cache") var restApis []*apigateway.RestApi input := apigateway.GetRestApisInput{} err := r.client.GetRestApisPages(&input, @@ -53,7 +85,22 @@ func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) return !lastPage }, ) + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetRestApisPages(&input, + func(resp *apigateway.GetRestApisOutput, lastPage bool) bool { + restApis = append(restApis, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetRestApisPages!") + } + + if err != nil { + logrus.Error("error in list all apis") return nil, err } @@ -67,6 +114,16 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { } account, err := r.client.GetAccount(&apigateway.GetAccountInput{}) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + input := apigateway.GetAccountInput{} + account, err = r.client.GetAccount(&input) + return err + }, "Error caught during GetAccount!") + } + if err != nil { return nil, err } @@ -77,6 +134,7 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil { + logrus.Debug("Getting api keys from cache") return v.([]*apigateway.ApiKey), nil } @@ -88,6 +146,20 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { return !lastPage }, ) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetApiKeysPages(&input, + func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool { + apiKeys = append(apiKeys, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetApiKeysPages!") + } + if err != nil { return nil, err } @@ -99,14 +171,26 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) { cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId) if v := r.cache.Get(cacheKey); v != nil { + logrus.Debug("Getting api authorizers from cache ", apiId) return v.([]*apigateway.Authorizer), nil } + logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId) input := &apigateway.GetAuthorizersInput{ RestApiId: &apiId, } resources, err := r.client.GetAuthorizers(input) + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId) + resources, err = r.client.GetAuthorizers(input) + return err + }, "Error caught during GetAuthorizers with input "+apiId+"!") + } + + if err != nil { + logrus.Error("error in api authorizer") return nil, err } @@ -119,14 +203,26 @@ func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway v := r.cache.GetAndLock(cacheKey) defer r.cache.Unlock(cacheKey) if v != nil { + logrus.Debug("Getting api stages from cache ", apiId) return v.([]*apigateway.Stage), nil } + logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId) input := &apigateway.GetStagesInput{ RestApiId: &apiId, } resources, err := r.client.GetStages(input) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId) + resources, err = r.client.GetStages(input) + return err + }, "Error caught during GetStages with input "+apiId+"!") + } + if err != nil { + logrus.Error("error in api stage") return nil, err } @@ -139,9 +235,11 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate v := r.cache.GetAndLock(cacheKey) defer r.cache.Unlock(cacheKey) if v != nil { + logrus.Debug("Getting api resource from cache ", apiId) return v.([]*apigateway.Resource), nil } + logrus.Debug("Making a call to API for specific resource not found in cache ", apiId) var resources []*apigateway.Resource input := &apigateway.GetResourcesInput{ RestApiId: &apiId, @@ -151,7 +249,20 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate resources = append(resources, res.Items...) return !lastPage }) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool { + resources = append(resources, res.Items...) + return !lastPage + }) + return err + }, "Error caught during GetResourcesPages with input "+apiId+"!") + } + if err != nil { + logrus.Error("error in api resource") return nil, err } @@ -175,6 +286,20 @@ func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, e return !lastPage }, ) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetDomainNamesPages(&input, + func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool { + domainNames = append(domainNames, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetDomainNamesPages!") + } + if err != nil { return nil, err } @@ -196,6 +321,20 @@ func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOut return !lastPage }, ) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetVpcLinksPages(&input, + func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool { + vpcLinks = append(vpcLinks, resp.Items...) + return !lastPage + }, + ) + return err + }, "Error caught during GetVpcLinksPages!") + } + if err != nil { return nil, err } @@ -214,6 +353,15 @@ func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([] RestApiId: &apiId, } resources, err := r.client.GetRequestValidators(input) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + resources, err = r.client.GetRequestValidators(input) + return err + }, "Error caught during GetRequestValidators with input "+apiId+"!") + } + if err != nil { return nil, err } @@ -236,6 +384,18 @@ func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName stri mappings = append(mappings, res.Items...) return !lastPage }) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool { + mappings = append(mappings, res.Items...) + return !lastPage + }) + return err + }, "Error caught during GetBasePathMappingsPages with input "+domainName+"!") + } + if err != nil { return nil, err } @@ -258,6 +418,17 @@ func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway resources = append(resources, res.Items...) return !lastPage }) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + err = r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool { + resources = append(resources, res.Items...) + return !lastPage + }) + return err + }, "Error caught during GetModelsPages with input "+apiId+"!") + } if err != nil { return nil, err } @@ -276,6 +447,15 @@ func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]* RestApiId: &apiId, } resources, err := r.client.GetGatewayResponses(input) + + if err != nil { + err = retryOnFailure(func() error { + logrus.Debug("Making a call to get rest APIs not found in cache") + resources, err = r.client.GetGatewayResponses(input) + return err + }, "Error caught during GetGatewayResponses with input "+apiId+"!") + } + if err != nil { return nil, err } From 38ad4d2ac3f949be3513341b12f223cf45900e8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=C4=93m=20Darios?= Date: Fri, 17 Mar 2023 21:17:02 -0700 Subject: [PATCH 2/2] Added check for TooManyRequestsException to exit earlier --- .../aws/repository/api_gateway_repository.go | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/enumeration/remote/aws/repository/api_gateway_repository.go b/enumeration/remote/aws/repository/api_gateway_repository.go index 68a528b22..5da06f7bd 100644 --- a/enumeration/remote/aws/repository/api_gateway_repository.go +++ b/enumeration/remote/aws/repository/api_gateway_repository.go @@ -86,7 +86,7 @@ func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) }, ) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetRestApisPages(&input, @@ -100,7 +100,6 @@ func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) } if err != nil { - logrus.Error("error in list all apis") return nil, err } @@ -115,7 +114,7 @@ func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { account, err := r.client.GetAccount(&apigateway.GetAccountInput{}) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") input := apigateway.GetAccountInput{} @@ -147,7 +146,7 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { }, ) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetApiKeysPages(&input, @@ -181,7 +180,7 @@ func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apiga } resources, err := r.client.GetAuthorizers(input) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to API for specific authorizers not found in cache: ", apiId) resources, err = r.client.GetAuthorizers(input) @@ -190,7 +189,6 @@ func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apiga } if err != nil { - logrus.Error("error in api authorizer") return nil, err } @@ -213,7 +211,7 @@ func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway } resources, err := r.client.GetStages(input) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to API for specific stage not found in cache: ", apiId) resources, err = r.client.GetStages(input) @@ -250,7 +248,7 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate return !lastPage }) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool { @@ -262,7 +260,6 @@ func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigate } if err != nil { - logrus.Error("error in api resource") return nil, err } @@ -287,7 +284,7 @@ func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, e }, ) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetDomainNamesPages(&input, @@ -322,7 +319,7 @@ func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOut }, ) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetVpcLinksPages(&input, @@ -354,7 +351,7 @@ func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([] } resources, err := r.client.GetRequestValidators(input) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") resources, err = r.client.GetRequestValidators(input) @@ -385,7 +382,7 @@ func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName stri return !lastPage }) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool { @@ -419,7 +416,7 @@ func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway return !lastPage }) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") err = r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool { @@ -448,7 +445,7 @@ func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]* } resources, err := r.client.GetGatewayResponses(input) - if err != nil { + if err != nil && strings.Contains(err.Error(), "TooManyRequestsException") { err = retryOnFailure(func() error { logrus.Debug("Making a call to get rest APIs not found in cache") resources, err = r.client.GetGatewayResponses(input)