From ebad8bf72fd71eb785c14aa6f7e678a15449c0e1 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Mon, 31 Oct 2022 13:45:40 +0100 Subject: [PATCH] make policy change handler try all fleet hosts before failing (#1329) (#1576) It changes the remote client to: - when creating a new client: - succeed if at least one host is health - shuffle the hosts, avoiding all the agents reaching to the same fleet-server on the first request - makes `(remote.*Client).Send` try all the hosts before failing, returning a multi-error if all hosts fail - if debug logs are enabled, `Send` will log each error with debug level - modifies `remote.requestClient`: - now `requestClient` holds its host - remove `requestFunc` - `(remopte.requestClient).newRequest uses the new `host` property to build the final URL for the request (cherry picked from commit 19f82223b0af9f4e114530956f833060120da667) Co-authored-by: Anderson Queiroz --- ...ltiple-Fleet-Server-hosts-are-handled.yaml | 35 +++ .../handlers/handler_action_policy_change.go | 14 +- internal/pkg/fleetapi/client/client.go | 2 +- internal/pkg/remote/client.go | 227 ++++++++++-------- internal/pkg/remote/client_test.go | 140 +++++++---- 5 files changed, 273 insertions(+), 145 deletions(-) create mode 100644 changelog/fragments/1666281194-Fix-how-multiple-Fleet-Server-hosts-are-handled.yaml diff --git a/changelog/fragments/1666281194-Fix-how-multiple-Fleet-Server-hosts-are-handled.yaml b/changelog/fragments/1666281194-Fix-how-multiple-Fleet-Server-hosts-are-handled.yaml new file mode 100644 index 00000000000..c0f13aa3d9c --- /dev/null +++ b/changelog/fragments/1666281194-Fix-how-multiple-Fleet-Server-hosts-are-handled.yaml @@ -0,0 +1,35 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: bug-fix + +# Change summary; a 80ish characters long description of the change. +summary: Fix how multiple Fleet Server hosts are handled + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +description: It fixes the bug when the Elastic Agent would be enrolled using + a valid Fleet Server URL, but the policy would contain more than one, being + the first URL unreachable. In that case the Elastic Agent would enroll with + Fleet Server, but become unhealthy as it'd get stuck trying only the first, + unreachable Fleet Server host. + +# Affected component; a word indicating the component this changeset affects. +#component: + +# PR number; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: 1329 + +# Issue number; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +#issue: 1234 diff --git a/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_policy_change.go b/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_policy_change.go index ad75299e420..5551e9461c7 100644 --- a/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_policy_change.go +++ b/internal/pkg/agent/application/pipeline/actions/handlers/handler_action_policy_change.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "sort" "time" @@ -142,14 +143,17 @@ func (h *PolicyChange) handleFleetServerHosts(ctx context.Context, c *config.Con err, "fail to create API client with updated hosts", errors.TypeNetwork, errors.M("hosts", h.config.Fleet.Client.Hosts)) } + ctx, cancel := context.WithTimeout(ctx, apiStatusTimeout) defer cancel() - resp, err := client.Send(ctx, "GET", "/api/status", nil, nil, nil) + + resp, err := client.Send(ctx, http.MethodGet, "/api/status", nil, nil, nil) if err != nil { return errors.New( - err, "fail to communicate with updated API client hosts", + err, "fail to communicate with Fleet Server API client hosts", errors.TypeNetwork, errors.M("hosts", h.config.Fleet.Client.Hosts)) } + // discard body for proper cancellation and connection reuse _, _ = io.Copy(ioutil.Discard, resp.Body) resp.Body.Close() @@ -157,15 +161,17 @@ func (h *PolicyChange) handleFleetServerHosts(ctx context.Context, c *config.Con reader, err := fleetToReader(h.agentInfo, h.config) if err != nil { return errors.New( - err, "fail to persist updated API client hosts", + err, "fail to persist new Fleet Server API client hosts", errors.TypeUnexpected, errors.M("hosts", h.config.Fleet.Client.Hosts)) } + err = h.store.Save(reader) if err != nil { return errors.New( - err, "fail to persist updated API client hosts", + err, "fail to persist new Fleet Server API client hosts", errors.TypeFilesystem, errors.M("hosts", h.config.Fleet.Client.Hosts)) } + for _, setter := range h.setters { setter.SetClient(client) } diff --git a/internal/pkg/fleetapi/client/client.go b/internal/pkg/fleetapi/client/client.go index 4470f0259a8..0f478497bb6 100644 --- a/internal/pkg/fleetapi/client/client.go +++ b/internal/pkg/fleetapi/client/client.go @@ -87,7 +87,7 @@ func NewWithConfig(log *logger.Logger, cfg remote.Config) (*remote.Client, error // ExtractError extracts error from a fleet-server response func ExtractError(resp io.Reader) error { - // Lets try to extract a high level fleet-server error. + // Let's try to extract a high level fleet-server error. e := &struct { StatusCode int `json:"statusCode"` Error string `json:"error"` diff --git a/internal/pkg/remote/client.go b/internal/pkg/remote/client.go index 085ab2bfe0e..5c8fd5c9a34 100644 --- a/internal/pkg/remote/client.go +++ b/internal/pkg/remote/client.go @@ -6,14 +6,17 @@ package remote import ( "context" + "fmt" "io" + "math/rand" "net/http" "net/url" + "sort" "strings" "sync" "time" - "github.com/pkg/errors" + "github.com/hashicorp/go-multierror" urlutil "github.com/elastic/elastic-agent-libs/kibana" "github.com/elastic/elastic-agent-libs/transport/httpcommon" @@ -26,33 +29,32 @@ const ( retryOnBadConnTimeout = 5 * time.Minute ) -type requestFunc func(string, string, url.Values, io.Reader) (*http.Request, error) type wrapperFunc func(rt http.RoundTripper) (http.RoundTripper, error) type requestClient struct { - request requestFunc + host string client http.Client lastUsed time.Time lastErr error lastErrOcc time.Time } -// Client wraps an http.Client and takes care of making the raw calls, the client should -// stay simple and specificals should be implemented in external action instead of adding new methods -// to the client. For authenticated calls or sending fields on every request, create customer RoundTripper -// implementations that will take care of the boiler plates. +// Client wraps a http.Client and takes care of making the raw calls, the client should +// stay simple and specifics should be implemented in external action instead of adding new methods +// to the client. For authenticated calls or sending fields on every request, create a custom RoundTripper +// implementation that will take care of the boilerplate. type Client struct { - log *logger.Logger - lock sync.Mutex - clients []*requestClient - config Config + log *logger.Logger + clientLock sync.Mutex + clients []*requestClient + config Config } // NewConfigFromURL returns a Config based on a received host. func NewConfigFromURL(URL string) (Config, error) { u, err := url.Parse(URL) if err != nil { - return Config{}, errors.Wrap(err, "could not parse url") + return Config{}, fmt.Errorf("could not parse url: %w", err) } c := DefaultClientConfig() @@ -76,7 +78,7 @@ func NewWithRawConfig(log *logger.Logger, config *config.Config, wrapper wrapper cfg := Config{} if err := config.Unpack(&cfg); err != nil { - return nil, errors.Wrap(err, "invalidate configuration") + return nil, fmt.Errorf("invalidate configuration: %w", err) } return NewWithConfig(l, cfg, wrapper) @@ -97,11 +99,14 @@ func NewWithConfig(log *logger.Logger, cfg Config, wrapper wrapperFunc) (*Client } hosts := cfg.GetHosts() - clients := make([]*requestClient, len(hosts)) - for i, host := range cfg.GetHosts() { - connStr, err := urlutil.MakeURL(string(cfg.Protocol), p, host, 0) + hostCount := len(hosts) + log.With("hosts", hosts).Debugf( + "creating remote client with %d hosts", hostCount) + clients := make([]*requestClient, hostCount) + for i, host := range hosts { + baseURL, err := urlutil.MakeURL(string(cfg.Protocol), p, host, 0) if err != nil { - return nil, errors.Wrap(err, "invalid fleet-server endpoint") + return nil, fmt.Errorf("invalid fleet-server endpoint: %w", err) } transport, err := cfg.Transport.RoundTripper( @@ -115,7 +120,7 @@ func NewWithConfig(log *logger.Logger, cfg Config, wrapper wrapperFunc) (*Client if wrapper != nil { transport, err = wrapper(transport) if err != nil { - return nil, errors.Wrap(err, "fail to create transport client") + return nil, fmt.Errorf("fail to create transport client: %w", err) } } @@ -125,17 +130,17 @@ func NewWithConfig(log *logger.Logger, cfg Config, wrapper wrapperFunc) (*Client } clients[i] = &requestClient{ - request: prefixRequestFactory(connStr), - client: httpClient, + host: baseURL, + client: httpClient, } } - return new(log, cfg, clients...) + return newClient(log, cfg, clients...) } -// Send executes a direct calls against the API, the method will takes cares of cloning -// also add necessary headers for likes: "Content-Type", "Accept", and "kbn-xsrf". -// No assumptions is done on the response concerning the received format, this will be the responsibility +// Send executes a direct calls against the API, the method will take care of cloning and +// also adding the necessary headers likes: "Content-Type", "Accept", and "kbn-xsrf". +// No assumptions are done on the response concerning the received format, this will be the responsibility // of the implementation to correctly unpack any received data. // // NOTE: @@ -155,45 +160,62 @@ func (c *Client) Send( } c.log.Debugf("Request method: %s, path: %s, reqID: %s", method, path, reqID) - c.lock.Lock() - defer c.lock.Unlock() - requester := c.nextRequester() + c.clientLock.Lock() + defer c.clientLock.Unlock() - req, err := requester.request(method, path, params, body) - if err != nil { - return nil, errors.Wrapf(err, "fail to create HTTP request using method %s to %s", method, path) - } + var resp *http.Response + var multiErr error - // Add generals headers to the request, we are dealing exclusively with JSON. - // Content-Type / Accepted type can be override from the called. - req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - // This header should be specific to fleet-server or remove it - req.Header.Set("kbn-xsrf", "1") // Without this Kibana will refuse to answer the request. + c.sortClients() + for i, requester := range c.clients { + req, err := requester.newRequest(method, path, params, body) + if err != nil { + return nil, fmt.Errorf( + "fail to create HTTP request using method %s to %s: %w", + method, path, err) + } - // If available, add the request id as an HTTP header - if reqID != "" { - req.Header.Add("X-Request-ID", reqID) - } + // Add generals headers to the request, we are dealing exclusively with JSON. + // Content-Type / Accepted type can be overridden by the caller. + req.Header.Set("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + // This header should be specific to fleet-server or remove it + req.Header.Set("kbn-xsrf", "1") // Without this Kibana will refuse to answer the request. - // copy headers. - for header, values := range headers { - for _, v := range values { - req.Header.Add(header, v) + // If available, add the request id as an HTTP header + if reqID != "" { + req.Header.Add("X-Request-ID", reqID) } - } - requester.lastUsed = time.Now().UTC() + // copy headers. + for header, values := range headers { + for _, v := range values { + req.Header.Add(header, v) + } + } + + requester.lastUsed = time.Now().UTC() + + resp, err = requester.client.Do(req.WithContext(ctx)) + if err != nil { + requester.lastErr = err + requester.lastErrOcc = time.Now().UTC() + + msg := fmt.Sprintf("requester %d/%d to host %s errored", + i, len(c.clients), requester.host) + multiErr = multierror.Append(multiErr, fmt.Errorf("%s: %w", msg, err)) + + // Using debug level as the error is only relevant if all clients fail. + c.log.With("error", err).Debugf(msg) + continue + } - resp, err := requester.client.Do(req.WithContext(ctx)) - if err != nil { - requester.lastErr = err - requester.lastErrOcc = time.Now().UTC() - } else { requester.lastErr = nil requester.lastErrOcc = time.Time{} + return resp, nil } - return resp, err + + return nil, fmt.Errorf("all hosts failed: %w", multiErr) } // URI returns the remote URI. @@ -202,67 +224,78 @@ func (c *Client) URI() string { return string(c.config.Protocol) + "://" + host + "/" + c.config.Path } -// new creates new API client. -func new( +// newClient creates a new API client. +func newClient( log *logger.Logger, cfg Config, - httpClients ...*requestClient, + clients ...*requestClient, ) (*Client, error) { + // Shuffle so all the agents don't access the hosts in the same order + rand.Shuffle(len(clients), func(i, j int) { + clients[i], clients[j] = clients[j], clients[i] + }) + c := &Client{ log: log, - clients: httpClients, + clients: clients, config: cfg, } return c, nil } -// nextRequester returns the requester to use. -// -// It excludes clients that have errored in the last 5 minutes. -func (c *Client) nextRequester() *requestClient { - var selected *requestClient - +// sortClients sort the clients according to the following priority: +// - never used +// - without errors, last used first when more than one does not have errors +// - last errored. +// It also removes the last error after retryOnBadConnTimeout has elapsed. +func (c *Client) sortClients() { now := time.Now().UTC() - for _, requester := range c.clients { - if requester.lastErr != nil && now.Sub(requester.lastErrOcc) > retryOnBadConnTimeout { - requester.lastErr = nil - requester.lastErrOcc = time.Time{} + + sort.Slice(c.clients, func(i, j int) bool { + // First, set them good if the timout has elapsed + if c.clients[i].lastErr != nil && + now.Sub(c.clients[i].lastErrOcc) > retryOnBadConnTimeout { + c.clients[i].lastErr = nil + c.clients[i].lastErrOcc = time.Time{} } - if requester.lastErr != nil { - continue + if c.clients[j].lastErr != nil && + now.Sub(c.clients[j].lastErrOcc) > retryOnBadConnTimeout { + c.clients[j].lastErr = nil + c.clients[j].lastErrOcc = time.Time{} } - if requester.lastUsed.IsZero() { - // never been used, instant winner! - selected = requester - break + + // Pick not yet used first, but if both haven't been used yet, + // we return false to comply with the sort.Interface definition. + if c.clients[i].lastUsed.IsZero() && + c.clients[j].lastUsed.IsZero() { + return false } - if selected == nil { - selected = requester - continue + + // Pick not yet used first + if c.clients[i].lastUsed.IsZero() { + return true } - if requester.lastUsed.Before(selected.lastUsed) { - selected = requester + + // If none has errors, pick the last used + // Then, the one without errors + if c.clients[i].lastErr == nil && + c.clients[j].lastErr == nil { + return c.clients[i].lastUsed.Before(c.clients[j].lastUsed) } - } - if selected == nil { - // all are erroring; select the oldest one that errored - for _, requester := range c.clients { - if selected == nil { - selected = requester - continue - } - if requester.lastErrOcc.Before(selected.lastErrOcc) { - selected = requester - } + + // Then, the one without error + if c.clients[i].lastErr == nil { + return true } - } - return selected + + // Lastly, the one that errored last + return c.clients[i].lastUsed.Before(c.clients[j].lastUsed) + }) } -func prefixRequestFactory(URL string) requestFunc { - return func(method, path string, params url.Values, body io.Reader) (*http.Request, error) { - path = strings.TrimPrefix(path, "/") - newPath := strings.Join([]string{URL, path, "?", params.Encode()}, "") - return http.NewRequest(method, newPath, body) //nolint:noctx // keep old behaviour - } +func (r requestClient) newRequest(method string, path string, params url.Values, body io.Reader) (*http.Request, error) { + path = strings.TrimPrefix(path, "/") + newPath := strings.Join([]string{r.host, path, "?", params.Encode()}, "") + + return http.NewRequest(method, newPath, body) } diff --git a/internal/pkg/remote/client_test.go b/internal/pkg/remote/client_test.go index 6ea546f8128..887bc9817b2 100644 --- a/internal/pkg/remote/client_test.go +++ b/internal/pkg/remote/client_test.go @@ -58,7 +58,8 @@ func TestPortDefaults(t *testing.T) { c, err := NewWithConfig(l, cfg, nil) require.NoError(t, err) - r, err := c.nextRequester().request("GET", "/", nil, strings.NewReader("")) + c.sortClients() + r, err := c.clients[0].newRequest(http.MethodGet, "/", nil, strings.NewReader("")) require.NoError(t, err) if tc.ExpectedPort > 0 { @@ -77,13 +78,13 @@ func TestHTTPClient(t *testing.T) { l, err := logger.New("", false) require.NoError(t, err) + const successResp = `{"message":"hello"}` t.Run("Guard against double slashes on path", withServer( func(t *testing.T) *http.ServeMux { - msg := `{ message: "hello" }` mux := http.NewServeMux() mux.HandleFunc("/nested/echo-hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, msg) + fmt.Fprint(w, successResp) }) return addCatchAll(mux, t) }, func(t *testing.T, host string) { @@ -97,23 +98,22 @@ func TestHTTPClient(t *testing.T) { client, err := NewWithConfig(l, c, noopWrapper) require.NoError(t, err) - resp, err := client.Send(ctx, "GET", "/nested/echo-hello", nil, nil, nil) + resp, err := client.Send(ctx, http.MethodGet, "/nested/echo-hello", nil, nil, nil) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, `{ message: "hello" }`, string(body)) + assert.Equal(t, successResp, string(body)) }, )) t.Run("Simple call", withServer( func(t *testing.T) *http.ServeMux { - msg := `{ message: "hello" }` mux := http.NewServeMux() mux.HandleFunc("/echo-hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, msg) + fmt.Fprint(w, successResp) }) return mux }, func(t *testing.T, host string) { @@ -123,23 +123,22 @@ func TestHTTPClient(t *testing.T) { client, err := NewWithRawConfig(nil, cfg, nil) require.NoError(t, err) - resp, err := client.Send(ctx, "GET", "/echo-hello", nil, nil, nil) + resp, err := client.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, nil) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, `{ message: "hello" }`, string(body)) + assert.Equal(t, successResp, string(body)) }, )) t.Run("Simple call with a prefix path", withServer( func(t *testing.T) *http.ServeMux { - msg := `{ message: "hello" }` mux := http.NewServeMux() mux.HandleFunc("/mycustompath/echo-hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, msg) + fmt.Fprint(w, successResp) }) return mux }, func(t *testing.T, host string) { @@ -150,23 +149,62 @@ func TestHTTPClient(t *testing.T) { client, err := NewWithRawConfig(nil, cfg, nil) require.NoError(t, err) - resp, err := client.Send(ctx, "GET", "/echo-hello", nil, nil, nil) + resp, err := client.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, nil) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, `{ message: "hello" }`, string(body)) + assert.Equal(t, successResp, string(body)) }, )) + t.Run("Tries all the hosts", withServer( + func(t *testing.T) *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/echo-hello", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, successResp) + }) + return mux + }, func(t *testing.T, host string) { + one := &requestClient{host: "http://must.fail-1.co/"} + two := &requestClient{host: "http://must.fail-2.co/"} + three := &requestClient{host: fmt.Sprintf("http://%s/", host)} + + c := &Client{clients: []*requestClient{one, two, three}, log: l} + require.NoError(t, err) + resp, err := c.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, nil) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, successResp, string(body)) + }, + )) + + t.Run("Return last error", func(t *testing.T) { + client := &Client{ + log: l, + clients: []*requestClient{ + {host: "http://must.fail-1.co/"}, + {host: "http://must.fail-2.co/"}, + {host: "http://must.fail-3.co/"}, + }} + + resp, err := client.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, nil) + assert.Contains(t, err.Error(), "http://must.fail-3.co/") // error contains last host + assert.Nil(t, resp) + }) + t.Run("Custom user agent", withServer( func(t *testing.T) *http.ServeMux { - msg := `{ message: "hello" }` mux := http.NewServeMux() mux.HandleFunc("/echo-hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, msg) + fmt.Fprint(w, successResp) require.Equal(t, r.Header.Get("User-Agent"), "custom-agent") }) return mux @@ -180,23 +218,22 @@ func TestHTTPClient(t *testing.T) { }) require.NoError(t, err) - resp, err := client.Send(ctx, "GET", "/echo-hello", nil, nil, nil) + resp, err := client.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, nil) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, `{ message: "hello" }`, string(body)) + assert.Equal(t, successResp, string(body)) }, )) t.Run("Allows to debug HTTP request between a client and a server", withServer( func(t *testing.T) *http.ServeMux { - msg := `{ "message": "hello" }` mux := http.NewServeMux() mux.HandleFunc("/echo-hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, msg) + fmt.Fprint(w, successResp) }) return mux }, func(t *testing.T, host string) { @@ -212,16 +249,16 @@ func TestHTTPClient(t *testing.T) { }) require.NoError(t, err) - resp, err := client.Send(ctx, "GET", "/echo-hello", nil, nil, bytes.NewBuffer([]byte("hello"))) + resp, err := client.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, bytes.NewBuffer([]byte("hello"))) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, `{ "message": "hello" }`, string(body)) + assert.Equal(t, successResp, string(body)) for _, m := range debugger.messages { - fmt.Println(m) + fmt.Println(m) //nolint:forbidigo // printing debug messages on a test. } assert.Equal(t, 1, len(debugger.messages)) @@ -230,11 +267,10 @@ func TestHTTPClient(t *testing.T) { t.Run("RequestId", withServer( func(t *testing.T) *http.ServeMux { - msg := `{ message: "hello" }` mux := http.NewServeMux() mux.HandleFunc("/echo-hello", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, msg) + fmt.Fprint(w, successResp) require.NotEmpty(t, r.Header.Get("X-Request-ID")) }) return mux @@ -245,48 +281,58 @@ func TestHTTPClient(t *testing.T) { client, err := NewWithRawConfig(nil, cfg, nil) require.NoError(t, err) - resp, err := client.Send(ctx, "GET", "/echo-hello", nil, nil, nil) + resp, err := client.Send(ctx, http.MethodGet, "/echo-hello", nil, nil, nil) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, `{ message: "hello" }`, string(body)) + assert.Equal(t, successResp, string(body)) }, )) } -func TestNextRequester(t *testing.T) { +func TestSortClients(t *testing.T) { t.Run("Picks first requester on initial call", func(t *testing.T) { one := &requestClient{} two := &requestClient{} - client, err := new(nil, Config{}, one, two) + client, err := newClient(nil, Config{}, one, two) require.NoError(t, err) - assert.Equal(t, one, client.nextRequester()) + + client.sortClients() + + assert.Equal(t, one, client.clients[0]) }) t.Run("Picks second requester when first has error", func(t *testing.T) { one := &requestClient{ + lastUsed: time.Now().UTC(), lastErr: fmt.Errorf("fake error"), lastErrOcc: time.Now().UTC(), } two := &requestClient{} - client, err := new(nil, Config{}, one, two) + client, err := newClient(nil, Config{}, one, two) require.NoError(t, err) - assert.Equal(t, two, client.nextRequester()) + + client.sortClients() + + assert.Equal(t, two, client.clients[0]) }) - t.Run("Picks second requester when first has used", func(t *testing.T) { + t.Run("Picks second requester when first has been used", func(t *testing.T) { one := &requestClient{ lastUsed: time.Now().UTC(), } two := &requestClient{} - client, err := new(nil, Config{}, one, two) + client, err := newClient(nil, Config{}, one, two) require.NoError(t, err) - assert.Equal(t, two, client.nextRequester()) + + client.sortClients() + + assert.Equal(t, two, client.clients[0]) }) - t.Run("Picks second requester when its oldest", func(t *testing.T) { + t.Run("Picks second requester when it's the oldest", func(t *testing.T) { one := &requestClient{ lastUsed: time.Now().UTC().Add(-time.Minute), } @@ -296,12 +342,15 @@ func TestNextRequester(t *testing.T) { three := &requestClient{ lastUsed: time.Now().UTC().Add(-2 * time.Minute), } - client, err := new(nil, Config{}, one, two, three) + client, err := newClient(nil, Config{}, one, two, three) require.NoError(t, err) - assert.Equal(t, two, client.nextRequester()) + + client.sortClients() + + assert.Equal(t, two, client.clients[0]) }) - t.Run("Picks third requester when its second has error and first is last used", func(t *testing.T) { + t.Run("Picks third requester when second has error and first is last used", func(t *testing.T) { one := &requestClient{ lastUsed: time.Now().UTC().Add(-time.Minute), } @@ -313,9 +362,11 @@ func TestNextRequester(t *testing.T) { three := &requestClient{ lastUsed: time.Now().UTC().Add(-2 * time.Minute), } - client, err := new(nil, Config{}, one, two, three) - require.NoError(t, err) - assert.Equal(t, three, client.nextRequester()) + client := &Client{clients: []*requestClient{one, two, three}} + + client.sortClients() + + assert.Equal(t, three, client.clients[0]) }) t.Run("Picks second requester when its oldest and all have old errors", func(t *testing.T) { @@ -334,9 +385,12 @@ func TestNextRequester(t *testing.T) { lastErr: fmt.Errorf("fake error"), lastErrOcc: time.Now().Add(-2 * time.Minute), } - client, err := new(nil, Config{}, one, two, three) + client, err := newClient(nil, Config{}, one, two, three) require.NoError(t, err) - assert.Equal(t, two, client.nextRequester()) + + client.sortClients() + + assert.Equal(t, two, client.clients[0]) }) }