From 76a35e71beee1408957a8dc3c7230154deb7a0bd Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Thu, 7 Sep 2017 15:10:32 -0700 Subject: [PATCH] client: fail over to next endpoint on oneshot failure Fixes #8515 --- client/client.go | 19 ++++++++++--------- client/client_test.go | 33 +++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/client/client.go b/client/client.go index 03054150ccd..3c8948252f8 100644 --- a/client/client.go +++ b/client/client.go @@ -371,12 +371,7 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo if err == context.Canceled || err == context.DeadlineExceeded { return nil, nil, err } - if isOneShot { - return nil, nil, err - } - continue - } - if resp.StatusCode/100 == 5 { + } else if resp.StatusCode/100 == 5 { switch resp.StatusCode { case http.StatusInternalServerError, http.StatusServiceUnavailable: // TODO: make sure this is a no leader response @@ -384,10 +379,16 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo default: cerr.Errors = append(cerr.Errors, fmt.Errorf("client: etcd member %s returns server error [%s]", eps[k].String(), http.StatusText(resp.StatusCode))) } - if isOneShot { - return nil, nil, cerr.Errors[0] + err = cerr.Errors[0] + } + if err != nil { + if !isOneShot { + continue } - continue + c.Lock() + c.pinned = (k + 1) % leps + c.Unlock() + return nil, nil, err } if k != pinned { c.Lock() diff --git a/client/client_test.go b/client/client_test.go index 0a355c5e445..40328a1e9b8 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -17,6 +17,7 @@ package client import ( "context" "errors" + "fmt" "io" "io/ioutil" "math/rand" @@ -304,7 +305,9 @@ func TestHTTPClusterClientDo(t *testing.T) { fakeErr := errors.New("fake!") fakeURL := url.URL{} tests := []struct { - client *httpClusterClient + client *httpClusterClient + ctx context.Context + wantCode int wantErr error wantPinned int @@ -395,10 +398,30 @@ func TestHTTPClusterClientDo(t *testing.T) { wantCode: http.StatusTeapot, wantPinned: 1, }, + + // 500-level errors cause one shot Do to fallthrough to next endpoint + { + client: &httpClusterClient{ + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + {resp: http.Response{StatusCode: http.StatusBadGateway}}, + {resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + ), + rand: rand.New(rand.NewSource(0)), + }, + ctx: context.WithValue(context.Background(), &oneShotCtxValue, &oneShotCtxValue), + wantErr: fmt.Errorf("client: etcd member returns server error [Bad Gateway]"), + wantPinned: 1, + }, } for i, tt := range tests { - resp, _, err := tt.client.Do(context.Background(), nil) + if tt.ctx == nil { + tt.ctx = context.Background() + } + resp, _, err := tt.client.Do(tt.ctx, nil) if !reflect.DeepEqual(tt.wantErr, err) { t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr) continue @@ -407,11 +430,9 @@ func TestHTTPClusterClientDo(t *testing.T) { if resp == nil { if tt.wantCode != 0 { t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode) + continue } - continue - } - - if resp.StatusCode != tt.wantCode { + } else if resp.StatusCode != tt.wantCode { t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode) continue }