Skip to content

Commit

Permalink
parallelize requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Codelax committed Sep 22, 2022
1 parent 2551ee9 commit 1dbff93
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 16 deletions.
80 changes: 64 additions & 16 deletions scw/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ import (
"net/http"
"net/http/httputil"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/scaleway/scaleway-sdk-go/internal/auth"
"github.com/scaleway/scaleway-sdk-go/internal/errors"
"github.com/scaleway/scaleway-sdk-go/logger"

goerrors "errors"
)

// Client is the Scaleway client which performs API requests.
Expand Down Expand Up @@ -168,7 +168,7 @@ func (c *Client) Do(req *ScalewayRequest, res interface{}, opts ...RequestOption
}

if req.zones != nil {
return c.doListAllZones(req, res, req.zones)
return c.doListZones(req, res, req.zones)
}

if req.allPages {
Expand Down Expand Up @@ -347,36 +347,84 @@ func (c *Client) doListAll(req *ScalewayRequest, res interface{}) (err error) {
return errors.New("%T does not support pagination", res)
}

func (c *Client) doListAllZones(req *ScalewayRequest, res interface{}, zones []Zone) (err error) {
func (c *Client) doListZones(req *ScalewayRequest, res interface{}, zones []Zone) (err error) {
if response, isLister := res.(lister); isLister {
// Prepare request with %zone% that can be replaced with actual zone
path := req.Path
for _, zone := range AllZones {
if strings.Contains(req.Path, string(zone)) {
path = strings.ReplaceAll(req.Path, string(zone), "%zone%")
}
}
for _, zone := range zones {
req.Path = strings.ReplaceAll(path, "%zone%", string(zone))

nextZone := newVariableFromType(response)
err := c.doListAll(req, nextZone)
if err != nil {
responseError := &ResponseError{}
if !goerrors.As(err, &responseError) || responseError.StatusCode != 404 {
return err
// Requests are parallelized
responseMutex := sync.Mutex{}
requestGroup := sync.WaitGroup{}
errChan := make(chan error, len(zones))

requestGroup.Add(len(zones))
for _, zone := range zones {
go func(zone Zone) {
defer requestGroup.Done()
req := req.clone()
req.Path = strings.ReplaceAll(path, "%zone%", string(zone))

zoneResponse := newVariableFromType(response)
err := c.doListAll(req, zoneResponse)
if err != nil {
errChan <- err
}
}
_, err = response.UnsafeAppend(nextZone)
if err != nil {
return err
responseMutex.Lock()
_, err = response.UnsafeAppend(zoneResponse)
responseMutex.Unlock()
if err != nil {
errChan <- err
}
}(zone)
}
requestGroup.Wait()

L:
for {
select {
case newErr := <-errChan:
err = errors.Wrap(err, newErr.Error())
default:
break L
}
}
close(errChan)
if err != nil {
return err
}

sortResponseByZone(res)
return nil
}

return errors.New("%T does not support pagination", res)
}

func sortSliceByZone(list interface{}) {
listValue := reflect.ValueOf(list)
sort.Slice(list, func(i, j int) bool {
zone1 := listValue.Index(i).Elem().FieldByName("Zone").Interface().(Zone)
zone2 := listValue.Index(j).Elem().FieldByName("Zone").Interface().(Zone)
return zone1 < zone2
})
}

func sortResponseByZone(res interface{}) {
resType := reflect.TypeOf(res).Elem()
fields := reflect.VisibleFields(resType)
for _, field := range fields {
if field.Type.Kind() == reflect.Slice {
sortSliceByZone(reflect.ValueOf(res).Elem().FieldByName(field.Name).Interface())
return
}
}
}

// newVariableFromType returns a variable set to the zero value of the given type
func newVariableFromType(t interface{}) interface{} {
// reflect.New always create a pointer, that's why we use reflect.Indirect before
Expand Down
16 changes: 16 additions & 0 deletions scw/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,19 @@ func (req *ScalewayRequest) validate() error {
// nothing so far
return nil
}

func (req *ScalewayRequest) clone() *ScalewayRequest {
clonedReq := &ScalewayRequest{
Method: req.Method,
Path: req.Path,
Headers: req.Headers.Clone(),
ctx: req.ctx,
auth: req.auth,
allPages: req.allPages,
zones: req.zones,
}
if req.Query != nil {
clonedReq.Query = url.Values(http.Header(req.Query).Clone())
}
return clonedReq
}

0 comments on commit 1dbff93

Please sign in to comment.